From f7ea583082c670103fb2cebd6035b944c71d64c4 Mon Sep 17 00:00:00 2001 From: Benoit Jacob Date: Fri, 27 Mar 2020 21:58:51 -0400 Subject: Move ruy's code to a ruy/ subdirectory. The motivation is that having source files in the repository root runs into a number of corner cases with copybara setups and with external CMake build systems, so enclosing all code in ruy/ avoids that while generally making our setup much more similar to that of other related projects (TensorFlow, IREE). PiperOrigin-RevId: 303448881 --- ruy/BUILD | 954 ++++ ruy/allocator.cc | 51 + ruy/allocator.h | 185 + ruy/allocator_test.cc | 103 + ruy/benchmark.cc | 196 + ruy/block_map.cc | 486 ++ ruy/block_map.h | 161 + ruy/block_map_test.cc | 263 + ruy/blocking_counter.cc | 49 + ruy/blocking_counter.h | 62 + ruy/build_defs.bzl | 54 + ruy/build_defs.bzl.opensource | 40 + ruy/check_macros.h | 138 + ruy/check_macros_test.cc | 153 + ruy/common.h | 73 + ruy/context.cc | 109 + ruy/context.h | 109 + ruy/context_test.cc | 63 + ruy/cpu_cache_size.h | 81 + ruy/detect_arm.cc | 73 + ruy/detect_arm.h | 29 + ruy/detect_x86.cc | 101 + ruy/detect_x86.h | 49 + ruy/dispatch.h | 482 ++ ruy/example.cc | 136 + ruy/example_advanced.cc | 83 + ruy/have_built_path_for.h | 32 + ruy/have_built_path_for_avx2.cc | 35 + ruy/have_built_path_for_avx512.cc | 35 + ruy/have_built_path_for_avxvnni.cc | 39 + ruy/have_built_path_for_sse42.cc | 39 + ruy/internal_matrix.h | 388 ++ ruy/kernel.h | 31 + ruy/kernel_arm.h | 211 + ruy/kernel_arm32.cc | 2499 +++++++++ ruy/kernel_arm64.cc | 7835 +++++++++++++++++++++++++++++ ruy/kernel_avx2.cc | 1664 ++++++ ruy/kernel_avx512.cc | 1820 +++++++ ruy/kernel_avxvnni.cc | 435 ++ ruy/kernel_common.h | 481 ++ ruy/kernel_sse42.cc | 428 ++ ruy/kernel_x86.h | 222 + ruy/matrix.h | 182 + ruy/opt_set.h | 51 + ruy/pack.h | 98 + ruy/pack_arm.cc | 1936 +++++++ ruy/pack_arm.h | 497 ++ ruy/pack_avx2.cc | 816 +++ ruy/pack_avx512.cc | 693 +++ ruy/pack_avxvnni.cc | 478 ++ ruy/pack_common.h | 246 + ruy/pack_sse42.cc | 471 ++ ruy/pack_x86.h | 461 ++ ruy/path.h | 162 + ruy/platform.h | 156 + ruy/pmu.cc | 281 ++ ruy/pmu.h | 44 + ruy/prepack.h | 108 + ruy/prepacked_cache.cc | 82 + ruy/prepacked_cache.h | 130 + ruy/prepacked_cache_test.cc | 210 + ruy/profiler/BUILD | 52 + ruy/profiler/README.md | 149 + ruy/profiler/instrumentation.cc | 130 + ruy/profiler/instrumentation.h | 203 + ruy/profiler/profiler.cc | 109 + ruy/profiler/profiler.h | 106 + ruy/profiler/test.cc | 167 + ruy/profiler/test_instrumented_library.cc | 59 + ruy/profiler/test_instrumented_library.h | 23 + ruy/profiler/treeview.cc | 248 + ruy/profiler/treeview.h | 130 + ruy/ruy.h | 42 + ruy/ruy_advanced.h | 69 + ruy/ruy_test.bzl | 34 + ruy/ruy_test_ext.bzl | 19 + ruy/ruy_test_ext.bzl.opensource | 7 + ruy/side_pair.h | 64 + ruy/size_util.h | 93 + ruy/size_util_test.cc | 101 + ruy/spec.h | 118 + ruy/test.h | 2125 ++++++++ ruy/test_fast.cc | 110 + ruy/test_slow.cc | 71 + ruy/test_special_specs.cc | 163 + ruy/thread_pool.cc | 200 + ruy/thread_pool.h | 102 + ruy/time.h | 81 + ruy/trace.cc | 325 ++ ruy/trace.h | 73 + ruy/trmul.cc | 401 ++ ruy/trmul.h | 38 + ruy/trmul_params.h | 67 + ruy/tune.cc | 161 + ruy/tune.h | 163 + ruy/tune_test.cc | 53 + ruy/tune_tool.cc | 56 + ruy/wait.cc | 69 + ruy/wait.h | 73 + ruy/wait_test.cc | 117 + 100 files changed, 33950 insertions(+) create mode 100644 ruy/BUILD create mode 100644 ruy/allocator.cc create mode 100644 ruy/allocator.h create mode 100644 ruy/allocator_test.cc create mode 100644 ruy/benchmark.cc create mode 100644 ruy/block_map.cc create mode 100644 ruy/block_map.h create mode 100644 ruy/block_map_test.cc create mode 100644 ruy/blocking_counter.cc create mode 100644 ruy/blocking_counter.h create mode 100644 ruy/build_defs.bzl create mode 100644 ruy/build_defs.bzl.opensource create mode 100644 ruy/check_macros.h create mode 100644 ruy/check_macros_test.cc create mode 100644 ruy/common.h create mode 100644 ruy/context.cc create mode 100644 ruy/context.h create mode 100644 ruy/context_test.cc create mode 100644 ruy/cpu_cache_size.h create mode 100644 ruy/detect_arm.cc create mode 100644 ruy/detect_arm.h create mode 100644 ruy/detect_x86.cc create mode 100644 ruy/detect_x86.h create mode 100644 ruy/dispatch.h create mode 100644 ruy/example.cc create mode 100644 ruy/example_advanced.cc create mode 100644 ruy/have_built_path_for.h create mode 100644 ruy/have_built_path_for_avx2.cc create mode 100644 ruy/have_built_path_for_avx512.cc create mode 100644 ruy/have_built_path_for_avxvnni.cc create mode 100644 ruy/have_built_path_for_sse42.cc create mode 100644 ruy/internal_matrix.h create mode 100644 ruy/kernel.h create mode 100644 ruy/kernel_arm.h create mode 100644 ruy/kernel_arm32.cc create mode 100644 ruy/kernel_arm64.cc create mode 100644 ruy/kernel_avx2.cc create mode 100644 ruy/kernel_avx512.cc create mode 100644 ruy/kernel_avxvnni.cc create mode 100644 ruy/kernel_common.h create mode 100644 ruy/kernel_sse42.cc create mode 100644 ruy/kernel_x86.h create mode 100644 ruy/matrix.h create mode 100644 ruy/opt_set.h create mode 100644 ruy/pack.h create mode 100644 ruy/pack_arm.cc create mode 100644 ruy/pack_arm.h create mode 100644 ruy/pack_avx2.cc create mode 100644 ruy/pack_avx512.cc create mode 100644 ruy/pack_avxvnni.cc create mode 100644 ruy/pack_common.h create mode 100644 ruy/pack_sse42.cc create mode 100644 ruy/pack_x86.h create mode 100644 ruy/path.h create mode 100644 ruy/platform.h create mode 100644 ruy/pmu.cc create mode 100644 ruy/pmu.h create mode 100644 ruy/prepack.h create mode 100644 ruy/prepacked_cache.cc create mode 100644 ruy/prepacked_cache.h create mode 100644 ruy/prepacked_cache_test.cc create mode 100644 ruy/profiler/BUILD create mode 100644 ruy/profiler/README.md create mode 100644 ruy/profiler/instrumentation.cc create mode 100644 ruy/profiler/instrumentation.h create mode 100644 ruy/profiler/profiler.cc create mode 100644 ruy/profiler/profiler.h create mode 100644 ruy/profiler/test.cc create mode 100644 ruy/profiler/test_instrumented_library.cc create mode 100644 ruy/profiler/test_instrumented_library.h create mode 100644 ruy/profiler/treeview.cc create mode 100644 ruy/profiler/treeview.h create mode 100644 ruy/ruy.h create mode 100644 ruy/ruy_advanced.h create mode 100644 ruy/ruy_test.bzl create mode 100644 ruy/ruy_test_ext.bzl create mode 100644 ruy/ruy_test_ext.bzl.opensource create mode 100644 ruy/side_pair.h create mode 100644 ruy/size_util.h create mode 100644 ruy/size_util_test.cc create mode 100644 ruy/spec.h create mode 100644 ruy/test.h create mode 100644 ruy/test_fast.cc create mode 100644 ruy/test_slow.cc create mode 100644 ruy/test_special_specs.cc create mode 100644 ruy/thread_pool.cc create mode 100644 ruy/thread_pool.h create mode 100644 ruy/time.h create mode 100644 ruy/trace.cc create mode 100644 ruy/trace.h create mode 100644 ruy/trmul.cc create mode 100644 ruy/trmul.h create mode 100644 ruy/trmul_params.h create mode 100644 ruy/tune.cc create mode 100644 ruy/tune.h create mode 100644 ruy/tune_test.cc create mode 100644 ruy/tune_tool.cc create mode 100644 ruy/wait.cc create mode 100644 ruy/wait.h create mode 100644 ruy/wait_test.cc (limited to 'ruy') diff --git a/ruy/BUILD b/ruy/BUILD new file mode 100644 index 0000000..0b19193 --- /dev/null +++ b/ruy/BUILD @@ -0,0 +1,954 @@ +# Ruy is not BLAS + +load(":build_defs.bzl", "ruy_copts_avx2", "ruy_copts_avxvnni", "ruy_copts_base", "ruy_copts_skylake", "ruy_copts_sse42") +load(":ruy_test_ext.bzl", "ruy_test_ext_defines", "ruy_test_ext_deps") +load(":ruy_test.bzl", "ruy_benchmark", "ruy_test") + +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) + +config_setting( + name = "windows", + values = {"cpu": "x64_windows"}, +) + +config_setting( + name = "armeabi-v7a", + values = {"cpu": "armeabi-v7a"}, +) + +config_setting( + name = "x86_64", + values = {"cpu": "k8"}, +) + +config_setting( + name = "optimized", + values = { + "compilation_mode": "opt", + }, + visibility = ["//visibility:public"], +) + +cc_library( + name = "platform", + hdrs = ["platform.h"], + copts = ruy_copts_base(), +) + +cc_library( + name = "check_macros", + hdrs = ["check_macros.h"], + copts = ruy_copts_base(), +) + +cc_test( + name = "check_macros_test", + srcs = ["check_macros_test.cc"], + copts = ruy_copts_base(), + deps = [ + ":check_macros", + "@com_google_googletest//:gtest", + ], +) + +cc_library( + name = "opt_set", + hdrs = ["opt_set.h"], + copts = ruy_copts_base(), +) + +cc_library( + name = "time", + hdrs = ["time.h"], + copts = ruy_copts_base(), +) + +cc_library( + name = "wait", + srcs = ["wait.cc"], + hdrs = ["wait.h"], + copts = ruy_copts_base(), + deps = [":time"], +) + +cc_test( + name = "wait_test", + srcs = ["wait_test.cc"], + deps = [ + ":platform", + ":wait", + "@com_google_googletest//:gtest", + ], +) + +cc_library( + name = "size_util", + hdrs = ["size_util.h"], + copts = ruy_copts_base(), + deps = [":check_macros"], +) + +cc_test( + name = "size_util_test", + srcs = ["size_util_test.cc"], + deps = [ + ":size_util", + "@com_google_googletest//:gtest", + ], +) + +cc_library( + name = "tune", + srcs = [ + "tune.cc", + ], + hdrs = [ + "tune.h", + ], + copts = ruy_copts_base(), + deps = [ + ":opt_set", + ":platform", + ":time", + ], +) + +cc_library( + name = "prepacked_cache", + srcs = [ + "prepacked_cache.cc", + ], + hdrs = [ + "prepacked_cache.h", + ], + copts = ruy_copts_base(), + deps = [ + ":allocator", + ":matrix", + ":opt_set", + ":platform", + ":time", + "//ruy/profiler:instrumentation", + ], +) + +cc_test( + name = "tune_test", + srcs = ["tune_test.cc"], + deps = [ + ":tune", + "@com_google_googletest//:gtest", + ], +) + +cc_test( + name = "prepacked_cache_test", + srcs = ["prepacked_cache_test.cc"], + deps = [ + ":prepacked_cache", + ":ruy", + ":time", + "@com_google_googletest//:gtest", + ], +) + +cc_binary( + name = "tune_tool", + srcs = ["tune_tool.cc"], + deps = [ + ":tune", + ], +) + +cc_library( + name = "allocator", + srcs = [ + "allocator.cc", + ], + hdrs = [ + "allocator.h", + ], + copts = ruy_copts_base(), + deps = [ + ":check_macros", + ":size_util", + ], +) + +cc_test( + name = "allocator_test", + srcs = ["allocator_test.cc"], + deps = [ + ":allocator", + "@com_google_googletest//:gtest", + ], +) + +cc_library( + name = "side_pair", + hdrs = ["side_pair.h"], + copts = ruy_copts_base(), + deps = [":check_macros"], +) + +cc_library( + name = "block_map", + srcs = [ + "block_map.cc", + ], + hdrs = [ + "block_map.h", + ], + copts = ruy_copts_base(), + deps = [ + ":check_macros", + ":opt_set", + ":path", + ":side_pair", + ":size_util", + "//ruy/profiler:instrumentation", + ], +) + +cc_test( + name = "block_map_test", + srcs = ["block_map_test.cc"], + deps = [ + ":block_map", + ":cpu_cache_size", + ":path", + ":side_pair", + "@com_google_googletest//:gtest", + ], +) + +cc_library( + name = "blocking_counter", + srcs = [ + "blocking_counter.cc", + ], + hdrs = [ + "blocking_counter.h", + ], + copts = ruy_copts_base(), + deps = [ + ":check_macros", + ":wait", + ], +) + +cc_library( + name = "thread_pool", + srcs = [ + "thread_pool.cc", + ], + hdrs = [ + "thread_pool.h", + ], + copts = ruy_copts_base(), + deps = [ + ":blocking_counter", + ":check_macros", + ":wait", + ], +) + +cc_library( + name = "detect_arm", + srcs = [ + "detect_arm.cc", + ], + hdrs = [ + "detect_arm.h", + ], + copts = ruy_copts_base(), +) + +cc_library( + name = "detect_x86", + srcs = [ + "detect_x86.cc", + ], + hdrs = [ + "detect_x86.h", + ], + copts = ruy_copts_base(), + deps = [ + ":platform", + ], +) + +cc_library( + name = "path", + hdrs = ["path.h"], + copts = ruy_copts_base(), + deps = [ + ":platform", + ":size_util", + ], +) + +cc_library( + name = "cpu_cache_size", + hdrs = ["cpu_cache_size.h"], + copts = ruy_copts_base(), + deps = [ + ":path", + ":platform", + ], +) + +cc_library( + name = "trace", + srcs = [ + "trace.cc", + ], + hdrs = [ + "trace.h", + ], + copts = ruy_copts_base(), + deps = [ + ":block_map", + ":check_macros", + ":side_pair", + ":time", + ], +) + +cc_library( + name = "matrix", + hdrs = ["matrix.h"], + copts = ruy_copts_base(), + deps = [":check_macros"], +) + +cc_library( + name = "spec", + hdrs = ["spec.h"], + copts = ruy_copts_base(), + deps = [ + ":cpu_cache_size", + ":matrix", + ], +) + +cc_library( + name = "internal_matrix", + hdrs = ["internal_matrix.h"], + copts = ruy_copts_base(), + deps = [ + ":check_macros", + ":common", + ":matrix", + ":size_util", + ], +) + +cc_library( + name = "common", + hdrs = [ + "common.h", + ], + copts = ruy_copts_base(), + deps = [ + ":check_macros", + ":matrix", + ":opt_set", + ":path", + ":platform", + ], +) + +cc_library( + name = "kernel_common", + hdrs = [ + "kernel.h", + "kernel_arm.h", + "kernel_common.h", + "kernel_x86.h", + ], + copts = ruy_copts_base(), + deps = [ + ":check_macros", + ":common", + ":internal_matrix", + ":matrix", + ":opt_set", + ":path", + ":platform", + ":side_pair", + ":size_util", + ":spec", + ":tune", + "//ruy/profiler:instrumentation", + ], +) + +cc_library( + name = "pack_common", + hdrs = [ + "pack.h", + "pack_arm.h", + "pack_common.h", + "pack_x86.h", + ], + copts = ruy_copts_base(), + deps = [ + ":check_macros", + ":common", + ":internal_matrix", + ":matrix", + ":opt_set", + ":path", + ":platform", + ":tune", + "//ruy/profiler:instrumentation", + ], +) + +cc_library( + name = "kernel_arm", + srcs = [ + "kernel_arm32.cc", + "kernel_arm64.cc", + ], + copts = ruy_copts_base(), + deps = [ + ":common", + ":kernel_common", + ":opt_set", + ":platform", + "//ruy/profiler:instrumentation", + ], +) + +cc_library( + name = "pack_arm", + srcs = [ + "pack_arm.cc", + ], + copts = ruy_copts_base(), + deps = [ + ":common", + ":opt_set", + ":pack_common", + ":platform", + "//ruy/profiler:instrumentation", + ], +) + +# AVX-512 compilation units. +# +# These must use the same compiler options. +RUY_COPTS_BUILT_FOR_AVX512 = ruy_copts_base() + ruy_copts_skylake() + +cc_library( + name = "kernel_avx512", + srcs = [ + "kernel_avx512.cc", + ], + copts = RUY_COPTS_BUILT_FOR_AVX512, + deps = [ + ":check_macros", + ":kernel_common", + ":opt_set", + ":platform", + "//ruy/profiler:instrumentation", + ], +) + +cc_library( + name = "pack_avx512", + srcs = [ + "pack_avx512.cc", + ], + copts = RUY_COPTS_BUILT_FOR_AVX512, + deps = [ + ":check_macros", + ":matrix", + ":opt_set", + ":pack_common", + ":path", + ":platform", + "//ruy/profiler:instrumentation", + ], +) + +cc_library( + name = "have_built_path_for_avx512", + srcs = [ + "have_built_path_for_avx512.cc", + ], + hdrs = [ + "have_built_path_for.h", + ], + copts = RUY_COPTS_BUILT_FOR_AVX512, + deps = [ + ":opt_set", + ":platform", + ], +) +# End: AVX-512 compilation units. + +# AVX2 compilation units. +# +# These must use the same compiler options. +RUY_COPTS_BUILT_FOR_AVX2 = ruy_copts_base() + ruy_copts_avx2() + +cc_library( + name = "kernel_avx2", + srcs = [ + "kernel_avx2.cc", + ], + copts = RUY_COPTS_BUILT_FOR_AVX2, + deps = [ + ":check_macros", + ":kernel_common", + ":opt_set", + ":platform", + "//ruy/profiler:instrumentation", + ], +) + +cc_library( + name = "pack_avx2", + srcs = [ + "pack_avx2.cc", + ], + copts = RUY_COPTS_BUILT_FOR_AVX2, + deps = [ + ":check_macros", + ":matrix", + ":opt_set", + ":pack_common", + ":path", + ":platform", + "//ruy/profiler:instrumentation", + ], +) + +cc_library( + name = "have_built_path_for_avx2", + srcs = [ + "have_built_path_for_avx2.cc", + ], + hdrs = [ + "have_built_path_for.h", + ], + copts = RUY_COPTS_BUILT_FOR_AVX2, + deps = [ + ":opt_set", + ":platform", + ], +) +# End: AVX2 compilation units. + +# SSE42 compilation units. +# +# TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. +# Optimization is not finished. In particular the dimensions of the kernel +# blocks can be changed as desired. +# +# These must use the same compiler options. +RUY_COPTS_BUILT_FOR_SSE42 = ruy_copts_base() + ruy_copts_sse42() + +cc_library( + name = "kernel_sse42", + srcs = [ + "kernel_sse42.cc", + ], + copts = RUY_COPTS_BUILT_FOR_SSE42, + deps = [ + ":check_macros", + ":kernel_common", + ":opt_set", + ":platform", + "//ruy/profiler:instrumentation", + ], +) + +cc_library( + name = "pack_sse42", + srcs = [ + "pack_sse42.cc", + ], + copts = RUY_COPTS_BUILT_FOR_SSE42, + deps = [ + ":check_macros", + ":matrix", + ":opt_set", + ":pack_common", + ":path", + ":platform", + "//ruy/profiler:instrumentation", + ], +) + +cc_library( + name = "have_built_path_for_sse42", + srcs = [ + "have_built_path_for_sse42.cc", + ], + hdrs = [ + "have_built_path_for.h", + ], + copts = RUY_COPTS_BUILT_FOR_SSE42, + deps = [ + ":opt_set", + ":platform", + ], +) +# End: SSE42 compilation units. + +# AVX-VNNI compilation units. +# +# TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. +# Optimization is not finished. In particular the dimensions of the kernel +# blocks can be changed as desired. +# +# These must use the same compiler options. +RUY_COPTS_BUILT_FOR_AVX_VNNI = ruy_copts_base() + ruy_copts_avxvnni() + +cc_library( + name = "kernel_avxvnni", + srcs = [ + "kernel_avxvnni.cc", + ], + copts = RUY_COPTS_BUILT_FOR_AVX_VNNI, + deps = [ + ":check_macros", + ":kernel_common", + ":opt_set", + ":platform", + "//ruy/profiler:instrumentation", + ], +) + +cc_library( + name = "pack_avxvnni", + srcs = [ + "pack_avxvnni.cc", + ], + copts = RUY_COPTS_BUILT_FOR_AVX_VNNI, + deps = [ + ":check_macros", + ":matrix", + ":opt_set", + ":pack_common", + ":path", + ":platform", + "//ruy/profiler:instrumentation", + ], +) + +cc_library( + name = "have_built_path_for_avxvnni", + srcs = [ + "have_built_path_for_avxvnni.cc", + ], + hdrs = [ + "have_built_path_for.h", + ], + copts = RUY_COPTS_BUILT_FOR_AVX_VNNI, + deps = [ + ":opt_set", + ":platform", + ], +) +# End: AVX-VNNI compilation units. + +cc_library( + name = "kernel", + hdrs = [ + "kernel.h", + "kernel_common.h", + ], + copts = ruy_copts_base(), + deps = [ + ":check_macros", + ":common", + ":internal_matrix", + ":kernel_arm", # fixdeps: keep + ":kernel_avx2", # fixdeps: keep + ":kernel_avx512", # fixdeps: keep + ":kernel_avxvnni", # fixdeps: keep + ":kernel_common", + ":kernel_sse42", # fixdeps: keep + ":matrix", + ":opt_set", + ":path", + ":platform", + ":side_pair", + ":size_util", + ":spec", + ":tune", + "//ruy/profiler:instrumentation", + ], +) + +cc_library( + name = "pack", + hdrs = [ + "pack.h", + "pack_common.h", + ], + copts = ruy_copts_base(), + deps = [ + ":check_macros", + ":common", + ":internal_matrix", + ":matrix", + ":opt_set", + ":pack_arm", # fixdeps: keep + ":pack_avx2", # fixdeps: keep + ":pack_avx512", # fixdeps: keep + ":pack_avxvnni", # fixdeps: keep + ":pack_common", + ":pack_sse42", # fixdeps: keep + ":path", + ":platform", + ":tune", + "//ruy/profiler:instrumentation", + ], +) + +cc_library( + name = "have_built_path_for", + hdrs = [ + "have_built_path_for.h", + ], + deps = [ + ":have_built_path_for_avx2", + ":have_built_path_for_avx512", + ":have_built_path_for_avxvnni", + ":have_built_path_for_sse42", + ":platform", + ], +) + +cc_library( + name = "context", + srcs = [ + "context.cc", + ], + hdrs = [ + "context.h", + ], + copts = ruy_copts_base(), + deps = [ + ":allocator", + ":check_macros", + ":detect_arm", + ":detect_x86", + ":have_built_path_for", + ":path", + ":platform", + ":prepacked_cache", + ":thread_pool", + ":trace", + ":tune", + ], +) + +cc_test( + name = "context_test", + srcs = ["context_test.cc"], + deps = [ + ":context", + ":path", + ":platform", + "@com_google_googletest//:gtest", + ], +) + +cc_library( + name = "trmul_params", + hdrs = ["trmul_params.h"], + copts = ruy_copts_base(), + deps = [ + ":internal_matrix", + ":side_pair", + ":tune", + ], +) + +cc_library( + name = "trmul", + srcs = ["trmul.cc"], + hdrs = ["trmul.h"], + copts = ruy_copts_base(), + deps = [ + ":allocator", + ":block_map", + ":check_macros", + ":common", + ":context", + ":internal_matrix", + ":matrix", + ":opt_set", + ":side_pair", + ":size_util", + ":spec", + ":thread_pool", + ":trace", + ":trmul_params", + ":tune", + "//ruy/profiler:instrumentation", + ], +) + +# The main library. +cc_library( + name = "ruy", + srcs = [ + "dispatch.h", + "prepack.h", + ], + hdrs = [ + "ruy.h", + "ruy_advanced.h", + ], + copts = ruy_copts_base(), + deps = [ + ":check_macros", + ":common", + ":context", + ":internal_matrix", + ":kernel", + ":matrix", + ":opt_set", + ":pack", + ":path", + ":prepacked_cache", + ":side_pair", + ":size_util", + ":spec", + ":trmul", + ":trmul_params", + ":tune", + "//ruy/profiler:instrumentation", + ], +) + +# Usage examples. +cc_binary( + name = "example", + srcs = ["example.cc"], + deps = [":ruy"], +) + +# Usage examples of the advanced API. +cc_binary( + name = "example_advanced", + srcs = ["example_advanced.cc"], + deps = [":ruy"], +) + +# Small library to query PMU counters, for benchmark only +cc_library( + name = "pmu", + testonly = True, + srcs = ["pmu.cc"], + hdrs = ["pmu.h"], + copts = ruy_copts_base(), + deps = [":check_macros"], +) + +# Testing framework. +cc_library( + name = "test_lib", + testonly = True, + hdrs = ["test.h"], + copts = ruy_copts_base(), + # need defines, not copts, because it's controlling a header, test.h + defines = ruy_test_ext_defines(), + linkopts = select({ + ":windows": [], + "//conditions:default": ["-lm"], + }), + deps = [ + ":matrix", + ":pmu", + ":ruy", + ":spec", + ":time", + "@com_google_googletest//:gtest", + ":platform", + "//ruy/profiler:profiler", + ] + ruy_test_ext_deps(), +) + +ruy_benchmark( + name = "benchmark", + srcs = ["benchmark.cc"], + copts = ruy_copts_base(), + lhs_rhs_accum_dst = [ + ("f32", "f32", "f32", "f32"), + ("u8", "u8", "i32", "u8"), + ("i8", "i8", "i32", "u8"), + ("i8", "i8", "i32", "i8"), + ("u8", "u8", "i32", "i16"), + ("i8", "i8", "i32", "i32"), + ], + deps = [ + "//ruy:test_lib", + "//ruy/profiler:instrumentation", + ], +) + +ruy_test( + name = "test_fast", + srcs = ["test_fast.cc"], + copts = ruy_copts_base(), + lhs_rhs_accum_dst = [ + ("f32", "f32", "f32", "f32"), + ("f64", "f32", "f64", "f32"), + ("f32", "f64", "f64", "f64"), + ("u8", "u8", "i32", "u8"), + ("i8", "i8", "i32", "i8"), + ("i8", "u8", "i32", "i8"), + ("u8", "u8", "i32", "i16"), + ("i8", "i8", "i32", "i32"), + ("i8", "u8", "i32", "i32"), + ], + deps = [ + "@com_google_googletest//:gtest_main", + "//ruy:test_lib", + ], +) + +ruy_test( + name = "test_slow", + srcs = ["test_slow.cc"], + copts = ruy_copts_base(), + lhs_rhs_accum_dst = [ + ("f32", "f32", "f32", "f32"), + ("u8", "u8", "i32", "u8"), + ("i8", "i8", "i32", "i8"), + ("u8", "u8", "i32", "i16"), + ("i8", "i8", "i32", "i32"), + ], + tags = ["slow"], + deps = [ + "@com_google_googletest//:gtest_main", + "//ruy:test_lib", + ], +) + +ruy_test( + name = "test_special_specs", + srcs = ["test_special_specs.cc"], + copts = ruy_copts_base(), + lhs_rhs_accum_dst = [ + ("f32", "f32", "f32", "f32"), + ("u8", "u8", "i32", "u8"), + ("u8", "u8", "i32", "i16"), + ], + deps = [ + "@com_google_googletest//:gtest_main", + "//ruy:test_lib", + ], +) diff --git a/ruy/allocator.cc b/ruy/allocator.cc new file mode 100644 index 0000000..d8fb738 --- /dev/null +++ b/ruy/allocator.cc @@ -0,0 +1,51 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/allocator.h" + +#include +#include + +#ifdef _WIN32 +#include +#endif + +namespace ruy { + +namespace detail { + +void *SystemAlignedAlloc(std::ptrdiff_t num_bytes) { +#ifdef _WIN32 + return _aligned_malloc(num_bytes, kMinimumBlockAlignment); +#else + void *ptr; + if (posix_memalign(&ptr, kMinimumBlockAlignment, num_bytes)) { + return nullptr; + } + return ptr; +#endif +} + +void SystemAlignedFree(void *ptr) { +#ifdef _WIN32 + _aligned_free(ptr); +#else + free(ptr); +#endif +} + +} // namespace detail + +} // namespace ruy diff --git a/ruy/allocator.h b/ruy/allocator.h new file mode 100644 index 0000000..b0379b1 --- /dev/null +++ b/ruy/allocator.h @@ -0,0 +1,185 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_ALLOCATOR_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_ALLOCATOR_H_ + +#include +#include +#include +#include + +#include "ruy/check_macros.h" +#include "ruy/size_util.h" + +namespace ruy { + +namespace detail { + +inline void* VoidPtrAdd(void* p, std::ptrdiff_t offset) { + RUY_DCHECK(p); + std::uintptr_t addr = reinterpret_cast(p) + offset; + return reinterpret_cast(addr); +} + +// Minimum alignment for blocks. +// +// Considerations: +// - This needs to be at least the alignment of any usual data type. +// - It's useful that this is at least the size of a cache line to limit +// possible cache side effects (if only on performance behavior). +// - It's useful that this is at least the size of SIMD registers, as +// some SIMD instruction sets have at least performance behavior +// differences (e.g. NEON) or even different requirements (e.g. SSE) +// based on that. +// - It's useful that this is at least the size of an "exclusive reservation +// granule" on ARM, meaning that if we use this Allocator to allocate +// an atomic variable, there will be no side effects from other things +// contending for exclusive/atomic memory accesses to it. While the +// ARM reference manual mentions that this granule size may be as large +// as 2048 bytes, in practice we observe it to be 64 bytes. It can +// be queried cheaply, at runtime, from userspace, if needed. +static constexpr std::ptrdiff_t kMinimumBlockAlignment = 64; + +// Primitive allocation functions obtaining aligned memory from the +// operating system. +void* SystemAlignedAlloc(std::ptrdiff_t num_bytes); +void SystemAlignedFree(void* ptr); + +// Specialized allocator designed to converge to a steady-state where all +// allocations are bump-ptr allocations from an already-allocated buffer. +// +// To support these constraints, this allocator only supports two +// operations. +// - AllocateAlignedBytes: allocates a pointer to storage of a specified +// size, which must be aligned to kMinimumBlockAlignment. +// - FreeAll: frees all previous allocations (but retains the internal +// buffer to minimize future calls into the system allocator). +// +// This class is specialized for supporting just those two operations +// under this specific steady-state usage pattern. Extending this class +// with new allocation interfaces that don't fit that pattern is probably not +// the right choice. Instead, build a new class on top of +// SystemAlignedAlloc/SystemAlignedFree. +// +// All operations happen on aligned blocks for simplicity. +class AlignedAllocator { + public: + void operator=(const AlignedAllocator&) = delete; + ~AlignedAllocator() { + FreeAll(); + SystemAlignedFree(ptr_); + } + + void* AllocateAlignedBytes(std::ptrdiff_t num_bytes) { + RUY_DCHECK_GT(num_bytes, 0); + RUY_DCHECK((num_bytes & (kMinimumBlockAlignment - 1)) == 0); + if (void* p = AllocateFast(num_bytes)) { + return p; + } + return AllocateSlow(num_bytes); + } + + void FreeAll() { + current_ = 0; + if (fallback_blocks_.empty()) { + return; + } + + // No rounding-up of the size means linear instead of logarithmic + // bound on the number of allocation in some worst-case calling patterns. + // This is considered worth it because minimizing memory usage is important + // and actual calling patterns in applications that we care about still + // reach the no-further-allocations steady state in a small finite number + // of iterations. + std::ptrdiff_t new_size = size_ + fallback_blocks_total_size_; + SystemAlignedFree(ptr_); + ptr_ = SystemAlignedAlloc(new_size); + size_ = new_size; + + for (void* p : fallback_blocks_) { + SystemAlignedFree(p); + } + fallback_blocks_.clear(); + fallback_blocks_total_size_ = 0; + } + + private: + void* AllocateFast(std::ptrdiff_t num_bytes) { + if (current_ + num_bytes > size_) { + return nullptr; + } + void* ret = VoidPtrAdd(ptr_, current_); + current_ += num_bytes; + return ret; + } + + void* AllocateSlow(std::ptrdiff_t num_bytes) { + void* p = SystemAlignedAlloc(num_bytes); + fallback_blocks_total_size_ += num_bytes; + fallback_blocks_.push_back(p); + return p; + } + + // Theory of operation: + // + // - ptr_, current_, and size_ implement a basic bump-ptr allocator. + // + // - in AllocateAlignedBytes, the fast path is just a bump-ptr + // allocation. If our bump-ptr allocator doesn't have enough space for an + // allocation, then we allocate a block from the system allocator to + // service the allocation request. We save that block in fallback_blocks_ + // and track the total size of the fallback blocks in + // fallback_blocks_total_size_. + // + // - in FreeAll, the fast path just resets the bump-ptr allocator. If + // there are any fallback blocks, we free them and reallocate the + // bump-ptr allocator's buffer so that the next sequence of allocations + // will hopefully not need any fallback blocks. + void* ptr_ = nullptr; + std::ptrdiff_t current_ = 0; + std::ptrdiff_t size_ = 0; + std::vector fallback_blocks_; + std::ptrdiff_t fallback_blocks_total_size_ = 0; +}; + +} // namespace detail + +// The main Allocator class, with a convenient interface for allocating a +// typed buffer. +class Allocator { + public: + void* AllocateBytes(std::ptrdiff_t num_bytes) { + if (num_bytes == 0) { + return nullptr; + } + return aligned.AllocateAlignedBytes( + round_up_pot(num_bytes, detail::kMinimumBlockAlignment)); + } + template + void Allocate(std::ptrdiff_t count, Pointer* out) { + using T = typename std::pointer_traits::element_type; + *out = static_cast(AllocateBytes(count * sizeof(T))); + } + + void FreeAll() { aligned.FreeAll(); } + + private: + detail::AlignedAllocator aligned; +}; + +} // namespace ruy + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_ALLOCATOR_H_ diff --git a/ruy/allocator_test.cc b/ruy/allocator_test.cc new file mode 100644 index 0000000..7f46a66 --- /dev/null +++ b/ruy/allocator_test.cc @@ -0,0 +1,103 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/allocator.h" + +#include "testing/base/public/gunit.h" + +namespace ruy { +namespace { + +TEST(AllocatorTest, ReturnsValidMemory) { + Allocator allocator; + int *p; + allocator.Allocate(1, &p); + ASSERT_NE(p, nullptr); + + // If this is bogus memory, ASan will cause this test to fail. + *p = 42; + + allocator.FreeAll(); +} + +TEST(AllocatorTest, NoLeak) { + Allocator allocator; + // Allocate and free some ridiculously large total amount of memory, so + // that a leak will hopefully cause some sort of resource exhaustion. + // + // Despite the large number of allocations, this test is actually quite + // fast, since our fast-path allocation logic is very fast. + constexpr int kNumAllocations = 100 * 1024; + constexpr int kAllocationSize = 1024 * 1024; + for (int i = 0; i < kNumAllocations; i++) { + char *p; + allocator.Allocate(kAllocationSize, &p); + allocator.FreeAll(); + } +} + +TEST(AllocatorTest, IncreasingSizes) { + Allocator allocator; + // Allocate sizes that increase by small amounts across FreeAll calls. + for (int i = 1; i < 100 * 1024; i++) { + char *p; + allocator.Allocate(i, &p); + allocator.FreeAll(); + } +} + +TEST(AllocatorTest, ManySmallAllocations) { + Allocator allocator; + // Allocate many small allocations between FreeAll calls. + for (int i = 0; i < 10 * 1024; i += 100) { + for (int j = 0; j < i; j++) { + char *p; + allocator.Allocate(1, &p); + } + allocator.FreeAll(); + } +} + +TEST(AllocatorTest, DestructorHandlesMainBumpPtr) { + // This is a white-box test. + Allocator allocator; + allocator.AllocateBytes(1); + allocator.FreeAll(); + // After the call to FreeAll, the allocator will consolidate all of the memory + // into the main bump-ptr allocator's block, which we then expect to be freed + // in the destructor. + // + // We have no test assertions -- we primarily expect that this trigger a leak + // checker and cause the test to fail. +} + +TEST(AllocatorTest, DestructorHandlesFallbackBlocks) { + // This is a white-box test. + Allocator allocator; + // Since we just created the allocator, this will allocate a fallback block, + // which we then expect to be freed in the destructor. + // + // We have no test assertions -- we primarily expect that this trigger a leak + // checker and cause the test to fail. + allocator.AllocateBytes(1); +} + +} // namespace +} // namespace ruy + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/ruy/benchmark.cc b/ruy/benchmark.cc new file mode 100644 index 0000000..6ce0b32 --- /dev/null +++ b/ruy/benchmark.cc @@ -0,0 +1,196 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "ruy/test.h" + +namespace ruy { + +using LhsScalar = RUY_TEST_LHSSCALAR; +using RhsScalar = RUY_TEST_RHSSCALAR; +using AccumScalar = RUY_TEST_ACCUMSCALAR; +using DstScalar = RUY_TEST_DSTSCALAR; +using TestSetType = + TestSet>; + +struct BenchmarkShape { + int rows; + int depth; + int cols; + int symm_lhs; + int symm_rhs; +}; + +template +std::vector>> BenchmarkRCC( + const BenchmarkShape& shape) { + TestSetType test_set; + test_set.rows = shape.rows; + test_set.depth = shape.depth; + test_set.cols = shape.cols; + test_set.lhs_order = Order::kRowMajor; + test_set.rhs_order = Order::kColMajor; + test_set.dst_order = Order::kColMajor; + test_set.layout_style = LayoutStyle::kPackedLinear; + test_set.benchmark = true; + const int asymmetry_lhs = shape.symm_lhs ? 0 : 1; + const int asymmetry_rhs = shape.symm_rhs ? 0 : 1; + test_set.lhs_zero_point = SymmetricZeroPoint() + asymmetry_lhs; + test_set.rhs_zero_point = SymmetricZeroPoint() + asymmetry_rhs; + test_set.use_specified_zero_points = true; + test_set.perchannel = GetBoolEnvVarOrFalse("PERCHANNEL"); + test_set.benchmark_prepack_lhs = GetBoolEnvVarOrFalse("PREPACK_LHS"); + test_set.benchmark_prepack_rhs = GetBoolEnvVarOrFalse("PREPACK_RHS"); + test_set.Run(); + return std::move(test_set.results); +} + +std::vector ParseCommaSeparatedInts( + const std::string& comma_separated_ints) { + std::vector result; + for (std::size_t pos = 0; pos < comma_separated_ints.size();) { + std::size_t delim_pos = comma_separated_ints.find(',', pos); + if (delim_pos == std::string::npos) { + delim_pos = comma_separated_ints.size(); + } + result.push_back( + std::stoi(comma_separated_ints.substr(pos, delim_pos - pos))); + pos = delim_pos + 1; + } + return result; +} + +void Benchmark() { + const bool symm_lhs = std::is_floating_point::value || + GetBoolEnvVarOrFalse("SYMM_LHS"); + const bool symm_rhs = std::is_floating_point::value || + GetBoolEnvVarOrFalse("SYMM_RHS"); + const bool benchmark_cubic = GetBoolEnvVarOrFalse("RUY_BENCHMARK_CUBIC") || + GetBoolEnvVarOrFalse("RUY_BENCHMARK_CUBIC_LIST"); + const int explicit_rows = GetIntEnvVarOrZero("ROWS"); + const int explicit_cols = GetIntEnvVarOrZero("COLS"); + const int explicit_depth = GetIntEnvVarOrZero("DEPTH"); + + std::vector shapes; + + if (benchmark_cubic) { + std::vector sizes; + const char* benchmark_cubic_list_env = getenv("RUY_BENCHMARK_CUBIC_LIST"); + if (benchmark_cubic_list_env) { + sizes = ParseCommaSeparatedInts(benchmark_cubic_list_env); + } else { + // Often 8 is used for this multiplier, but to check teeny sizes one can + // use 1. + static constexpr int cubic_size_multiplier = 8; + for (int i = 2 * cubic_size_multiplier; + i <= (512 * cubic_size_multiplier); i *= 2) { + sizes.push_back(i); + if (i < (512 * cubic_size_multiplier)) { + sizes.push_back(i * 3 / 2); + } + } + } + for (int i : sizes) { + BenchmarkShape shape; + // Even in cubic mode, one may still override an individual dimension + // to allow testing a batch of rectangular sizes. + shape.rows = explicit_rows ? explicit_rows : i; + shape.cols = explicit_cols ? explicit_cols : i; + shape.depth = explicit_depth ? explicit_depth : i; + shape.symm_lhs = symm_lhs; + shape.symm_rhs = symm_rhs; + shapes.push_back(shape); + } + } else { + BenchmarkShape shape; + shape.rows = explicit_rows; + shape.cols = explicit_cols; + shape.depth = explicit_depth; + if (!shape.rows || !shape.depth || !shape.cols) { + fprintf(stderr, + "Please specify positive sizes with these env vars: ROWS, DEPTH, " + "COLS.\n"); + exit(1); + } + shape.symm_lhs = symm_lhs; + shape.symm_rhs = symm_rhs; + shapes.push_back(shape); + } + + for (int i = 0; i < shapes.size(); i++) { + const auto& shape = shapes[i]; + const auto& results = BenchmarkRCC(shape); + if (i == 0) { + if (benchmark_cubic) { + printf("size"); + for (const auto& result : results) { + if (results.size() > 1) { + printf(",%s:Gop/s", PathName(*result).c_str()); + } else { + printf(",Gop/s"); + } + if (GetBoolEnvVarOrFalse("RUY_BENCHMARK_PMU")) { + printf( + ",l1_refill,l2_refill,l3_refill,l1tlb_refill,l2tlb_refill," + "mispred,frontend_stall,backend_stall"); + } + } + printf("\n"); + } else { + printf("path,shape,Gop/s\n"); + } + fflush(stdout); + } + if (benchmark_cubic) { + printf("%d", shape.rows); + for (const auto& result : results) { + printf(",%.4g", 2.0e-9 * shape.rows * shape.cols * shape.depth / + result->latency); + if (GetBoolEnvVarOrFalse("RUY_BENCHMARK_PMU")) { + printf(",%.3g,%.3g,%.3g,%.3g,%.3g,%.3g,%.3g,%.3g", + result->l1_refill_rate, result->l2_refill_rate, + result->l3_refill_rate, result->l1tlb_refill_rate, + result->l2tlb_refill_rate, result->mispred_rate, + result->frontend_stall_rate, result->backend_stall_rate); + } + } + printf("\n"); + fflush(stdout); + } else { + for (const auto& result : results) { + printf( + "%s,%dx%dx%d,%.4g", PathName(*result).c_str(), shape.rows, + shape.depth, shape.cols, + 2.0e-9 * shape.rows * shape.cols * shape.depth / result->latency); + if (GetBoolEnvVarOrFalse("RUY_BENCHMARK_PMU")) { + printf(",%.3g,%.3g,%.3g,%.3g,%.3g,%.3g,%.3g,%.3g", + result->l1_refill_rate, result->l2_refill_rate, + result->l3_refill_rate, result->l1tlb_refill_rate, + result->l2tlb_refill_rate, result->mispred_rate, + result->frontend_stall_rate, result->backend_stall_rate); + } + printf("\n"); + } + fflush(stdout); + } + } +} + +} // namespace ruy + +int main() { ruy::Benchmark(); } diff --git a/ruy/block_map.cc b/ruy/block_map.cc new file mode 100644 index 0000000..e1e6166 --- /dev/null +++ b/ruy/block_map.cc @@ -0,0 +1,486 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/block_map.h" + +#include +#include + +#ifdef RUY_MAKEBLOCKMAP_DEBUG +#include +#include +#include +#endif + +#include "ruy/check_macros.h" +#include "ruy/opt_set.h" +#include "ruy/profiler/instrumentation.h" +#include "ruy/size_util.h" + +namespace ruy { + +namespace { + +void DecodeTraversalLinear(int size_log2, std::uint32_t square_index, + SidePair* local_pos) { + (*local_pos)[Side::kLhs] = square_index & ((1 << size_log2) - 1); + (*local_pos)[Side::kRhs] = square_index >> size_log2; +} + +void DecodeTraversalFractalZ(std::uint32_t square_index, + SidePair* local_pos) { + const std::uint32_t n1 = square_index; + const std::uint32_t n2 = (n1 & 0x99999999u) | ((n1 & 0x44444444u) >> 1) | + ((n1 & 0x22222222u) << 1); + const std::uint32_t n4 = (n2 & 0xc3c3c3c3u) | ((n2 & 0x30303030u) >> 2) | + ((n2 & 0x0c0c0c0cu) << 2); + const std::uint32_t n8 = (n4 & 0xf00ff00fu) | ((n4 & 0x0f000f00u) >> 4) | + ((n4 & 0x00f000f0u) << 4); + const std::uint32_t n16 = (n8 & 0xff0000ffu) | ((n8 & 0x00ff0000u) >> 8) | + ((n8 & 0x0000ff00u) << 8); + (*local_pos)[Side::kLhs] = n16 & 0xffff; + (*local_pos)[Side::kRhs] = n16 >> 16; +} + +void DecodeTraversalFractalU(std::uint32_t square_index, + SidePair* local_pos) { + DecodeTraversalFractalZ(square_index, local_pos); + // Change fractal z-order to u-order + (*local_pos)[Side::kLhs] ^= (*local_pos)[Side::kRhs]; +} + +// Code inspired by the sample code in +// https://en.wikipedia.org/wiki/Hilbert_curve +// The main optimization is to avoid hard-to-predict conditional branches +// based on the bits of the square_index parameter. +void DecodeTraversalFractalHilbert(int size_log2, std::uint32_t square_index, + SidePair* local_pos) { + std::uint32_t t = square_index; + std::uint32_t x = 0; + std::uint32_t y = 0; + // Easy-to-predict for loop, the number of iterations is the same for + // an entire GEMM. + for (int sb = 0; sb < size_log2; sb++) { + std::uint32_t s = 1 << sb; + bool rx = t & 2; + bool ry = (t & 1) ^ rx; + std::uint32_t tmp = rx ? (s - 1 - x) : x; + x = ry ? x : rx ? (s - 1 - y) : y; + y = ry ? (y + s) : tmp; + x = rx ? (x + s) : x; + t >>= 2; + } + (*local_pos)[Side::kLhs] = y; + (*local_pos)[Side::kRhs] = x; +} + +} // end anonymous namespace + +void GetBlockByIndex(const BlockMap& block_map, int index, + SidePair* block) { + profiler::ScopeLabel label("GetBlockByIndex"); + const std::uint32_t index_u32 = index; + + const std::uint32_t num_blocks_per_local_curve = + 1u << (2 * block_map.num_blocks_base_log2); + const std::uint32_t square_index = + index_u32 & (num_blocks_per_local_curve - 1); + + const int size_log2 = block_map.num_blocks_base_log2; + SidePair local_pos; + switch (block_map.traversal_order) { + case BlockMapTraversalOrder::kFractalZ: + DecodeTraversalFractalZ(square_index, &local_pos); + break; + case BlockMapTraversalOrder::kFractalU: + DecodeTraversalFractalU(square_index, &local_pos); + break; + case BlockMapTraversalOrder::kFractalHilbert: + DecodeTraversalFractalHilbert(size_log2, square_index, &local_pos); + break; + default: + RUY_DCHECK(block_map.traversal_order == BlockMapTraversalOrder::kLinear); + DecodeTraversalLinear(size_log2, square_index, &local_pos); + break; + } + + const std::uint32_t rectangular_index = + index_u32 >> 2 * block_map.num_blocks_base_log2; + for (Side side : {Side::kLhs, Side::kRhs}) { + const std::uint32_t mask = (1u << block_map.rectangularness_log2[side]) - 1; + const int rectangular_offset = (rectangular_index & mask) + << block_map.num_blocks_base_log2; + (*block)[side] = local_pos[side] + rectangular_offset; + } +} + +BlockMapTraversalOrder GetTraversalOrder(int rows, int cols, int depth, + int lhs_scalar_size, + int rhs_scalar_size, + int local_data_cache_size, + int shared_data_cache_size) { + const int kFractalOptSets = + RUY_OPT_FRACTAL_Z | RUY_OPT_FRACTAL_U | RUY_OPT_FRACTAL_HILBERT; + const int working_set_size = + (lhs_scalar_size * rows + rhs_scalar_size * cols) * depth; + if (RUY_OPT_ENABLED(kFractalOptSets) && + (working_set_size > local_data_cache_size)) { + if (RUY_OPT_ENABLED(RUY_OPT_FRACTAL_HILBERT) && + (working_set_size > shared_data_cache_size)) { + return BlockMapTraversalOrder::kFractalHilbert; + } else if (RUY_OPT_ENABLED(RUY_OPT_FRACTAL_U)) { + return BlockMapTraversalOrder::kFractalU; + } else { + return BlockMapTraversalOrder::kFractalZ; + } + } else { + return BlockMapTraversalOrder::kLinear; + } +} + +namespace { + +int floor_log2_quotient(int num, int denom) { + if (num <= denom) { + return 0; + } + int log2_quotient = floor_log2(num) - ceil_log2(denom); + if ((denom << (log2_quotient + 1)) <= num) { + log2_quotient++; + } + return log2_quotient; +} + +// Computes the rectangularness of the matrix shape (rows, cols). This is +// essentially just the log2 of the quotient (rows / cols). The kernel_rows and +// kernel_cols only get into the picture for clamping bounds but don't affect +// the generic computation. +void GetRectangularness(int rows, int cols, int kernel_rows, int kernel_cols, + int* rows_rectangularness_log2, + int* cols_rectangularness_log2) { + *rows_rectangularness_log2 = 0; + *cols_rectangularness_log2 = 0; + + // In GEMV-ish cases, that is when kernel blocks are as narrow as the kernel + // itself, we risk having too small kernel blocks for good kernel + // amortization. We avoid that by limiting recangularness so that kernel + // blocks are not too tiny at least in that dimension. Specifically, we try to + // have at least (2^min_kernel_inner_loop_runs_log2) kernels fitting in each + // kernel block along the large dimension. + const int min_kernel_inner_loop_runs_log2 = 3; + if (rows > cols) { + int cols_of_kernel_inner_loop_runs_log2 = + ceil_log2(cols) - pot_log2(kernel_cols); + int min_rows_of_kernel_inner_loop_runs_log2 = + std::max(0, min_kernel_inner_loop_runs_log2 - + cols_of_kernel_inner_loop_runs_log2); + *rows_rectangularness_log2 = + std::min(floor_log2_quotient(rows, cols), + std::max(0, floor_log2(rows) - pot_log2(kernel_rows) - + min_rows_of_kernel_inner_loop_runs_log2)); + // Sanity check that we did not over-estimate rows_rectangularness_log2. + RUY_DCHECK_GE(rows >> *rows_rectangularness_log2, cols); + } else if (cols > rows) { + int rows_of_kernel_inner_loop_runs_log2 = + ceil_log2(rows) - pot_log2(kernel_rows); + int min_cols_of_kernel_inner_loop_runs_log2 = + std::max(0, min_kernel_inner_loop_runs_log2 - + rows_of_kernel_inner_loop_runs_log2); + *cols_rectangularness_log2 = + std::min(floor_log2_quotient(cols, rows), + std::max(0, floor_log2(cols) - pot_log2(kernel_cols) - + min_cols_of_kernel_inner_loop_runs_log2)); + // Sanity check that we did not over-estimate cols_rectangularness_log2. + RUY_DCHECK_GE(cols >> *cols_rectangularness_log2, rows); + } + RUY_DCHECK(!*rows_rectangularness_log2 || !*cols_rectangularness_log2); +} + +// Computes a 'multithreading score'. When multithreading, we need there to +// be at least as many tiles as there are threads, and hopefully +// substantially more than that, so we benefit from ruy's ability to +// dispatch fine-grained workloads to threads. +int GetMultithreadingScore(int block_size_log2, int rows, int cols, + int tentative_thread_count) { + const int num_full_blocks_of_rows = rows >> block_size_log2; + const int num_full_blocks_of_cols = cols >> block_size_log2; + const int candidate_num_full_blocks_log2 = floor_log2( + std::max(1, num_full_blocks_of_rows * num_full_blocks_of_cols)); + + // The values here have been tuned on ARM Cortex-A55. + // We expect this to have to be tuned differently for other CPUs. + if (tentative_thread_count == 1) { + return 0; + } else { + const int blocks_per_thread_log2 = + candidate_num_full_blocks_log2 - ceil_log2(tentative_thread_count); + if (blocks_per_thread_log2 < 0) { + return -64; + } else if (blocks_per_thread_log2 == 0) { + return -16; + } else if (blocks_per_thread_log2 == 1) { + return -8; + } else if (blocks_per_thread_log2 == 2) { + return 0; + } else if (blocks_per_thread_log2 == 3) { + return 8; + } else { + return 16; + } + } +} + +// Computes a 'cache locality score'. +int GetCacheLocalityScore(int block_size_log2, int rows, int cols, int depth, + int kernel_rows_log2, int kernel_cols_log2, + int lhs_scalar_size, int rhs_scalar_size, Path path, + int local_data_cache_size) { + // In the narrow case (e.g. matrix*vector), each byte of the big operand + // matrix (either LHS or RHS) is traversed only once, so any notion of data + // locality is irrelevant. Ignore the 'cache locality score' by forcing it to + // be 0 in that case. + if (rows <= (1 << kernel_rows_log2) || cols <= (1 << kernel_cols_log2)) { + return 0; + } + const int block_rows = std::min(1 << block_size_log2, rows); + const int block_cols = std::min(1 << block_size_log2, cols); + const int total_read_bytes = + (lhs_scalar_size * block_rows + rhs_scalar_size * block_cols) * depth; + const int total_read_bytes_log2 = ceil_log2(total_read_bytes); + const int nonlocality_log2 = + total_read_bytes_log2 - floor_log2(local_data_cache_size); + // The values here have been tuned on ARM Cortex-A55. + // We expect this to have to be tuned differently for other CPUs. + if (nonlocality_log2 < -1) { + return 64; + } else if (nonlocality_log2 == -1) { + return 56; + } else if (nonlocality_log2 == 0) { + return 48; + } else if (nonlocality_log2 == 1) { + return 32; + } else if (nonlocality_log2 == 2) { + return 16; + } else if (nonlocality_log2 == 3) { + return 0; + } else { + return -64; + } +} + +// Compute a 'kernel amortization score'. This is the notion that very small +// tiles result in more overhead outside of kernels, more complex memory +// access patterns and less benefits from ruy's fat kernels, so we reward +// larger blocks more than smaller ones. +int GetKernelAmortizationScore(int block_size_log2, int rows, int cols, + int kernel_rows_log2, int kernel_cols_log2) { + const int block_rows = std::min(1 << block_size_log2, rows); + const int block_cols = std::min(1 << block_size_log2, cols); + const int kernels_per_block_log2 = + floor_log2(block_rows * block_cols) - kernel_rows_log2 - kernel_cols_log2; + RUY_DCHECK_GE(kernels_per_block_log2, 0); + // The values here have been tuned on ARM Cortex-A55. + // We expect this to have to be tuned differently for other CPUs. + if (kernels_per_block_log2 == 0) { + return 0; + } else if (kernels_per_block_log2 == 1) { + return 8; + } else if (kernels_per_block_log2 == 2) { + return 16; + } else if (kernels_per_block_log2 == 3) { + return 24; + } else if (kernels_per_block_log2 == 4) { + return 32; + } else if (kernels_per_block_log2 == 5) { + return 40; + } else if (kernels_per_block_log2 == 6) { + return 48; + } else if (kernels_per_block_log2 == 7) { + return 56; + } else { + return 64; + } +} + +} // namespace + +void MakeBlockMap(int rows, int cols, int depth, int kernel_rows, + int kernel_cols, int lhs_scalar_size, int rhs_scalar_size, + int tentative_thread_count, Path path, + int local_data_cache_size, int shared_data_cache_size, + BlockMap* block_map) { + profiler::ScopeLabel label("MakeBlockMap"); + +#ifdef RUY_MAKEBLOCKMAP_DEBUG +#if RUY_MAKEBLOCKMAP_DEBUG >= 2 + static constexpr bool debug_everytime = true; +#else + static constexpr bool debug_everytime = false; +#endif + static bool firsttime = true; + if (firsttime || debug_everytime) { + fprintf(stderr, + "MakeBlockMap(rows=%d, cols=%d, depth=%d, kernel_rows=%d, " + "kernel_cols=%d, lhs_scalar_size=%d, rhs_scalar_size=%d, " + "tentative_thread_count=%d)\n", + rows, cols, depth, kernel_rows, kernel_cols, lhs_scalar_size, + rhs_scalar_size, tentative_thread_count); + } +#endif + + RUY_DCHECK_GE(rows, kernel_rows); + RUY_DCHECK_GE(cols, kernel_cols); + RUY_DCHECK_EQ(rows % kernel_rows, 0); + RUY_DCHECK_EQ(cols % kernel_cols, 0); + + block_map->traversal_order = + GetTraversalOrder(rows, cols, depth, lhs_scalar_size, rhs_scalar_size, + local_data_cache_size, shared_data_cache_size); + + int rows_rectangularness_log2 = 0; + int cols_rectangularness_log2 = 0; + GetRectangularness(rows, cols, kernel_rows, kernel_cols, + &rows_rectangularness_log2, &cols_rectangularness_log2); + + const int kernel_rows_log2 = pot_log2(kernel_rows); + const int kernel_cols_log2 = pot_log2(kernel_cols); + const int kernel_size_log2 = std::max(kernel_cols_log2, kernel_rows_log2); + + const int size = std::min(rows, cols); + const int size_log2 = std::max(kernel_size_log2, floor_log2(size)); + + RUY_DCHECK_GE(size_log2, kernel_size_log2); + + // We are going to try candidate values for block_size_log2 ranging from + // kernel_size_log2 to (kernel_size_log2 + kMaxKernelsPerBlockLog2). + // For each of them we will compute a 'score' by adding individual scores + // for a few different considerations, all of which is entirely empirical. + // The values (and possibly the logic) around here are all subject to tuning + // based on benchmarks on different hardware. The current values are based + // on benchmarking on Qualcomm S855 (big and little cores), arm64, + // kNeonDotprod, 8bit quantized path. Don't read too much into it, go ahead + // and tune this as needed to achieve good performance elsewhere. Use + // the unit test, block_map_test, to encode values that should be preserved + // on specific architectures. Use RUY_MAKEBLOCKMAP_DEBUG to help tuning this. + static constexpr int kMaxKernelsPerBlockLog2 = 6; + const int max_block_size_log2 = + std::min(size_log2, kernel_size_log2 + kMaxKernelsPerBlockLog2); + int best_score = std::numeric_limits::min(); + int best_score_block_size_log2 = -1; + for (int block_size_log2 = kernel_size_log2; + block_size_log2 <= max_block_size_log2; block_size_log2++) { + const int multithreading_score = GetMultithreadingScore( + block_size_log2, rows, cols, tentative_thread_count); + const int cache_locality_score = GetCacheLocalityScore( + block_size_log2, rows, cols, depth, kernel_rows_log2, kernel_cols_log2, + lhs_scalar_size, rhs_scalar_size, path, local_data_cache_size); + const int kernel_amortization_score = GetKernelAmortizationScore( + block_size_log2, rows, cols, kernel_rows_log2, kernel_cols_log2); + const int score = + multithreading_score + cache_locality_score + kernel_amortization_score; +#ifdef RUY_MAKEBLOCKMAP_DEBUG + if (firsttime || debug_everytime) { + fprintf(stderr, + "block_size_log2=%d: score=%d multithreading_score=%d " + "cache_locality_score=%d kernel_amortization_score=%d\n", + block_size_log2, score, multithreading_score, + cache_locality_score, kernel_amortization_score); + } +#endif + if (score >= best_score) { + best_score = score; + best_score_block_size_log2 = block_size_log2; + } + } + +#ifdef RUY_MAKEBLOCKMAP_DEBUG + if (firsttime || debug_everytime) { + fprintf(stderr, "best_score_block_size_log2=%d\n", + best_score_block_size_log2); + } + + static const char* explicit_block_size_log2_env = + getenv("RUY_MAKEBLOCKMAP_EXPLICIT_BLOCK_SIZE_LOG2"); + if (explicit_block_size_log2_env) { + best_score_block_size_log2 = std::stoi(explicit_block_size_log2_env); + if (firsttime || debug_everytime) { + fprintf(stderr, "Overridden best_score_block_size_log2=%d\n", + best_score_block_size_log2); + } + } + firsttime = false; +#endif + + int num_blocks_base_log2 = size_log2 - best_score_block_size_log2; + RUY_DCHECK_GE(num_blocks_base_log2, 0); + + const int num_blocks_of_rows_log2 = + num_blocks_base_log2 + rows_rectangularness_log2; + const int num_blocks_of_cols_log2 = + num_blocks_base_log2 + cols_rectangularness_log2; + + const int smallr = + round_down_pot(rows >> num_blocks_of_rows_log2, kernel_rows); + const int smallc = + round_down_pot(cols >> num_blocks_of_cols_log2, kernel_cols); + const int missr = + round_up_pot(rows - (smallr << num_blocks_of_rows_log2), kernel_rows) >> + pot_log2(kernel_rows); + const int missc = + round_up_pot(cols - (smallc << num_blocks_of_cols_log2), kernel_cols) >> + pot_log2(kernel_cols); + + block_map->dims[Side::kLhs] = rows; + block_map->dims[Side::kRhs] = cols; + block_map->kernel_dims[Side::kLhs] = kernel_rows; + block_map->kernel_dims[Side::kRhs] = kernel_cols; + block_map->num_blocks_base_log2 = num_blocks_base_log2; + block_map->rectangularness_log2[Side::kLhs] = rows_rectangularness_log2; + block_map->rectangularness_log2[Side::kRhs] = cols_rectangularness_log2; + block_map->small_block_dims[Side::kLhs] = smallr; + block_map->small_block_dims[Side::kRhs] = smallc; + block_map->large_blocks[Side::kLhs] = missr; + block_map->large_blocks[Side::kRhs] = missc; + // Done last: NumBlocks needs some of the block_map fields to be already set. + block_map->thread_count = + std::min(tentative_thread_count, NumBlocks(*block_map)); +} + +void GetBlockMatrixCoords(Side side, const BlockMap& block_map, int block, + int* start, int* end) { + profiler::ScopeLabel label("GetBlockMatrixCoords"); + *start = block * block_map.small_block_dims[side] + + std::min(block, block_map.large_blocks[side]) * + block_map.kernel_dims[side]; + *end = + *start + block_map.small_block_dims[side] + + (block < block_map.large_blocks[side] ? block_map.kernel_dims[side] : 0); + + RUY_DCHECK_EQ(0, *start % block_map.kernel_dims[side]); + RUY_DCHECK_EQ(0, *end % block_map.kernel_dims[side]); + RUY_DCHECK_LE(*end, block_map.dims[side]); + RUY_DCHECK_LT(*start, *end); + RUY_DCHECK_GE(*start, 0); +} + +void GetBlockMatrixCoords(const BlockMap& block_map, const SidePair& block, + SidePair* start, SidePair* end) { + for (Side side : {Side::kLhs, Side::kRhs}) { + GetBlockMatrixCoords(side, block_map, block[side], &(*start)[side], + &(*end)[side]); + } +} + +} // namespace ruy diff --git a/ruy/block_map.h b/ruy/block_map.h new file mode 100644 index 0000000..5e1cee0 --- /dev/null +++ b/ruy/block_map.h @@ -0,0 +1,161 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_BLOCK_MAP_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_BLOCK_MAP_H_ + +#include "ruy/path.h" +#include "ruy/side_pair.h" + +namespace ruy { + +enum class BlockMapTraversalOrder { + // Plain old row-by-row or column-by-column traversal. + kLinear, + // Fractal Z-order curve, https://en.wikipedia.org/wiki/Z-order_curve + kFractalZ, + // Variant of Z-order doing a U instead of a Z. + kFractalU, + // Hilbert curve, https://en.wikipedia.org/wiki/Hilbert_curve + kFractalHilbert +}; + +// A BlockMap describes a tiling of a matrix, typically the destination matrix +// of a matrix multiplication computation. As is standard in matrix +// multiplication, a tile is called a "block". +// +// Ruy subdivides work by blocks of the destination matrix: each thread fully +// computes a block at once, then moves on to another block; each block is +// produced by a single thread. +// +// This ensures that the workloads for each block are mutually independent, +// which reduces synchronization requirements. +// +// Typically, a matrix multiplication will early on create a BlockMap by +// calling MakeBlockMap. It will then query the number of blocks in that +// BlockMap by calling NumBlocks. It will then create a single atomic integer +// counter indexing these blocks, called the 'index', and will distribute +// work to its N threads by ensuring that each thread works on disjoint sets +// of index values. For a given index value, the thread will call +// GetBlockByIndex to get the corresponding block, then GetBlockMatrixCoords +// to find the actual row and column numbers of this block. +// +// There are two nested levels of subdivision. On a local level, the matrix is +// tiled into a square NxN grid where N is a power of two, specifically: +// N = 2^num_blocks_base_log2. +// +// At a larger scale, around these blocks, there may be one further +// level of subdivision, in only one dimension: either along rows or along +// columns. That is used to handle arbitrarily rectangular matrices. The +// aforementioned high-level block grid is square, so it does not readily fit +// well very rectangular matrices. +// +// Taking together these two nested levels of subdivision, the effective +// tiling is by +// 2^(num_blocks_base_log2 + rows_rectangularness_log2) +// blocks in the row dimension, and by +// 2^(num_blocks_base_log2 + cols_rectangularness_log2) +// blocks in the column dimension. See NumBlocksOfRows, NumBlocksOfCols. +// +// Either rows_rectangularness_log2 or cols_rectangularness_log2 must be zero. +// +// Finally, this BlockMap is designed to operate under alignment constraints: +// two fields, kernel_rows and kernel_cols, describe the requested alignment +// of the effective grid in both dimensions. The idea is to feed matrix +// multiplication kernels with tiles that fit their width as much as possible. +// Of course, if rows (resp. cols) is not a multiple of kernel_rows (resp. +// kernel_cols) then some tile will have to have unaligned size. BlockMap +// will only allow that to happen in the last position along each axis, so +// as to minimize the overhead incurred onto the matrix multiplication kernels. +struct BlockMap { + // The number of threads to use (to distribute the blocks to). + int thread_count; + // The order in which to traverse the matrix of which this BlockMap represents + // a tiling (hereafter "the matrix"). + BlockMapTraversalOrder traversal_order; + // The dimensions of the block_map, that is, of the destination + // matrix rounded up to next multiples of kernel_dims. + SidePair dims; + // Log2 of the minimum number of subdivisions of the grid along either axis. + int num_blocks_base_log2; + // Log2 of the additional subdivision of the rows/columns axis. + SidePair rectangularness_log2; + // Requested alignment of the subdivisions of the grid along the rows/columns + // axis. + SidePair kernel_dims; + // Internal helper. Minimum number of rows/columns in each block. + SidePair small_block_dims; + // Internal helper. Number of blocks along each dimension that need to have + // their size in that dimension be given by (small_block_dims + kernel_dims) + // instead of just small_block_dims. + SidePair large_blocks; +}; + +// Returns the traversal order to be used for the given matrix multiplication +// parameters. +BlockMapTraversalOrder GetTraversalOrder(int rows, int cols, int depth, + int lhs_scalar_size, + int rhs_scalar_size, + int local_data_cache_size, + int shared_data_cache_size); + +// Create a BlockMap suitable for tiling the destination matrix in a +// matrix multiplication with the given parameters. +void MakeBlockMap(int rows, int cols, int depth, int kernel_rows, + int kernel_cols, int lhs_scalar_size, int rhs_scalar_size, + int tentative_thread_count, Path path, + int local_data_cache_size, int shared_data_cache_size, + BlockMap* block_map); + +// Maps an integer index to a block position in the grid. +void GetBlockByIndex(const BlockMap& block_map, int index, + SidePair* block); + +// Given a block position in the grid, returns its actual +// position in the matrix that the BlockMap refers to in the dimension +// referred to by `side`: along rows if side==kLhs, along columns if +// side==kRhs. +void GetBlockMatrixCoords(Side side, const BlockMap& block_map, int block, + int* start, int* end); + +// Given a block position in the grid, returns its actual +// position in the matrix that the BlockMap refers to in terms of +// actual row/column indices. +void GetBlockMatrixCoords(const BlockMap& block_map, const SidePair& block, + SidePair* start, SidePair* end); + +// Returns the number of grid subdivisions along the rows dimension (if +// side == kLhs) or columns dimension (if side == kRhs). +inline int NumBlocksPerSide(Side side, const BlockMap& block_map) { + return 1 << (block_map.num_blocks_base_log2 + + block_map.rectangularness_log2[side]); +} + +// Returns the overall number of blocks in +// the BlockMap. The valid index values to pass to GetBlockByIndex are the +// integers from 0 to N-1 where N is the value returned here. +// +// Note that it is always true that +// NumBlocks == NumBlocksOfRows * NumBlocksOfCols +// because either rows_rectangularness_log2 or cols_rectangularness_log2 is 0. +inline int NumBlocks(const BlockMap& block_map) { + return 1 << (2 * block_map.num_blocks_base_log2 + + block_map.rectangularness_log2[Side::kLhs] + + block_map.rectangularness_log2[Side::kRhs]); +} + +} // namespace ruy + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_BLOCK_MAP_H_ diff --git a/ruy/block_map_test.cc b/ruy/block_map_test.cc new file mode 100644 index 0000000..24646cf --- /dev/null +++ b/ruy/block_map_test.cc @@ -0,0 +1,263 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/block_map.h" + +#include +#include +#include +#include +#include + +#include "testing/base/public/gunit.h" +#include "ruy/cpu_cache_size.h" +#include "ruy/path.h" +#include "ruy/side_pair.h" + +namespace ruy { +namespace { + +#if RUY_PLATFORM(NEON_64) + +// Unless otherwise specified, these tests have been tuned on ARM Cortex-A55. +void MakeBlockMapTuningTest(int rows, int cols, int depth, int kernel_rows, + int kernel_cols, int lhs_scalar_size, + int rhs_scalar_size, int tentative_thread_count, + Path path, int expected_num_blocks_base_log2, + int expected_rectangularness_log2) { + BlockMap block_map; + MakeBlockMap(rows, cols, depth, kernel_rows, kernel_cols, lhs_scalar_size, + rhs_scalar_size, tentative_thread_count, path, + LocalDataCacheSize(path), SharedDataCacheSize(path), &block_map); + EXPECT_EQ(block_map.num_blocks_base_log2, expected_num_blocks_base_log2); + EXPECT_EQ(std::min(block_map.rectangularness_log2[Side::kLhs], + block_map.rectangularness_log2[Side::kRhs]), + 0); + EXPECT_EQ(std::max(block_map.rectangularness_log2[Side::kLhs], + block_map.rectangularness_log2[Side::kRhs]), + expected_rectangularness_log2); +} + +TEST(BlockMapTest, MakeBlockMapTuningTest8bitCubicShapesOneThreadNeonDotprod) { + MakeBlockMapTuningTest(32, 32, 32, 8, 8, 1, 1, /* tentative_thread_count */ 1, + Path::kNeonDotprod, + /* expected_num_blocks_base_log2 */ 0, + /* expected_rectangularness_log2 */ 0); + MakeBlockMapTuningTest(48, 48, 48, 8, 8, 1, 1, /* tentative_thread_count */ 1, + Path::kNeonDotprod, + /* expected_num_blocks_base_log2 */ 0, + /* expected_rectangularness_log2 */ 0); + MakeBlockMapTuningTest(64, 64, 64, 8, 8, 1, 1, /* tentative_thread_count */ 1, + Path::kNeonDotprod, + /* expected_num_blocks_base_log2 */ 0, + /* expected_rectangularness_log2 */ 0); + MakeBlockMapTuningTest(96, 96, 96, 8, 8, 1, 1, /* tentative_thread_count */ 1, + Path::kNeonDotprod, + /* expected_num_blocks_base_log2 */ 0, + /* expected_rectangularness_log2 */ 0); + MakeBlockMapTuningTest(128, 128, 128, 8, 8, 1, 1, + /* tentative_thread_count */ 1, Path::kNeonDotprod, + /* expected_num_blocks_base_log2 */ 0, + /* expected_rectangularness_log2 */ 0); + MakeBlockMapTuningTest(192, 192, 192, 8, 8, 1, 1, + /* tentative_thread_count */ 1, Path::kNeonDotprod, + /* expected_num_blocks_base_log2 */ 0, + /* expected_rectangularness_log2 */ 0); + MakeBlockMapTuningTest(256, 256, 256, 8, 8, 1, 1, + /* tentative_thread_count */ 1, Path::kNeonDotprod, + /* expected_num_blocks_base_log2 */ 1, + /* expected_rectangularness_log2 */ 0); + MakeBlockMapTuningTest(384, 384, 384, 8, 8, 1, 1, + /* tentative_thread_count */ 1, Path::kNeonDotprod, + /* expected_num_blocks_base_log2 */ 1, + /* expected_rectangularness_log2 */ 0); +} + +TEST(BlockMapTest, + MakeBlockMapTuningTest8bitCubicShapesFourThreadsNeonDotprod) { + MakeBlockMapTuningTest(32, 32, 32, 8, 8, 1, 1, /* tentative_thread_count */ 4, + Path::kNeonDotprod, + /* expected_num_blocks_base_log2 */ 1, + /* expected_rectangularness_log2 */ 0); + MakeBlockMapTuningTest(48, 48, 48, 8, 8, 1, 1, /* tentative_thread_count */ 4, + Path::kNeonDotprod, + /* expected_num_blocks_base_log2 */ 1, + /* expected_rectangularness_log2 */ 0); + MakeBlockMapTuningTest(64, 64, 64, 8, 8, 1, 1, /* tentative_thread_count */ 4, + Path::kNeonDotprod, + /* expected_num_blocks_base_log2 */ 1, + /* expected_rectangularness_log2 */ 0); + MakeBlockMapTuningTest(96, 96, 96, 8, 8, 1, 1, /* tentative_thread_count */ 4, + Path::kNeonDotprod, + /* expected_num_blocks_base_log2 */ 1, + /* expected_rectangularness_log2 */ 0); + MakeBlockMapTuningTest(128, 128, 128, 8, 8, 1, 1, + /* tentative_thread_count */ 4, Path::kNeonDotprod, + /* expected_num_blocks_base_log2 */ 1, + /* expected_rectangularness_log2 */ 0); + MakeBlockMapTuningTest(192, 192, 192, 8, 8, 1, 1, + /* tentative_thread_count */ 4, Path::kNeonDotprod, + /* expected_num_blocks_base_log2 */ 1, + /* expected_rectangularness_log2 */ 0); + MakeBlockMapTuningTest(256, 256, 256, 8, 8, 1, 1, + /* tentative_thread_count */ 4, Path::kNeonDotprod, + /* expected_num_blocks_base_log2 */ 2, + /* expected_rectangularness_log2 */ 0); + MakeBlockMapTuningTest(384, 384, 384, 8, 8, 1, 1, + /* tentative_thread_count */ 4, Path::kNeonDotprod, + /* expected_num_blocks_base_log2 */ 2, + /* expected_rectangularness_log2 */ 0); +} + +TEST(BlockMapTest, MakeBlockMapTuningTest32bit) { + MakeBlockMapTuningTest(256, 256, 256, 8, 8, 4, 4, + /* tentative_thread_count */ 4, Path::kNeonDotprod, + /* expected_num_blocks_base_log2 */ 3, + /* expected_rectangularness_log2 */ 0); + MakeBlockMapTuningTest(4096, 4096, 4096, 8, 8, 4, 4, + /* tentative_thread_count */ 4, Path::kNeonDotprod, + /* expected_num_blocks_base_log2 */ 7, + /* expected_rectangularness_log2 */ 0); +} + +TEST(BlockMapTest, MakeBlockMapTuningTestRectangular) { + MakeBlockMapTuningTest(256, 16, 256, 8, 8, 1, 1, + /* tentative_thread_count */ 1, Path::kNeonDotprod, + /* expected_num_blocks_base_log2 */ 0, + /* expected_rectangularness_log2 */ 3); + MakeBlockMapTuningTest(24, 2400, 256, 8, 8, 1, 1, + /* tentative_thread_count */ 1, Path::kNeonDotprod, + /* expected_num_blocks_base_log2 */ 0, + /* expected_rectangularness_log2 */ 6); +} + +#endif + +int L1Distance(const SidePair& a, const SidePair& b) { + return std::abs(a[Side::kLhs] - b[Side::kLhs]) + + std::abs(a[Side::kRhs] - b[Side::kRhs]); +} + +void GetBlockByIndexSquareTest(int num_blocks_base_log2, + BlockMapTraversalOrder traversal_order) { + // Arbitrary, does not affect this test. 3 is just a typical value. + constexpr int kKernelSizeLog2 = 3; + + const int size_log2 = num_blocks_base_log2 + kKernelSizeLog2; + BlockMap block_map; + block_map.thread_count = 1; + block_map.traversal_order = traversal_order; + block_map.num_blocks_base_log2 = num_blocks_base_log2; + for (Side side : {Side::kLhs, Side::kRhs}) { + block_map.dims[side] = 1 << size_log2; + block_map.rectangularness_log2[side] = 0; + block_map.kernel_dims[side] = 1 << kKernelSizeLog2; + block_map.small_block_dims[side] = block_map.kernel_dims[side]; + block_map.large_blocks[side] = 0; + } + + const int num_blocks_per_side = 1 << num_blocks_base_log2; + const int num_blocks = num_blocks_per_side * num_blocks_per_side; + EXPECT_EQ(num_blocks, NumBlocks(block_map)); + + // Perform a full traversal of all blocks, as if computing a whole matrix + // multiplication. + // + // Used to record how many times each block was hit by the traversal. + std::vector block_hit_counts(num_blocks); + // Here we guard an assumption that all traversal orders start at (0, 0). + SidePair previous_block_coords(0, 0); + // Sum of L1 norm of the coordinate change at every step of the traversal. + std::int64_t total_l1_distance = 0; + // Number of jumps i.e. traversal steps with a L1 norm greater than 1. + int discontinuity_count = 0; + for (int block_index = 0; block_index < num_blocks; block_index++) { + SidePair block_coords; + GetBlockByIndex(block_map, block_index, &block_coords); + ++block_hit_counts[block_coords[Side::kLhs] + + num_blocks_per_side * block_coords[Side::kRhs]]; + int distance = L1Distance(block_coords, previous_block_coords); + total_l1_distance += distance; + discontinuity_count += (distance > 1); + previous_block_coords = block_coords; + } + + // Verify that each block was traversed exactly once. + for (int l = 0; l < num_blocks_per_side; l++) { + for (int r = 0; r < num_blocks_per_side; r++) { + EXPECT_EQ(block_hit_counts[l + num_blocks_per_side * r], 1); + } + } + + // Verify that the discontinuity_count and total_l1_distance are as expected + // for the given traversal_order. + switch (traversal_order) { + case BlockMapTraversalOrder::kFractalHilbert: + // No discontinuity at all with this space-filling continuous curve! + EXPECT_EQ(discontinuity_count, 0); + // Therefore, total_l1_distance has to be the number of blocks minus one. + EXPECT_EQ(total_l1_distance, num_blocks - 1); + break; + case BlockMapTraversalOrder::kLinear: + EXPECT_EQ(discontinuity_count, num_blocks_per_side - 1); + EXPECT_EQ(total_l1_distance, + 2 * num_blocks_per_side * (num_blocks_per_side - 1)); + break; + case BlockMapTraversalOrder::kFractalZ: + EXPECT_EQ(discontinuity_count, num_blocks > 1 ? (num_blocks / 2 - 1) : 0); + EXPECT_EQ(total_l1_distance, + 2 * num_blocks_per_side * (num_blocks_per_side - 1)); + break; + case BlockMapTraversalOrder::kFractalU: { + if (num_blocks_base_log2 == 0) { + EXPECT_EQ(discontinuity_count, 0); + EXPECT_EQ(total_l1_distance, 0); + } else { + int expected_discontinuity_count = 0; + int expected_total_l1_distance = 3; + for (int i = 2; i <= num_blocks_base_log2; i++) { + expected_discontinuity_count = 4 * expected_discontinuity_count + 2; + expected_total_l1_distance = + 4 * expected_total_l1_distance + (1 << (i + 1)) - 1; + } + EXPECT_EQ(discontinuity_count, expected_discontinuity_count); + EXPECT_EQ(total_l1_distance, expected_total_l1_distance); + } + break; + } + default: + abort(); + } +} + +TEST(BlockMapTest, GetBlockByIndexSquare) { + for (int num_blocks_base_log2 = 0; num_blocks_base_log2 <= 10; + num_blocks_base_log2++) { + for (BlockMapTraversalOrder traversal_order : + {BlockMapTraversalOrder::kLinear, BlockMapTraversalOrder::kFractalZ, + BlockMapTraversalOrder::kFractalU, + BlockMapTraversalOrder::kFractalHilbert}) { + GetBlockByIndexSquareTest(num_blocks_base_log2, traversal_order); + } + } +} + +} // namespace +} // namespace ruy + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/ruy/blocking_counter.cc b/ruy/blocking_counter.cc new file mode 100644 index 0000000..ffa7ac0 --- /dev/null +++ b/ruy/blocking_counter.cc @@ -0,0 +1,49 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/blocking_counter.h" + +#include "ruy/check_macros.h" +#include "ruy/wait.h" + +namespace ruy { + +void BlockingCounter::Reset(int initial_count) { + int old_count_value = count_.load(std::memory_order_relaxed); + RUY_DCHECK_EQ(old_count_value, 0); + (void)old_count_value; + count_.store(initial_count, std::memory_order_release); +} + +bool BlockingCounter::DecrementCount() { + int old_count_value = count_.fetch_sub(1, std::memory_order_acq_rel); + RUY_DCHECK_GT(old_count_value, 0); + int count_value = old_count_value - 1; + bool hit_zero = (count_value == 0); + if (hit_zero) { + std::lock_guard lock(count_mutex_); + count_cond_.notify_all(); + } + return hit_zero; +} + +void BlockingCounter::Wait() { + const auto& condition = [this]() { + return count_.load(std::memory_order_acquire) == 0; + }; + ruy::Wait(condition, &count_cond_, &count_mutex_); +} + +} // namespace ruy diff --git a/ruy/blocking_counter.h b/ruy/blocking_counter.h new file mode 100644 index 0000000..878f0e7 --- /dev/null +++ b/ruy/blocking_counter.h @@ -0,0 +1,62 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_BLOCKING_COUNTER_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_BLOCKING_COUNTER_H_ + +#include +#include // NOLINT(build/c++11) // IWYU pragma: keep +#include // NOLINT(build/c++11) // IWYU pragma: keep + +namespace ruy { + +// A BlockingCounter lets one thread to wait for N events to occur. +// This is how the master thread waits for all the worker threads +// to have finished working. +// The waiting is done using a naive spinlock waiting for the atomic +// count_ to hit the value 0. This is acceptable because in our usage +// pattern, BlockingCounter is used only to synchronize threads after +// short-lived tasks (performing parts of the same GEMM). It is not used +// for synchronizing longer waits (resuming work on the next GEMM). +class BlockingCounter { + public: + BlockingCounter() : count_(0) {} + + // Sets/resets the counter; initial_count is the number of + // decrementing events that the Wait() call will be waiting for. + void Reset(int initial_count); + + // Decrements the counter; if the counter hits zero, signals + // the threads that were waiting for that, and returns true. + // Otherwise (if the decremented count is still nonzero), + // returns false. + bool DecrementCount(); + + // Waits for the N other threads (N having been set by Reset()) + // to hit the BlockingCounter. + void Wait(); + + private: + std::atomic count_; + + // The condition variable and mutex allowing to passively wait for count_ + // to reach the value zero, in the case of longer waits. + std::condition_variable count_cond_; + std::mutex count_mutex_; +}; + +} // namespace ruy + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_BLOCKING_COUNTER_H_ diff --git a/ruy/build_defs.bzl b/ruy/build_defs.bzl new file mode 100644 index 0000000..964ede3 --- /dev/null +++ b/ruy/build_defs.bzl @@ -0,0 +1,54 @@ +"""Build definitions for Ruy. + +In some cases these are used to configure specific targets for +specific platforms, and dispatch is based on runtime capability detection. +""" + +# 1. Enable -mfpu=neon unconditionally on ARM32. If it turns out that we need to support +# ARM32 without NEON then we'll implement runtime detection and dispatch at that point. +# 2. Explicitly pass -O3 on optimization configs where just "-c opt" means "optimize for code size". + +def ruy_copts_base(): + return select({ + ":armeabi-v7a": [ + "-mfpu=neon", + ], + "//conditions:default": [], + }) + select({ + ":optimized": ["-O3"], + "//conditions:default": [], + }) + +# Used for targets that are compiled with extra features that are skipped at runtime if unavailable. +def ruy_copts_skylake(): + return select({ + ":x86_64": ["-march=skylake-avx512"], + "//conditions:default": [], + }) + +# Used for targets that are compiled with extra features that are skipped at runtime if unavailable. +def ruy_copts_avx2(): + return select({ + ":x86_64": ["-mavx2", "-mfma"], + "//conditions:default": [], + }) + +# TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. +# Optimization is not finished. In particular the dimensions of the kernel +# blocks can be changed as desired. +# +# Used for targets that are compiled with extra features that are skipped at runtime if unavailable. +def ruy_copts_sse42(): + return [] + +# TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. +# Optimization is not finished. In particular the dimensions of the kernel +# blocks can be changed as desired. +# +# Used for targets that are compiled with extra features that are skipped at runtime if unavailable. +def ruy_copts_avxvnni(): + return select({ + # TODO(b/146494398): Reinstate flag, something like "-march=cascadelake". + ":x86_64": [], + "//conditions:default": [], + }) diff --git a/ruy/build_defs.bzl.opensource b/ruy/build_defs.bzl.opensource new file mode 100644 index 0000000..9bccccf --- /dev/null +++ b/ruy/build_defs.bzl.opensource @@ -0,0 +1,40 @@ +"""Build definitions for Ruy.""" + +# 1. Enable -mfpu=neon unconditionally on ARM32. If it turns out that we need to support +# ARM32 without NEON then we'll implement runtime detection and dispatch at that point. +# 2. Explicitly pass -O3 on optimization configs where just "-c opt" means "optimize for code size". + +def ruy_copts_base(): + return select({ + ":armeabi-v7a": [ + "-mfpu=neon", + ], + "//conditions:default": [], + }) + select({ + ":optimized": ["-O3"], + "//conditions:default": [], + }) + +# Used for targets that are compiled with extra features that are skipped at runtime if unavailable. +def ruy_copts_skylake(): + return [] + +# Used for targets that are compiled with extra features that are skipped at runtime if unavailable. +def ruy_copts_avx2(): + return [] + +# TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. +# Optimization is not finished. In particular the dimensions of the kernel +# blocks can be changed as desired. +# +# Used for targets that are compiled with extra features that are skipped at runtime if unavailable. +def ruy_copts_sse42(): + return [] + +# TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. +# Optimization is not finished. In particular the dimensions of the kernel +# blocks can be changed as desired. +# +# Used for targets that are compiled with extra features that are skipped at runtime if unavailable. +def ruy_copts_avxvnni(): + return [] diff --git a/ruy/check_macros.h b/ruy/check_macros.h new file mode 100644 index 0000000..773f37d --- /dev/null +++ b/ruy/check_macros.h @@ -0,0 +1,138 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_CHECK_MACROS_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_CHECK_MACROS_H_ + +#include +#include +#include + +namespace ruy { +namespace check_macros { + +constexpr int kValueBufSize = 32; + +template +struct ToString { + static void Run(const T& value, char* buf) { + snprintf(buf, kValueBufSize, "(?)"); + } +}; + +template <> +struct ToString { + static void Run(float value, char* buf) { + snprintf(buf, kValueBufSize, "%.9g", static_cast(value)); + } +}; + +template <> +struct ToString { + static void Run(double value, char* buf) { + snprintf(buf, kValueBufSize, "%.16g", value); + } +}; + +template +struct ToString::value>::type> { + static void Run(const T& value, char* buf) { + snprintf(buf, kValueBufSize, "%lld", static_cast(value)); + } +}; + +template +struct ToString { + static void Run(T* value, char* buf) { + snprintf(buf, kValueBufSize, "%p", value); + } +}; + +template +struct ToString::value>::type> { + static void Run(const T& value, char* buf) { + snprintf(buf, kValueBufSize, "(enum value %d)", static_cast(value)); + } +}; + +inline void Failure(const char* file, int line, const char* macro, + const char* condition) { + fprintf(stderr, "%s:%d: %s condition not satisfied: %s\n", file, line, macro, + condition); + abort(); +} + +template +inline void Failure(const char* file, int line, const char* macro, + const char* lhs, const LhsType& lhs_value, const char* op, + const char* rhs, const RhsType& rhs_value) { + char lhs_value_buf[kValueBufSize]; + ToString::Run(lhs_value, lhs_value_buf); + char rhs_value_buf[kValueBufSize]; + ToString::Run(rhs_value, rhs_value_buf); + fprintf(stderr, + "%s:%d: %s condition not satisfied: [ %s %s %s ] with values [ " + "%s %s %s ].\n", + file, line, macro, lhs, op, rhs, lhs_value_buf, op, rhs_value_buf); + abort(); +} + +#define RUY_CHECK_IMPL(macro, condition) \ + do { \ + if (!(condition)) { \ + ruy::check_macros::Failure(__FILE__, __LINE__, #macro, #condition); \ + } \ + } while (false) + +#define RUY_CHECK_OP_IMPL(macro, lhs, op, rhs) \ + do { \ + const auto& lhs_value = (lhs); \ + const auto& rhs_value = (rhs); \ + if (!(lhs_value op rhs_value)) { \ + ruy::check_macros::Failure(__FILE__, __LINE__, #macro, #lhs, lhs_value, \ + #op, #rhs, rhs_value); \ + } \ + } while (false) + +#define RUY_CHECK(condition) RUY_CHECK_IMPL(RUY_CHECK, condition) +#define RUY_CHECK_EQ(x, y) RUY_CHECK_OP_IMPL(RUY_CHECK_EQ, x, ==, y) +#define RUY_CHECK_NE(x, y) RUY_CHECK_OP_IMPL(RUY_CHECK_NE, x, !=, y) +#define RUY_CHECK_GE(x, y) RUY_CHECK_OP_IMPL(RUY_CHECK_GE, x, >=, y) +#define RUY_CHECK_GT(x, y) RUY_CHECK_OP_IMPL(RUY_CHECK_GT, x, >, y) +#define RUY_CHECK_LE(x, y) RUY_CHECK_OP_IMPL(RUY_CHECK_LE, x, <=, y) +#define RUY_CHECK_LT(x, y) RUY_CHECK_OP_IMPL(RUY_CHECK_LT, x, <, y) + +#ifdef NDEBUG +#define RUY_DCHECK(condition) +#define RUY_DCHECK_EQ(x, y) +#define RUY_DCHECK_NE(x, y) +#define RUY_DCHECK_GE(x, y) +#define RUY_DCHECK_GT(x, y) +#define RUY_DCHECK_LE(x, y) +#define RUY_DCHECK_LT(x, y) +#else +#define RUY_DCHECK(condition) RUY_CHECK(condition) +#define RUY_DCHECK_EQ(x, y) RUY_CHECK_EQ(x, y) +#define RUY_DCHECK_NE(x, y) RUY_CHECK_NE(x, y) +#define RUY_DCHECK_GE(x, y) RUY_CHECK_GE(x, y) +#define RUY_DCHECK_GT(x, y) RUY_CHECK_GT(x, y) +#define RUY_DCHECK_LE(x, y) RUY_CHECK_LE(x, y) +#define RUY_DCHECK_LT(x, y) RUY_CHECK_LT(x, y) +#endif + +} // end namespace check_macros +} // end namespace ruy + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_CHECK_MACROS_H_ diff --git a/ruy/check_macros_test.cc b/ruy/check_macros_test.cc new file mode 100644 index 0000000..7e47e7f --- /dev/null +++ b/ruy/check_macros_test.cc @@ -0,0 +1,153 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/check_macros.h" + +#include "testing/base/public/gunit.h" + +namespace { + +#define TEST_CONDITION_FOR_FAMILY(family, vacuously_succeeds, condition) \ + do { \ + if (vacuously_succeeds || (condition)) { \ + RUY_##family(condition); \ + } \ + } while (false) + +#define TEST_COMPARISON_FOR_FAMILY(family, vacuously_succeeds, op_name, x, op, \ + y) \ + do { \ + if (vacuously_succeeds || ((x)op(y))) { \ + RUY_##family##_##op_name(x, y); \ + } \ + } while (false) + +#ifdef NDEBUG +#define TEST_CONDITION(condition) \ + do { \ + TEST_CONDITION_FOR_FAMILY(CHECK, false, condition); \ + } while (false) +#define TEST_COMPARISON(op_name, x, op, y) \ + do { \ + TEST_COMPARISON_FOR_FAMILY(CHECK, false, op_name, x, op, y); \ + } while (false) +#else +#define TEST_CONDITION(condition) \ + do { \ + TEST_CONDITION_FOR_FAMILY(CHECK, false, condition); \ + TEST_CONDITION_FOR_FAMILY(DCHECK, false, condition); \ + } while (false) +#define TEST_COMPARISON(op_name, x, op, y) \ + do { \ + TEST_COMPARISON_FOR_FAMILY(CHECK, false, op_name, x, op, y); \ + TEST_COMPARISON_FOR_FAMILY(DCHECK, false, op_name, x, op, y); \ + } while (false) + +#endif + +template +void TestEqualityComparisons(const LhsType& lhs, const RhsType& rhs) { + RUY_CHECK_EQ(lhs, lhs); + TEST_COMPARISON(EQ, lhs, ==, lhs); + RUY_CHECK_EQ(lhs, lhs); + RUY_CHECK_EQ(lhs, lhs); + if (lhs == rhs) { + RUY_CHECK_EQ(lhs, rhs); + } + if (lhs != rhs) { + RUY_CHECK_NE(lhs, rhs); + } +} + +template +void TestComparisons(const LhsType& lhs, const RhsType& rhs) { + TestEqualityComparisons(lhs, rhs); + if (lhs > rhs) { + RUY_CHECK_GT(lhs, rhs); + } + if (lhs >= rhs) { + RUY_CHECK_GE(lhs, rhs); + } + if (lhs < rhs) { + RUY_CHECK_LT(lhs, rhs); + } + if (lhs <= rhs) { + RUY_CHECK_LE(lhs, rhs); + } +} + +TEST(CheckMacrosTest, IntInt) { + TestComparisons(0, 0); + TestComparisons(0, 1); + TestComparisons(1, -1); + TestComparisons(-1, 0); + TestComparisons(123, -456); + TestComparisons(std::numeric_limits::min(), + std::numeric_limits::max()); + TestComparisons(123, std::numeric_limits::max()); + TestComparisons(123, std::numeric_limits::min()); +} + +TEST(CheckMacrosTest, Uint8Uint8) { + TestComparisons(0, 0); + TestComparisons(255, 0); + TestComparisons(0, 255); + TestComparisons(12, 34); +} + +TEST(CheckMacrosTest, Uint8Int) { + TestComparisons(0, std::numeric_limits::min()); + TestComparisons(255, std::numeric_limits::min()); + TestComparisons(0, std::numeric_limits::max()); + TestComparisons(255, std::numeric_limits::max()); +} + +TEST(CheckMacrosTest, FloatFloat) { + TestComparisons(0.f, 0.f); + TestComparisons(0.f, 1.f); + TestComparisons(1.f, -1.f); + TestComparisons(-1.f, 0.f); + TestComparisons(123.f, -456.f); + TestComparisons(std::numeric_limits::lowest(), + std::numeric_limits::max()); + TestComparisons(123.f, std::numeric_limits::max()); + TestComparisons(123.f, std::numeric_limits::lowest()); +} + +TEST(CheckMacrosTest, IntFloat) { + TestComparisons(0, 0.f); + TestComparisons(0, 1.f); + TestComparisons(1, -1.f); + TestComparisons(-1, 0.f); + TestComparisons(123, -456.f); + TestComparisons(std::numeric_limits::lowest(), + std::numeric_limits::max()); + TestComparisons(123, std::numeric_limits::max()); + TestComparisons(123, std::numeric_limits::lowest()); +} + +TEST(CheckMacrosTest, EnumClass) { + enum class SomeEnumClass { kA, kB, kC }; + TestEqualityComparisons(SomeEnumClass::kA, SomeEnumClass::kA); + TestEqualityComparisons(SomeEnumClass::kA, SomeEnumClass::kB); + TestEqualityComparisons(SomeEnumClass::kC, SomeEnumClass::kB); +} + +} // namespace + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/ruy/common.h b/ruy/common.h new file mode 100644 index 0000000..1cd40fe --- /dev/null +++ b/ruy/common.h @@ -0,0 +1,73 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Miscellaneous helpers internal library. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_COMMON_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_COMMON_H_ + +#include +#include + +#include "ruy/check_macros.h" +#include "ruy/matrix.h" +#include "ruy/opt_set.h" +#include "ruy/path.h" +#include "ruy/platform.h" + +#if RUY_OPT_ENABLED(RUY_OPT_PREFETCH_LOAD) +#define RUY_PREFETCH_LOAD(X) X +#else +#define RUY_PREFETCH_LOAD(X) +#endif + +#if RUY_OPT_ENABLED(RUY_OPT_PREFETCH_STORE) +#define RUY_PREFETCH_STORE(X) X +#else +#define RUY_PREFETCH_STORE(X) +#endif + +#define RUY_STR(s) RUY_STR_UNEXPANDED(s) +#define RUY_STR_UNEXPANDED(s) #s + +namespace ruy { + +// Helper for type-erasing a pointer. +// +// Often inside Ruy, a template parameter holds type information statically, but +// we would like to have a function signature that doesn't depend on the +// template parameters, so that we can dispatch indirectly across multiple +// implementations. This helper is at the core of such type-erasure. +// +// The opposite of this operation is just `static_cast(void_ptr)`. +template +void* ToVoidPtr(T* p) { + return const_cast(static_cast(p)); +} + +template +Scalar SymmetricZeroPoint() { + if (std::is_floating_point::value) { + return 0; + } + if (std::is_signed::value) { + return 0; + } + return std::numeric_limits::max() / 2 + 1; +} + +} // namespace ruy + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_COMMON_H_ diff --git a/ruy/context.cc b/ruy/context.cc new file mode 100644 index 0000000..1a70303 --- /dev/null +++ b/ruy/context.cc @@ -0,0 +1,109 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/context.h" + +#include "ruy/check_macros.h" +#include "ruy/detect_arm.h" +#include "ruy/detect_x86.h" +#include "ruy/have_built_path_for.h" +#include "ruy/platform.h" + +namespace ruy { + +void Context::SetRuntimeEnabledPaths(Path paths) { + runtime_enabled_paths_ = paths; +} + +Path Context::GetRuntimeEnabledPaths() { + // This function should always return the same value on a given machine. + // When runtime_enabled_paths_ has its initial value kNone, it performs + // some platform detection to resolve it to specific Path values. + + // Fast path: already resolved. + if (runtime_enabled_paths_ != Path::kNone) { + return runtime_enabled_paths_; + } + + // Need to resolve now. Start by considering all paths enabled. + runtime_enabled_paths_ = kAllPaths; + + // This mechanism is intended to be used for testing and benchmarking. For + // example, one can set RUY_FORCE_DISABLE_PATHS to Path::kAvx512 in order to + // evaluate AVX2 performance on an AVX-512 machine. +#ifdef RUY_FORCE_DISABLE_PATHS + runtime_enabled_paths_ = runtime_enabled_paths_ & ~(RUY_FORCE_DISABLE_PATHS); +#endif + +#if RUY_PLATFORM(ARM) + // Now selectively disable paths that aren't supported on this machine. + if ((runtime_enabled_paths_ & Path::kNeonDotprod) != Path::kNone) { + if (!DetectDotprod()) { + runtime_enabled_paths_ = runtime_enabled_paths_ & ~Path::kNeonDotprod; + // Sanity check. + RUY_DCHECK((runtime_enabled_paths_ & Path::kNeonDotprod) == Path::kNone); + } + } +#endif // RUY_PLATFORM(ARM) + +#if RUY_PLATFORM(X86) + // TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / + // placeholder. Optimization is not finished. In particular the dimensions of + // the kernel blocks can be changed as desired. + // + if ((runtime_enabled_paths_ & Path::kSse42) != Path::kNone) { + if (!(HaveBuiltPathForSse42() && DetectCpuSse42())) { + runtime_enabled_paths_ = runtime_enabled_paths_ & ~Path::kSse42; + // Sanity check. + RUY_DCHECK((runtime_enabled_paths_ & Path::kSse42) == Path::kNone); + } + } + + if ((runtime_enabled_paths_ & Path::kAvx2) != Path::kNone) { + if (!(HaveBuiltPathForAvx2() && DetectCpuAvx2())) { + runtime_enabled_paths_ = runtime_enabled_paths_ & ~Path::kAvx2; + // Sanity check. + RUY_DCHECK((runtime_enabled_paths_ & Path::kAvx2) == Path::kNone); + } + } + + if ((runtime_enabled_paths_ & Path::kAvx512) != Path::kNone) { + if (!(HaveBuiltPathForAvx512() && DetectCpuAvx512())) { + runtime_enabled_paths_ = runtime_enabled_paths_ & ~Path::kAvx512; + // Sanity check. + RUY_DCHECK((runtime_enabled_paths_ & Path::kAvx512) == Path::kNone); + } + } + + // TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / + // placeholder. Optimization is not finished. In particular the dimensions of + // the kernel blocks can be changed as desired. + // + if ((runtime_enabled_paths_ & Path::kAvxVnni) != Path::kNone) { + if (!(HaveBuiltPathForAvxVnni() && DetectCpuAvxVnni())) { + runtime_enabled_paths_ = runtime_enabled_paths_ & ~Path::kAvxVnni; + // Sanity check. + RUY_DCHECK((runtime_enabled_paths_ & Path::kAvxVnni) == Path::kNone); + } + } +#endif // RUY_PLATFORM(X86) + + // Sanity check. We can't possibly have disabled all paths, as some paths + // are universally available (kReference, kStandardCpp). + RUY_DCHECK_NE(runtime_enabled_paths_, Path::kNone); + return runtime_enabled_paths_; +} + +} // namespace ruy diff --git a/ruy/context.h b/ruy/context.h new file mode 100644 index 0000000..330a7e7 --- /dev/null +++ b/ruy/context.h @@ -0,0 +1,109 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_CONTEXT_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_CONTEXT_H_ + +#include +#include +#include + +#include "ruy/allocator.h" +#include "ruy/path.h" +#include "ruy/prepacked_cache.h" +#include "ruy/thread_pool.h" +#include "ruy/trace.h" +#include "ruy/tune.h" + +namespace ruy { + +// The state private to each Ruy thread. +struct PerThreadState { + // Each thread may be running on a different microarchitecture. For example, + // some threads may be on big cores, while others are on little cores. Thus, + // it's best for the tuning to be per-thread. + TuningResolver tuning_resolver; + // Each thread has its own local allocator. + Allocator allocator; +}; + +// A Context holds runtime information used by Ruy. It holds runtime resources +// such as the workers thread pool and the allocator (which holds buffers for +// temporary data), as well as runtime options controlling which Paths are +// enabled (typically based on which instruction sets are detected) and how +// many threads to use. +struct Context final { + Path last_taken_path = Path::kNone; + Tuning explicit_tuning = Tuning::kAuto; + // TODO(benoitjacob) rename that thread_pool. Current name is gemmlowp legacy. + ThreadPool workers_pool; + int max_num_threads = 1; + // State for each thread in the thread pool. Entry 0 is the main thread. + std::vector> per_thread_states; + TracingContext tracing; + CachePolicy cache_policy = CachePolicy::kNoCache; + + Allocator* GetMainAllocator() { + if (!main_allocator_) { + main_allocator_.reset(new Allocator); + } + return main_allocator_.get(); + } + + PrepackedCache* GetPrepackedCache() { + if (!prepacked_cache_) { + prepacked_cache_.reset(new PrepackedCache); + } + return prepacked_cache_.get(); + } + + void ClearPrepackedCache() { prepacked_cache_ = nullptr; } + + void EnsureNPerThreadStates(int thread_count) { + while (per_thread_states.size() < static_cast(thread_count)) { + per_thread_states.emplace_back(new PerThreadState); + } + } + + Tuning GetMainThreadTuning() { + EnsureNPerThreadStates(1); + TuningResolver* tuning_resolver = &per_thread_states[0]->tuning_resolver; + tuning_resolver->SetTuning(explicit_tuning); + return tuning_resolver->Resolve(); + } + + template + Path GetPathToTake() { + last_taken_path = + GetMostSignificantPath(CompiledPaths & GetRuntimeEnabledPaths()); + return last_taken_path; + } + + void SetRuntimeEnabledPaths(Path paths); + Path GetRuntimeEnabledPaths(); + + private: + // Allocator for main thread work before invoking the threadpool. + // Our simple Allocator does not allow reserving/allocating more blocks + // while it's already in committed state, so the main thread needs both + // this allocator, and its per-thread allocator. + std::unique_ptr main_allocator_; + std::unique_ptr prepacked_cache_; + Path runtime_enabled_paths_ = Path::kNone; +}; + +} // end namespace ruy + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_CONTEXT_H_ diff --git a/ruy/context_test.cc b/ruy/context_test.cc new file mode 100644 index 0000000..c189030 --- /dev/null +++ b/ruy/context_test.cc @@ -0,0 +1,63 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/context.h" + +#include "testing/base/public/gunit.h" +#include "ruy/path.h" +#include "ruy/platform.h" + +namespace ruy { +namespace { + +TEST(ContextTest, EnabledPathsGeneral) { + ruy::Context ruy_context; + const auto ruy_paths = ruy_context.GetRuntimeEnabledPaths(); + const auto ruy_paths_repeat = ruy_context.GetRuntimeEnabledPaths(); + ASSERT_EQ(ruy_paths, ruy_paths_repeat); + EXPECT_NE(ruy_paths, Path::kNone); + EXPECT_EQ(ruy_paths & Path::kReference, Path::kReference); + EXPECT_EQ(ruy_paths & Path::kStandardCpp, Path::kStandardCpp); +} + +#if RUY_PLATFORM(X86) +TEST(ContextTest, EnabledPathsX86) { + ruy::Context ruy_context; + ruy_context.SetRuntimeEnabledPaths(Path::kSse42 | Path::kAvx2 | + Path::kAvx512 | Path::kAvxVnni); + const auto ruy_paths = ruy_context.GetRuntimeEnabledPaths(); + EXPECT_EQ(ruy_paths & Path::kReference, Path::kNone); + EXPECT_EQ(ruy_paths & Path::kStandardCpp, Path::kNone); +} +#endif // RUY_PLATFORM(X86) + +#if RUY_PLATFORM(ARM) +TEST(ContextTest, EnabledPathsArm) { + ruy::Context ruy_context; + ruy_context.SetRuntimeEnabledPaths(Path::kNeon | Path::kNeonDotprod); + const auto ruy_paths = ruy_context.GetRuntimeEnabledPaths(); + EXPECT_EQ(ruy_paths & Path::kReference, Path::kNone); + EXPECT_EQ(ruy_paths & Path::kStandardCpp, Path::kNone); + EXPECT_EQ(ruy_paths & Path::kNeon, Path::kNeon); +} +#endif // RUY_PLATFORM(ARM) + +} // namespace +} // namespace ruy + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/ruy/cpu_cache_size.h b/ruy/cpu_cache_size.h new file mode 100644 index 0000000..82f41cc --- /dev/null +++ b/ruy/cpu_cache_size.h @@ -0,0 +1,81 @@ +/* Copyright 2020 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_CPU_CACHE_SIZE_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_CPU_CACHE_SIZE_H_ + +#include "ruy/path.h" +#include "ruy/platform.h" + +namespace ruy { + +// LocalDataCacheSize returns a sane default size for each CPU core's local +// data cache, i.e. the largest data cache that is local to that CPU core, not +// shared with other cores. That allows coarse tuning of code that aims for +// most of its memory accesses to hit such a typically fast data cache. +// +// SharedDataCacheSize returns a sane default size of the total data cache +// accessible to each CPU, including any shared cache. +// +// For example, if we design tune this code for a ARM Cortex-A55 with a local L1 +// cache of 32k, a local L2 cache of 128k and a shared L3 cache of 1M, +// LocalDataCacheSize should return 128k and SharedDataCacheSize +// should return 1M. +// +// Ideally these values would be queried at runtime, and we should probably +// do that on x86, but that is hard to do on ARM. +#if RUY_PLATFORM(ARM_64) +inline int LocalDataCacheSize() { return 1 << 15; } +inline int SharedDataCacheSize() { return 1 << 19; } +#elif RUY_PLATFORM(ARM_32) +inline int LocalDataCacheSize() { return 1 << 14; } +inline int SharedDataCacheSize() { return 1 << 18; } +#elif RUY_PLATFORM(X86) +inline int LocalDataCacheSize() { return 1 << 17; } +inline int SharedDataCacheSize() { return 1 << 21; } +#else +inline int LocalDataCacheSize() { return 1 << 14; } +inline int SharedDataCacheSize() { return 1 << 18; } +#endif +// Variants taking a Path argument which acts +// as a hint telling whether we're targeting more or less recent/powerful CPUs. +inline int LocalDataCacheSize(Path path) { +#if RUY_PLATFORM(ARM_64) + if (path == Path::kNeonDotprod) { + // At the moment, the smallest CPU with dotprod is probably Cortex-A55 with + // 128k L2 local cache. + return 1 << 17; + } +#else + (void)path; +#endif + return LocalDataCacheSize(); +} +inline int SharedDataCacheSize(Path path) { +#if RUY_PLATFORM(ARM_64) + if (path == Path::kNeonDotprod) { + // At the moment, the smallest CPU with dotprod is probably Cortex-A55 with + // 1M L3 shared cache. + return 1 << 20; + } +#else + (void)path; +#endif + return SharedDataCacheSize(); +} + +} // namespace ruy + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_CPU_CACHE_SIZE_H_ diff --git a/ruy/detect_arm.cc b/ruy/detect_arm.cc new file mode 100644 index 0000000..85f7156 --- /dev/null +++ b/ruy/detect_arm.cc @@ -0,0 +1,73 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +/* Detection of dotprod instructions on ARM. + * The current Linux-specific code relies on sufficiently new Linux kernels: + * At least Linux 4.15 in general; on Android, at least Linux 4.14.111 thanks to + * a late backport. This was backported just before the Android 10 release, so + * this is leaving out pre-release Android 10 builds as well as earlier Android + * versions. + * + * It is possible to detect instructions in other ways that don't rely on + * an OS-provided feature identification mechanism: + * + * (A) We used to have a SIGILL-handler-based method that worked at least + * on Linux. Its downsides were (1) crashes on a few devices where + * signal handler installation didn't work as intended; (2) additional + * complexity to generalize to other Unix-ish operating systems including + * iOS; (3) source code complexity and fragility of anything installing + * and restoring signal handlers; (4) confusing behavior under a debugger. + * + * (B) We also experimented with a fork-ing approach where a subprocess + * tries the instruction. Compared to (A), this is much simpler and more + * reliable and portable, but also much higher latency on Android where + * an uncaught signal typically causes a 100 ms latency. + * + * Should there be interest in either technique again in the future, + * code implementing both (A) and (B) can be found in earlier revisions of this + * file - in actual code for (A) and in a comment for (B). + */ + +#include "ruy/detect_arm.h" + +#if defined __linux__ && defined __aarch64__ +#include +#endif + +namespace ruy { + +namespace { + +#if defined __linux__ && defined __aarch64__ +bool DetectDotprodByLinuxAuxvMethod() { + // This is the value of HWCAP_ASIMDDP in sufficiently recent Linux headers, + // however we need to support building against older headers for the time + // being. + const int kLocalHwcapAsimddp = 1 << 20; + return getauxval(AT_HWCAP) & kLocalHwcapAsimddp; +} +#endif + +} // namespace + +bool DetectDotprod() { +#if defined __linux__ && defined __aarch64__ + return DetectDotprodByLinuxAuxvMethod(); +#endif + + return false; +} + +} // namespace ruy diff --git a/ruy/detect_arm.h b/ruy/detect_arm.h new file mode 100644 index 0000000..9a1542d --- /dev/null +++ b/ruy/detect_arm.h @@ -0,0 +1,29 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Temporary dotprod-detection code until we can rely on getauxval. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_DETECT_ARM_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_DETECT_ARM_H_ + +namespace ruy { + +// On A64, returns true if the dotprod extension is present. +// On other architectures, returns false unconditionally. +bool DetectDotprod(); + +} // namespace ruy + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_DETECT_ARM_H_ diff --git a/ruy/detect_x86.cc b/ruy/detect_x86.cc new file mode 100644 index 0000000..ded37b1 --- /dev/null +++ b/ruy/detect_x86.cc @@ -0,0 +1,101 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/detect_x86.h" + +#include + +#if RUY_PLATFORM(X86) && RUY_PLATFORM(X86_ENHANCEMENTS) +#include // IWYU pragma: keep + +#endif + +namespace ruy { +#if RUY_PLATFORM(X86) && RUY_PLATFORM(X86_ENHANCEMENTS) + +namespace { + +// See Intel docs, such as http://goo.gl/c6IkGX. +inline void RunCpuid(std::uint32_t eax, std::uint32_t ecx, + std::uint32_t abcd[4]) { + std::uint32_t ebx, edx; +#if defined(__i386__) && defined(__PIC__) + /* in case of PIC under 32-bit EBX cannot be clobbered */ + asm volatile("movl %%ebx, %%edi \n\t cpuid \n\t xchgl %%ebx, %%edi" + : "=D"(ebx), +#else + asm volatile("cpuid" + : "+b"(ebx), +#endif + "+a"(eax), "+c"(ecx), "=d"(edx)); + abcd[0] = eax; + abcd[1] = ebx; + abcd[2] = ecx; + abcd[3] = edx; +} + +} // namespace + +bool DetectCpuSse42() { + std::uint32_t abcd[4]; + + constexpr std::uint32_t kEcxSse42 = 1u << 20; + RunCpuid(1, 0, abcd); + const bool has_sse4_2_base = (abcd[2] & kEcxSse42) == kEcxSse42; + +#ifdef RUY_ENABLE_AMD_CPUID_CHECKS + constexpr std::uint32_t kEcxAbm = 1u << 5; + RunCpuid(0x80000001, 0, abcd); + const bool has_extras = (abcd[2] & kEcxAbm) == kEcxAbm; +#else + constexpr std::uint32_t kEcxPopcnt = 1u << 23; + RunCpuid(1, 0, abcd); + const bool has_extras = (abcd[2] & kEcxPopcnt) == kEcxPopcnt; +#endif + + return has_sse4_2_base && has_extras; +} + +bool DetectCpuAvx2() { + constexpr std::uint32_t kEbxAvx2 = 1u << 5; + constexpr std::uint32_t kEcxFma = 1u << 12; + + std::uint32_t abcd[4]; + + RunCpuid(7, 0, abcd); + const bool has_avx2 = (abcd[1] & kEbxAvx2) == kEbxAvx2; + RunCpuid(1, 0, abcd); + const bool has_fma = (abcd[2] & kEcxFma) == kEcxFma; + + return has_avx2 && has_fma; +} + +bool DetectCpuAvx512() { + constexpr std::uint32_t kEbxAvx512F = 1u << 16; + constexpr std::uint32_t kEbxAvx512Dq = 1u << 17; + constexpr std::uint32_t kEbxAvx512Cd = 1u << 28; + constexpr std::uint32_t kEbxAvx512Bw = 1u << 30; + constexpr std::uint32_t kEbxAvx512Vl = 1u << 31; + + constexpr std::uint32_t kEbxAvx512Mask = + kEbxAvx512F | kEbxAvx512Dq | kEbxAvx512Cd | kEbxAvx512Bw | kEbxAvx512Vl; + std::uint32_t abcd[4]; + RunCpuid(7, 0, abcd); + + return (abcd[1] & kEbxAvx512Mask) == kEbxAvx512Mask; +} + +#endif +} // namespace ruy diff --git a/ruy/detect_x86.h b/ruy/detect_x86.h new file mode 100644 index 0000000..fede7c7 --- /dev/null +++ b/ruy/detect_x86.h @@ -0,0 +1,49 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_DETECT_X86_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_DETECT_X86_H_ + +#include "ruy/platform.h" + +namespace ruy { + +#if RUY_PLATFORM(X86) +#if RUY_PLATFORM(X86_ENHANCEMENTS) + +// This also checks ABM support, which implies LZCNT and POPCNT. +bool DetectCpuSse42(); +bool DetectCpuAvx2(); +bool DetectCpuAvx512(); +// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. +// Optimization is not finished. In particular the dimensions of the kernel +// blocks can be changed as desired. +// +// TODO(b/146646451): Introduce and activate. +inline bool DetectCpuAvxVnni() { return false; } + +#else // RUY_PLATFORM(X86_ENHANCEMENTS) + +inline bool DetectCpuSse42() { return false; } +inline bool DetectCpuAvx2() { return false; } +inline bool DetectCpuAvx512() { return false; } +inline bool DetectCpuAvxVnni() { return false; } + +#endif // !RUY_PLATFORM(X86_ENHANCEMENTS) +#endif // RUY_PLATFORM(X86) + +} // namespace ruy + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_DETECT_X86_H_ diff --git a/ruy/dispatch.h b/ruy/dispatch.h new file mode 100644 index 0000000..2fd50d0 --- /dev/null +++ b/ruy/dispatch.h @@ -0,0 +1,482 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file implements the translation between Ruy's entry point (ruy::Mul) and +// the internal implementation of matrix multiplication. +// +// The primary elements of this dispatch are: +// - pick suitable gemm kernel and packing routines for the user-specified +// CompiledPaths based on the current CPU. +// - decide on the structure of the packed matrices needed by the internal +// implementation (see pack.h for more information on packing). +// - translate the Mul operation into TrMul (see trmul.h for why that is +// useful). This is done by changing the matrix Layout -- no matrix data is +// actually moved. +// +// This file is also factored to serve as a building block for the advanced API +// as well. +// +// This file also performs some checking of invariants to catch user errors. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_DISPATCH_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_DISPATCH_H_ + +#include +#include +#include // IWYU pragma: keep +#include + +#include "ruy/check_macros.h" +#include "ruy/common.h" +#include "ruy/context.h" +#include "ruy/internal_matrix.h" +#include "ruy/kernel.h" +#include "ruy/kernel_common.h" +#include "ruy/matrix.h" +#include "ruy/opt_set.h" +#include "ruy/pack.h" +#include "ruy/pack_common.h" +#include "ruy/path.h" +#include "ruy/profiler/instrumentation.h" +#include "ruy/side_pair.h" +#include "ruy/size_util.h" +#include "ruy/spec.h" +#include "ruy/trmul.h" +#include "ruy/trmul_params.h" + +namespace ruy { + +// If the Spec's LayoutSupport covers only some special cases, +// this function enforces that the matrix multiplication at hand falls into +// that special case. +template +void EnforceLayoutSupport(const Layout& lhs_layout, const Layout& rhs_layout, + const Layout& dst_layout) { + if (Spec::kLayoutSupport == LayoutSupport::kRCC) { + RUY_DCHECK(IsRowMajor(lhs_layout)); + RUY_DCHECK(IsColMajor(rhs_layout)); + RUY_DCHECK(IsColMajor(dst_layout)); + } +} + +template +bool IsSymmetricZeroPoint(Scalar zero_point) { + return zero_point == SymmetricZeroPoint(); +} + +template +void CheckZeroPoint(Scalar zero_point) { + if (std::is_floating_point::value || + Spec::kZeroPointSupport == ZeroPointSupport::kSymmetric) { + RUY_DCHECK(IsSymmetricZeroPoint(zero_point)); + } +} + +template +void EnforceZeroPointSupport(LhsScalar lhs_zero_point, RhsScalar rhs_zero_point, + DstScalar dst_zero_point) { + // If the Spec's ZeroPointSupport covers only some special cases, + // this function enforces that the matrix multiplication at hand falls into + // that special case. + CheckZeroPoint(lhs_zero_point); + CheckZeroPoint(rhs_zero_point); + CheckZeroPoint(dst_zero_point); + + // Guard against the case when both LHS and RHS zero_point's are equal to + // the minimum representable value. In that case, padding with zero_point + // values will generate the bad case for fast int8 kernels on NEON + // (pre-dotprod) which attempt to multiply-accumulate two pairs of int8 + // into a int16: this is safe except in the bad case -128*-128 + -128*-128. + // See b/131609283. This only affects the kNeon path but we ban this for all + // paths in order for ruy to have the same supported parameter space + // on all paths. + RUY_DCHECK(lhs_zero_point != std::numeric_limits::lowest() || + rhs_zero_point != std::numeric_limits::lowest()); +} + +template +void EnforceDstSpecSupport(const Spec& spec, DstScalar dst_zero_point) { + static_assert(std::is_same::value, ""); + if (!std::is_same::value) return; + + // If user is looking for the raw accumulator, zero_point and all the other + // dequantize fields don't make sense and should not be set. + RUY_DCHECK_EQ(dst_zero_point, 0); + RUY_DCHECK_EQ(spec.clamp_max, std::numeric_limits::max()); + RUY_DCHECK_EQ(spec.clamp_min, std::numeric_limits::min()); + RUY_DCHECK_EQ(spec.multiplier_fixedpoint, 0); + RUY_DCHECK_EQ(spec.multiplier_exponent, 0); + RUY_DCHECK_EQ(spec.multiplier_fixedpoint_perchannel, nullptr); + RUY_DCHECK_EQ(spec.multiplier_exponent_perchannel, nullptr); +} + +inline bool IsColMajorTrMul(const TrMulParams& params) { + return IsColMajor(params.src[Side::kLhs].layout) && + IsColMajor(params.src[Side::kRhs].layout) && + IsColMajor(params.dst.layout); +} + +inline void CreatePackedLayout(const Layout& src, const Type& scalar, + const KernelLayout& kernel_layout, + PackedLayout* packed) { + packed->order = Order::kColMajor; + packed->rows = round_up_pot(src.rows, kernel_layout.rows); + packed->cols = round_up_pot(src.cols, kernel_layout.cols); + packed->kernel = kernel_layout; + int inner_size = packed->rows; + if (RUY_OPT_ENABLED(RUY_OPT_AVOID_ALIASING)) { + packed->stride = + (inner_size * scalar.size) % 1024 ? inner_size : inner_size + 64; + } else { + packed->stride = inner_size; + } +} + +template +void CreatePackedMatrix(Side side, const KernelLayout& kernel_layout, + TrMulParams* params) { + // Ruy always uses 32-bit signed accumulators for quantized + // matrix multiplication, so we would like to always use std::int32_t + // unconditionally for SumsType. + // However, for floating point types, we still need a reasonable type here to + // avoid tripping assertions elsewhere in the code. + using SumsType = + typename std::conditional::value, Scalar, + std::int32_t>::type; + + const DMatrix& src = params->src[side]; + PMatrix* packed = ¶ms->packed[side]; + packed->data_type = Type::Create(); + packed->sums_type = Type::Create(); + CreatePackedLayout(src.layout, packed->data_type, kernel_layout, + &packed->layout); + packed->zero_point = Pack(src.zero_point); +} + +template +void PopulateTrMulParams(TrMulParams* params) { + static_assert((ThePath & Path::kReference) == Path::kNone, + "Path::kReference should not do TrMul"); + // The optimized code paths don't handle the full generality of Ruy's API. + // Fall back to Path::kStandardCpp if necessary. + bool fallback_to_standard_cpp = false; + if (ThePath != Path::kStandardCpp) { + // The optimized code paths currently only handle the case of all matrices + // being column major. + if (!IsColMajorTrMul(*params)) { + fallback_to_standard_cpp = true; + } + } + + if (fallback_to_standard_cpp) { + PopulateTrMulParams(params); + return; + } + + using PackedLhsScalar = PackedType; + using PackedRhsScalar = PackedType; + using Kernel = + Kernel; + using LhsKernelLayout = typename Kernel::LhsLayout; + using RhsKernelLayout = typename Kernel::RhsLayout; + + params->path = ThePath; + + params->local_data_cache_size = Spec::local_data_cache_size(); + params->shared_data_cache_size = Spec::shared_data_cache_size(); + + CreatePackedMatrix( + Side::kLhs, ToKernelLayout(), params); + CreatePackedMatrix( + Side::kRhs, ToKernelLayout(), params); + params->run_pack[Side::kLhs] = + &RunPack; + params->run_pack[Side::kRhs] = + &RunPack; + params->run_kernel = + &RunKernel; + + return; +} + +// PopulateTrMulParamsAllCompiledPaths calls into one of multiple +// instantiations of PopulateTrMulParams. For each bit that is set in +// CompiledPaths, it statically instantiates PopulateTrMulParams with a Path +// corresponding to that single bit. The call to PopulateTrMulParams is +// guarded by a runtime check that it is in fact the dynamically selected path. +// +// PopulateTrMulParamsAllCompiledPaths is implemented with template +// metaprogramming by mutual recursion between PathSearchCountdown and +// PathSearchCompiledPaths. +// +// PopulateTrMulParamsAllCompiledPaths is logically implementing the following +// computation: +// +// template +// void PopulateTrMulParamsAllCompiledPaths(Path the_path, +// TrMulParams* params) { +// for (int bit = 8 * sizeof(Path) - 1; bit != -1; bit--) { // [1] +// Path current_path = static_cast(1 << bit); +// if ((CompiledPaths & current_path) != Path::kNone) { // [2] +// if (current_path == the_path) { // [3] +// PopulateTrMulParams(the_path, params); +// return; +// } +// } +// } +// } +// +// +// +// [1] - Done by the main definition of PathSearchCountdown. The `bit--` is +// done in the recursion of PathSearchOnlyCompiledPaths. +// [2] - Done by PathSearchOnlyCompiledPaths's partial template +// specialization on InCompiledPaths. This is the check which necessitates +// doing the whole computation at C++ compile time. +// [3] - Done by the `if` in the main definition of +// PathSearchOnlyCompiledPaths. +// +// The template metaprogramming is necessary because: +// - In `PopulateTrMulParams`, current_path must be a C++ +// compile-time constant. +// - PopulateTrMulParamsAllCompiledPaths must not instantiate +// inner loops for paths that are not in CompiledPaths, since that can result in +// bogus instantiations which cause a compile time failure. +template +struct PathSearchCountdown; + +template +struct PathSearchOnlyCompiledPaths { + static constexpr Path kCurrentPath = static_cast(1 << BitNumber); + static void Search(Path the_path, TrMulParams* params) { + if (kCurrentPath == the_path) { + PopulateTrMulParams( + params); + return; + } + PathSearchCountdown::Search(the_path, params); + } +}; + +// Skip this iteration if CompiledPaths doesn't contain the specified path. +template +struct PathSearchOnlyCompiledPaths { + static void Search(Path the_path, TrMulParams* params) { + PathSearchCountdown::Search(the_path, params); + } +}; + +template +struct PathSearchCountdown { + static constexpr Path kCurrentPath = static_cast(1 << BitNumber); + static void Search(Path the_path, TrMulParams* params) { + PathSearchOnlyCompiledPaths< + CompiledPaths, (CompiledPaths & kCurrentPath) != Path::kNone, BitNumber, + LhsScalar, RhsScalar, DstScalar, Spec>::Search(the_path, params); + } +}; + +// Termination of the countdown. If the counter reaches -1, then we haven't +// found the specified path. +template +struct PathSearchCountdown { + static void Search(Path the_path, TrMulParams* params) { RUY_DCHECK(false); } +}; + +template +void PopulateTrMulParamsAllCompiledPaths(Path the_path, TrMulParams* params) { + return PathSearchCountdown::Search(the_path, + params); +} + +template +void CreateTrMulParams(const Matrix& lhs, + const Matrix& rhs, const Spec& spec, + Context* context, Matrix* dst, Path the_path, + TrMulParams* params) { + // Fill in the fields we already know. + params->src[Side::kLhs] = ToDMatrix(lhs); + params->src[Side::kRhs] = ToDMatrix(rhs); + params->dst = ToDMatrix(*dst); + params->spec = ToVoidPtr(&spec); + + // Create inner loops and packed matrices based on the Path. + PopulateTrMulParamsAllCompiledPaths(the_path, params); +} + +template +void ReferenceMul(const Matrix& lhs, const Matrix& rhs, + const Spec& spec, Matrix* dst) { + profiler::ScopeLabel label("ReferenceMul"); + for (int i = 0; i < lhs.layout.rows; i++) { + for (int j = 0; j < rhs.layout.cols; j++) { + using AccumScalar = typename Spec::AccumScalar; + AccumScalar accum = 0; + for (int k = 0; k < lhs.layout.cols; k++) { + AccumScalar lhs_val = Element(lhs, i, k); + AccumScalar rhs_val = Element(rhs, k, j); + accum += (lhs_val - lhs.zero_point) * (rhs_val - rhs.zero_point); + } + if (spec.bias) { + accum += spec.bias[i]; + } + ApplyMultiplier(spec, i, &accum); + accum += dst->zero_point; + accum = std::min(accum, spec.clamp_max); + accum = std::max(accum, spec.clamp_min); + *ElementPtr(dst, i, j) = static_cast(accum); + } + } +} + +// Compile-time dispatch to ReferenceMul. This allows us to statically ensure +// that there is no call to ReferenceMul in the user's binary. +template +struct CompileTimeEnabledReferenceMul { + template + static void Run(const Matrix& lhs, const Matrix& rhs, + const Spec& spec, Matrix* dst) { + ReferenceMul(lhs, rhs, spec, dst); + } +}; + +// When this partial specialization is chosen, it ensures that ReferenceMul +// is never compiled. +template <> +struct CompileTimeEnabledReferenceMul { + template + static void Run(const Matrix& lhs, const Matrix& rhs, + const Spec& spec, Matrix* dst) { + RUY_DCHECK(false); + } +}; + +inline void HandlePrepackedCaching(TrMulParams* params, + const SidePair& cacheable, + Context* context) { + if (context->cache_policy == CachePolicy::kNoCache) { + return; + } + + if (context->cache_policy == CachePolicy::kCacheLHSOnNarrowMul) { + // TODO(b/149304278) Cache on dst.cols <= selected kernel width. + if (!cacheable[Side::kLhs] || params->dst.layout.cols > 4) { + return; + } + PrepackedCache* prepacked_cache = context->GetPrepackedCache(); + auto cache_key = std::make_pair(reinterpret_cast(params->run_kernel), + params->src[Side::kLhs].data); + auto it = prepacked_cache->FindAndUpdate(cache_key); + if (it != prepacked_cache->cend()) { + params->packed[Side::kLhs].data = it->second.first.data; + params->packed[Side::kLhs].sums = it->second.first.sums; + params->is_prepacked[Side::kLhs] = true; + return; + } + + // Allocate the prepacked matrix. + PrepackedMatrix prepacked_lhs; + prepacked_lhs.data_size = DataSize(params->packed[Side::kLhs]); + prepacked_lhs.sums_size = SumsSize(params->packed[Side::kLhs]); + prepacked_cache->AllocatePrepackedMatrix(&prepacked_lhs); + params->packed[Side::kLhs].data = prepacked_lhs.data; + params->packed[Side::kLhs].sums = prepacked_lhs.sums; + params->is_prepacked[Side::kLhs] = true; + Tuning tuning = context->GetMainThreadTuning(); + params->RunPack(Side::kLhs, tuning, 0, + params->packed[Side::kLhs].layout.cols); + prepacked_cache->Insert(cache_key, prepacked_lhs); + return; + } +} + +template +void DispatchMul(const Matrix& lhs, const Matrix& rhs, + const Spec& spec, Context* context, Matrix* dst) { + static_assert(CompiledPaths != Path::kNone, "Must compile at least one Path"); + static_assert((CompiledPaths & ~kAllPaths) == Path::kNone, + "CompiledPaths must be a subset of ruy::kAllPaths"); + + profiler::ScopeLabel mul_label("Mul"); + profiler::ScopeLabel shape_specific_label("matmul shape: %dx%dx%d", + lhs.layout.rows, lhs.layout.cols, + rhs.layout.cols); + + EnforceLayoutSupport(lhs.layout, rhs.layout, dst->layout); + EnforceZeroPointSupport(lhs.zero_point, rhs.zero_point, + dst->zero_point); + EnforceDstSpecSupport(spec, dst->zero_point); + + // This should be a constant, for a given machine and CompiledPaths. + // There is a back door to override it for testing, but in production it will + // always be the "best" Path. I.e. the one with the newest SIMD instructions + // available on the present machine, and avoiding Path::kReference unless + // no other path is compiled. + // + // Unfortunately, it is not a *static* constant, since it depends on runtime + // detection of the available SIMD instructions. + Path the_path = context->GetPathToTake(); + + // Production code should probably never execute Path::kReference. + // Path::kReference implements a Mul, not a TrMul like the rest of Ruy, so if + // that's what we need to do, then get it out of the way before going down the + // TrMul path. + if (the_path == Path::kReference) { + constexpr bool ReferenceMulIsEnabled = + (CompiledPaths & Path::kReference) != Path::kNone; + CompileTimeEnabledReferenceMul::Run(lhs, rhs, spec, + dst); + return; + } + + // As described in the comment at the top of this file, Ruy internally + // converts Mul into TrMul. We handle that here. + // + // This is Ruy's main code path. + constexpr Path TrMulCompiledPaths = CompiledPaths & ~Path::kReference; + Matrix transposed_lhs(lhs); + Transpose(&transposed_lhs); + TrMulParams params; + CreateTrMulParams(transposed_lhs, rhs, spec, context, dst, + the_path, ¶ms); + SidePair cacheable(lhs.cacheable, rhs.cacheable); + HandlePrepackedCaching(¶ms, cacheable, context); + TrMul(¶ms, context); +} + +} // namespace ruy + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_DISPATCH_H_ diff --git a/ruy/example.cc b/ruy/example.cc new file mode 100644 index 0000000..3b42c97 --- /dev/null +++ b/ruy/example.cc @@ -0,0 +1,136 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "ruy/ruy.h" + +void ExampleMulFloat(ruy::Context *context) { + const float lhs_data[] = {1, 2, 3, 4}; + const float rhs_data[] = {1, 2, 3, 4}; + float dst_data[4]; + + ruy::Matrix lhs; + ruy::MakeSimpleLayout(2, 2, ruy::Order::kRowMajor, &lhs.layout); + lhs.data = lhs_data; + ruy::Matrix rhs; + ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, &rhs.layout); + rhs.data = rhs_data; + ruy::Matrix dst; + ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, &dst.layout); + dst.data = dst_data; + + ruy::BasicSpec spec; + ruy::Mul(lhs, rhs, spec, context, &dst); + + std::cout << "Example Mul, float:\n"; + std::cout << "LHS:\n" << lhs; + std::cout << "RHS:\n" << rhs; + std::cout << "Result:\n" << dst << "\n"; +} + +void ExampleMulFloatWithBiasAddAndClamp(ruy::Context *context) { + const float lhs_data[] = {1, 2, 3, 4}; + const float rhs_data[] = {1, 2, 3, 4}; + const float bias_data[] = {1, 0}; + float dst_data[4]; + + ruy::Matrix lhs; + ruy::MakeSimpleLayout(2, 2, ruy::Order::kRowMajor, &lhs.layout); + lhs.data = lhs_data; + ruy::Matrix rhs; + ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, &rhs.layout); + rhs.data = rhs_data; + ruy::Matrix dst; + ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, &dst.layout); + dst.data = dst_data; + + ruy::BasicSpec spec; + spec.bias = bias_data; + spec.clamp_min = 0; + spec.clamp_max = 15; + ruy::Mul(lhs, rhs, spec, context, &dst); + + std::cout << "Example Mul, float with bias addition and clamp:\n"; + std::cout << "LHS:\n" << lhs; + std::cout << "RHS:\n" << rhs; + std::cout << "Result:\n" << dst << "\n"; +} + +void ExampleMulUint8AsymmetricQuantized(ruy::Context *context) { + const std::uint8_t lhs_data[] = {124, 125, 126, 127}; + const std::uint8_t rhs_data[] = {129, 130, 131, 132}; + std::uint8_t dst_data[4]; + + ruy::Matrix lhs; + ruy::MakeSimpleLayout(2, 2, ruy::Order::kRowMajor, &lhs.layout); + lhs.data = lhs_data; + lhs.zero_point = 125; + ruy::Matrix rhs; + ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, &rhs.layout); + rhs.data = rhs_data; + rhs.zero_point = 132; + ruy::Matrix dst; + ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, &dst.layout); + dst.data = dst_data; + dst.zero_point = 129; + + ruy::BasicSpec spec; + spec.multiplier_fixedpoint = 1 << 30; + + spec.multiplier_exponent = 0; + ruy::Mul(lhs, rhs, spec, context, &dst); + + std::cout << "Example Mul, uint8 quantized with asymmetric zero points:\n"; + std::cout << "LHS:\n" << lhs; + std::cout << "RHS:\n" << rhs; + std::cout << "Result:\n" << dst << "\n"; +} +void ExampleMulInt8PerChannelQuantized(ruy::Context *context) { + const std::int8_t lhs_data[] = {1, 2, 3, 4}; + const std::int8_t rhs_data[] = {1, 2, 3, 4}; + const std::int32_t multiplier_data[] = {3 << 28, 5 << 28}; + const int exponent_data[] = {1, -2}; + std::int8_t dst_data[4]; + + ruy::Matrix lhs; + ruy::MakeSimpleLayout(2, 2, ruy::Order::kRowMajor, &lhs.layout); + lhs.data = lhs_data; + ruy::Matrix rhs; + ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, &rhs.layout); + rhs.data = rhs_data; + ruy::Matrix dst; + ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, &dst.layout); + dst.data = dst_data; + + ruy::BasicSpec spec; + spec.multiplier_fixedpoint_perchannel = multiplier_data; + spec.multiplier_exponent_perchannel = exponent_data; + ruy::Mul(lhs, rhs, spec, context, &dst); + + std::cout << "Example Mul, int8 quantized with per-channel multipliers\n"; + std::cout << "LHS:\n" << lhs; + std::cout << "RHS:\n" << rhs; + std::cout << "Result:\n" << dst << "\n"; +} + +int main() { + ruy::Context context; + ExampleMulFloat(&context); + ExampleMulFloatWithBiasAddAndClamp(&context); + ExampleMulUint8AsymmetricQuantized(&context); + ExampleMulInt8PerChannelQuantized(&context); +} diff --git a/ruy/example_advanced.cc b/ruy/example_advanced.cc new file mode 100644 index 0000000..9041bdb --- /dev/null +++ b/ruy/example_advanced.cc @@ -0,0 +1,83 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "ruy/ruy_advanced.h" + +// Simple allocator for allocating pre-packed matrices. +class SimpleAllocator { + public: + void* AllocateBytes(std::size_t num_bytes) { + char* p = new char[num_bytes]; + buffers_.emplace_back(p); + return static_cast(p); + } + + private: + std::vector> buffers_; +}; + +void ExamplePrepack(ruy::Context* context) { + const float lhs_data[] = {1, 2, 3, 4}; + const float rhs_data[] = {1, 2, 3, 4}; + float dst_data[4]; + + // Set up the matrix layouts and spec. + ruy::Matrix lhs; + ruy::MakeSimpleLayout(2, 2, ruy::Order::kRowMajor, &lhs.layout); + ruy::Matrix rhs; + ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, &rhs.layout); + ruy::Matrix dst; + ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, &dst.layout); + ruy::BasicSpec spec; + + SimpleAllocator allocator; + auto alloc_fn = [&allocator](std::size_t num_bytes) -> void* { + return allocator.AllocateBytes(num_bytes); + }; + + // In this example, we pre-pack only the RHS, but either will work. + // Note that we only need to set the data pointer for the matrix we are + // pre-packing. + ruy::PrepackedMatrix prepacked_rhs; + rhs.data = rhs_data; + ruy::PrePackForMul(lhs, rhs, spec, context, &dst, + /*prepacked_lhs=*/nullptr, &prepacked_rhs, + alloc_fn); + + // No data will be read from the RHS input matrix when using a pre-packed RHS. + rhs.data = nullptr; + lhs.data = lhs_data; + dst.data = dst_data; + ruy::MulWithPrepacked(lhs, rhs, spec, context, &dst, + /*prepacked_lhs=*/nullptr, + &prepacked_rhs); + rhs.data = rhs_data; + + // Print out the results. + std::cout << "Example Mul with pre-packing RHS, float:\n"; + std::cout << "LHS:\n" << lhs; + std::cout << "RHS:\n" << rhs; + std::cout << "Result:\n" << dst << "\n"; +} + +int main() { + ruy::Context context; + ExamplePrepack(&context); +} diff --git a/ruy/have_built_path_for.h b/ruy/have_built_path_for.h new file mode 100644 index 0000000..8913965 --- /dev/null +++ b/ruy/have_built_path_for.h @@ -0,0 +1,32 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_HAVE_BUILT_PATH_FOR_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_HAVE_BUILT_PATH_FOR_H_ + +#include "ruy/platform.h" + +namespace ruy { + +#if RUY_PLATFORM(X86) +bool HaveBuiltPathForSse42(); +bool HaveBuiltPathForAvx2(); +bool HaveBuiltPathForAvx512(); +bool HaveBuiltPathForAvxVnni(); +#endif // RUY_PLATFORM(X86) + +} // namespace ruy + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_HAVE_BUILT_PATH_FOR_H_ diff --git a/ruy/have_built_path_for_avx2.cc b/ruy/have_built_path_for_avx2.cc new file mode 100644 index 0000000..ceca8a4 --- /dev/null +++ b/ruy/have_built_path_for_avx2.cc @@ -0,0 +1,35 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/have_built_path_for.h" +#include "ruy/opt_set.h" + +namespace ruy { + +#if RUY_PLATFORM(X86) +// IMPORTANT: +// These patterns must match those in the pack and kernel cc files. +#if !(RUY_PLATFORM(AVX2) && RUY_OPT_ENABLED(RUY_OPT_ASM)) + +bool HaveBuiltPathForAvx2() { return false; } + +#else // RUY_PLATFORM(AVX2) && RUY_OPT_ENABLED(RUY_OPT_ASM) + +bool HaveBuiltPathForAvx2() { return true; } + +#endif // RUY_PLATFORM(AVX2) && RUY_OPT_ENABLED(RUY_OPT_ASM) +#endif // RUY_PLATFORM(X86) + +} // namespace ruy diff --git a/ruy/have_built_path_for_avx512.cc b/ruy/have_built_path_for_avx512.cc new file mode 100644 index 0000000..15fba62 --- /dev/null +++ b/ruy/have_built_path_for_avx512.cc @@ -0,0 +1,35 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/have_built_path_for.h" +#include "ruy/opt_set.h" + +namespace ruy { + +#if RUY_PLATFORM(X86) +// IMPORTANT: +// These patterns must match those in the pack and kernel cc files. +#if !(RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_ASM)) + +bool HaveBuiltPathForAvx512() { return false; } + +#else // RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_ASM) + +bool HaveBuiltPathForAvx512() { return true; } + +#endif // RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_ASM) +#endif // RUY_PLATFORM(X86) + +} // namespace ruy diff --git a/ruy/have_built_path_for_avxvnni.cc b/ruy/have_built_path_for_avxvnni.cc new file mode 100644 index 0000000..68ef2a2 --- /dev/null +++ b/ruy/have_built_path_for_avxvnni.cc @@ -0,0 +1,39 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/have_built_path_for.h" +#include "ruy/opt_set.h" + +namespace ruy { + +#if RUY_PLATFORM(X86) +// IMPORTANT: +// These patterns must match those in the pack and kernel cc files. +#if !(RUY_PLATFORM(AVX_VNNI) && RUY_OPT_ENABLED(RUY_OPT_ASM)) + +bool HaveBuiltPathForAvxVnni() { return false; } + +#else // RUY_PLATFORM(AVX_VNNI) && RUY_OPT_ENABLED(RUY_OPT_ASM) + +// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. +// Optimization is not finished. In particular the dimensions of the kernel +// blocks can be changed as desired. +// +bool HaveBuiltPathForAvxVnni() { return true; } + +#endif // RUY_PLATFORM(AVX_VNNI) && RUY_OPT_ENABLED(RUY_OPT_ASM) +#endif // RUY_PLATFORM(X86) + +} // namespace ruy diff --git a/ruy/have_built_path_for_sse42.cc b/ruy/have_built_path_for_sse42.cc new file mode 100644 index 0000000..2141b75 --- /dev/null +++ b/ruy/have_built_path_for_sse42.cc @@ -0,0 +1,39 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/have_built_path_for.h" +#include "ruy/opt_set.h" + +namespace ruy { + +#if RUY_PLATFORM(X86) +// IMPORTANT: +// These patterns must match those in the pack and kernel cc files. +#if !(RUY_PLATFORM(SSE42) && RUY_OPT_ENABLED(RUY_OPT_ASM)) + +bool HaveBuiltPathForSse42() { return false; } + +#else // RUY_PLATFORM(SSE42) && RUY_OPT_ENABLED(RUY_OPT_ASM) + +// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. +// Optimization is not finished. In particular the dimensions of the kernel +// blocks can be changed as desired. +// +bool HaveBuiltPathForSse42() { return true; } + +#endif // RUY_PLATFORM(SSE42) && RUY_OPT_ENABLED(RUY_OPT_ASM) +#endif // RUY_PLATFORM(X86) + +} // namespace ruy diff --git a/ruy/internal_matrix.h b/ruy/internal_matrix.h new file mode 100644 index 0000000..7fe13be --- /dev/null +++ b/ruy/internal_matrix.h @@ -0,0 +1,388 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Internal types and helpers for matrices. +// +// Ruy has a couple slightly different notions of matrices, besides the +// Matrix class that we expose to the user-facing API. +// +// TODO(silvasean): Put parts of this architecture description somewhere more +// prominent. +// +// The 4 main matrix types are: +// - Matrix: This is a user-facing type on Ruy's external API boundary. It is +// also used internally. +// - DMatrix: This is a type-erased version of Matrix. "D" = "dynamic". +// - PMatrix: This represents a packed matrix, which requires tracking kernel +// layout and row/column sums for quantization. It is type-erased. +// - PackedMatrix: This is a statically typed variant of PMatrix for +// convenience inside typed routines. +// +// Note that Matrix is *not* implemented in terms of the internal types. It +// is an independent, simple, and user-facing type. +// +// The use of type-erasure might seem surprising for a library like Ruy with a +// heavily-templated entry point, but it is motivated by the desire for most of +// Ruy's "middle-end" to be non-templated. Ruy can be thought of as having 3 +// main parts: +// - "front-end" (dispatch.h) - this is the highly templated ruy::Mul entry +// point, along with routines that select RunKernel and RunPack implementations +// statically based on those template parameters. +// - "back-end" (kernel.h, pack.h)- this consists of the implementations of +// RunKernel and RunPack, often in assembly code, which are the building blocks +// that Ruy calls to perform matrix multiplication. These are templated so that +// only the requested types/Path's are actually emitted by the compiler. +// - "middle-end" (trmul.h) - this is the part of Ruy that orchestrates the +// calls to the "back-end" optimized building blocks. This layer has to deal +// with issues like cache locality and low-overhead multi-threading. +// +// There is a desire for the "middle-end" to be non-templated in order to +// simplify the implementation and reduce code-size. We type-erase when going +// from the "front-end" to the "middle-end", and un-type-erase going from the +// "middle-end" to the "back-end". The un-type-erasure is possible because the +// "front-end" is responsible for instantiating the needed "back-end" templates, +// and thus the static type information is still present. +// +// Each layer of Ruy uses matrix types: +// - "front-end": Matrix +// - "middle-end": DMatrix, PMatrix +// - "back-end": Matrix, PackedMatrix +// +// The use of separate types for packed matrices is not essential, but makes it +// obvious at a glance whether a matrix is a packed matrix or not. We would +// reconsider this decision if there was significant duplication between packed +// and unpacked matrices, but that doesn't seem to be the case at the moment. +// +// Another goal is to keep the user-facing Matrix as simple and +// understandable as possible. Ideally, a user should be able to read the struct +// definition for Matrix and see a very simple definition with no internal +// details like sums and kernel block layout. +// +// To present another structured view of our various matrix types, here's a +// table: +// Plain matrices Packed matrices +// +---------------------------------- +// Templated | Matrix PackedMatrix +// Type-erased | DMatrix PMatrix +// +// +// There is 1 additional matrix type not mentioned above, due to its low +// importance: +// - PrepackedMatrix: This is a user-facing version of PMatrix. It has the bare +// minimum of fields needed for representing the raw data and sums buffers of a +// packed matrix for the "advanced" explicit pre-packing API. This type plays no +// role in Ruy's internals and can generally by ignored. The only reason it +// exists is so that PMatrix is not exposed to users -- we prefer to keep the +// internal matrix types hidden, even from "advanced" users. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_INTERNAL_MATRIX_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_INTERNAL_MATRIX_H_ + +#include +#include +#include +#include + +#include "ruy/check_macros.h" +#include "ruy/common.h" +#include "ruy/matrix.h" +#include "ruy/size_util.h" + +namespace ruy { + +// KernelLayout describes small-scale block structure in a packed matrix layout. +// It's a runtime (as opposed to compile-time-constant) version of the +// FixedKernelLayout struct used to declare kernel layouts. +// +// This is is sometimes known as "tiling" in other contexts. +// +// For example, consider a packed matrix in column-major format with a +// column-major KernelLayout. The matrix logically has a shape of +// `[cols, rows]`. However, the matrix is laid out as though it were a 4D array +// of shape `[cols / kcols, rows / krows, kcols, krows]`. +// +// Note that in the case of kcols=1, krows=1, this degenerates to +// `[cols, rows, 1, 1]` which is equivalent to having no small-scale block +// structure. +struct KernelLayout { + Order order = Order::kColMajor; + std::uint8_t rows = 1; + std::uint8_t cols = 1; +}; + +// A packed matrix has a small-scale block structure that is not present in in +// the input matrices. This block structure is necessary for the kernels to +// process data efficiently. +// +// This struct is very similar to Layout, but has the extra KernelLayout field. +struct PackedLayout { + std::int32_t rows = 0; + std::int32_t cols = 0; + // Stride is the offset between two adjacent matrix elements + // in the non-contiguous direction. + std::int32_t stride = 0; + Order order = Order::kColMajor; + // Small scale layout shuffling, potentially departing from + // linear row-major or column-major storage. See KernelLayout. + KernelLayout kernel; +}; + +// Dynamic representation for a type. +// +// The most important field in this struct is the size, which Ruy uses to know +// how much memory to allocate without having to be templated on a type. +// Signed-ness and floating-point-ness are mainly present as debugging checks. +// +// Note: Ruy does not use this struct to to dynamically dispatch between +// different typed implementations. As described in the comment at the top of +// this file, Ruy's "front-end", which is templated, instantiates all the +// necessary "back-end" routines with complete static knowledge of all the +// types. +struct Type { + template + static Type Create() { + Type ret; + ret.is_signed = std::is_signed::value; + ret.is_floating_point = std::is_floating_point::value; + ret.size = sizeof(T); + return ret; + } + + template + void AssertIs() const { + RUY_DCHECK_EQ(is_signed, Create().is_signed); + RUY_DCHECK_EQ(is_floating_point, Create().is_floating_point); + RUY_DCHECK_EQ(size, Create().size); + } + + bool is_signed = false; + bool is_floating_point = false; + std::uint8_t size = 0; +}; + +// Type-erased matrix. +struct DMatrix { + Type data_type; + void* data = nullptr; + Layout layout; + std::int32_t zero_point = 0; +}; + +// Type-erased packed matrix. +struct PMatrix { + Type data_type; + void* data = nullptr; + Type sums_type; + void* sums = nullptr; + PackedLayout layout; + std::int32_t zero_point = 0; +}; + +// Convenient typed helper for packed matrices. +template +struct PackedMatrix { + // The row/column sums needed for quantized matrix multiplication when + // the opposite operand of the multiplication uses a non-symmetric zero + // point. + // This member is only relevant for packed matrices. + // Additionally, Ruy always uses 32-bit signed accumulators for quantized + // matrix multiplication. + // For floating point types, there is no quantization, so this pointer + // will always be null. We still need code referencing it to compile + // though, even if it is always branched around. Hence we use Scalar* + // itself as the type in that case. + using SumsType = + typename std::conditional::value, Scalar, + std::int32_t>::type; + + Scalar* data = nullptr; + SumsType* sums = nullptr; + PackedLayout layout; + std::int32_t zero_point = 0; +}; + +template +DMatrix ToDMatrix(const Matrix& matrix) { + DMatrix ret; + ret.data_type = Type::Create(); + ret.data = ToVoidPtr(matrix.data.get()); + ret.layout = matrix.layout; + ret.zero_point = matrix.zero_point; + return ret; +} + +template +Matrix ToMatrix(const DMatrix& dmatrix) { + dmatrix.data_type.AssertIs(); + Matrix ret; + ret.data = static_cast(dmatrix.data); + ret.layout = dmatrix.layout; + ret.zero_point = dmatrix.zero_point; + return ret; +} + +template +PackedMatrix ToPackedMatrix(const PMatrix& pmatrix) { + using SumsType = typename PackedMatrix::SumsType; + pmatrix.data_type.AssertIs(); + pmatrix.sums_type.AssertIs(); + PackedMatrix ret; + ret.data = static_cast(pmatrix.data); + ret.sums = static_cast(pmatrix.sums); + ret.layout = pmatrix.layout; + ret.zero_point = pmatrix.zero_point; + return ret; +} + +// Helpers for Layout / PackedLayout. + +inline bool IsPacked(const Layout& layout) { + if (layout.order == Order::kColMajor) { + return layout.stride == layout.rows; + } else { + return layout.stride == layout.cols; + } +} + +inline bool IsRowMajor(const Layout& layout) { + return layout.order == Order::kRowMajor; +} + +template +inline bool IsColMajor(const LayoutOrPackedLayout& layout) { + return layout.order == Order::kColMajor; +} + +template +inline int FlatSize(const LayoutOrPackedLayout& layout) { + const int outerdim = + layout.order == Order::kColMajor ? layout.cols : layout.rows; + return layout.stride * outerdim; +} + +// TODO(b/130417400) add a unit test +inline int Offset(const Layout& layout, int row, int col) { + // TODO(benoitjacob) - should check this but this make the _slow tests take + // 5x longer. Find a mitigation like in Eigen with an 'internal' variant + // bypassing the check? + // RUY_DCHECK_GE(row, 0); + // RUY_DCHECK_GE(col, 0); + // RUY_DCHECK_LT(row, layout.rows); + // RUY_DCHECK_LT(col, layout.cols); + int row_stride = layout.order == Order::kColMajor ? 1 : layout.stride; + int col_stride = layout.order == Order::kRowMajor ? 1 : layout.stride; + return row * row_stride + col * col_stride; +} + +// TODO(b/130417400) add a unit test +inline int Offset(const PackedLayout& layout, int row, int col) { + RUY_DCHECK(is_pot(layout.kernel.rows)); + RUY_DCHECK(is_pot(layout.kernel.cols)); + int row_outer = row & ~(layout.kernel.rows - 1); + int col_outer = col & ~(layout.kernel.cols - 1); + int row_stride_outer = + layout.order == Order::kColMajor ? layout.kernel.cols : layout.stride; + int col_stride_outer = + layout.order == Order::kRowMajor ? layout.kernel.rows : layout.stride; + int offset_outer = + row_outer * row_stride_outer + col_outer * col_stride_outer; + int row_inner = row - row_outer; + int col_inner = col - col_outer; + int row_stride_inner = + layout.kernel.order == Order::kColMajor ? 1 : layout.kernel.cols; + int col_stride_inner = + layout.kernel.order == Order::kRowMajor ? 1 : layout.kernel.rows; + int offset_inner = + row_inner * row_stride_inner + col_inner * col_stride_inner; + return offset_outer + offset_inner; +} + +// Helpers for Matrix. + +template +const Scalar* ElementPtr(const Matrix& mat, int row, int col) { + return mat.data.get() + Offset(mat.layout, row, col); +} + +template +Scalar* ElementPtr(Matrix* mat, int row, int col) { + return mat->data.get() + Offset(mat->layout, row, col); +} + +template +Scalar Element(const Matrix& mat, int row, int col) { + return *ElementPtr(mat, row, col); +} + +// Helpers for PackedMatrix. +// Duplicated from Matrix, but the duplication seems acceptable. + +template +const Scalar* ElementPtr(const PackedMatrix& mat, int row, int col) { + return mat.data + Offset(mat.layout, row, col); +} + +template +Scalar* ElementPtr(PackedMatrix* mat, int row, int col) { + return mat->data + Offset(mat->layout, row, col); +} + +template +Scalar Element(const PackedMatrix& mat, int row, int col) { + return *ElementPtr(mat, row, col); +} + +// Helpers for PMatrix. + +inline std::size_t DataSize(const PMatrix& packed) { + return FlatSize(packed.layout) * packed.data_type.size; +} + +inline std::size_t SumsSize(const PMatrix& packed) { + // Packed matrices are only relevant for Ruy's TrMul implementations. For + // TrMul, the number of sums is always equal to the number of columns. + return packed.layout.cols * packed.sums_type.size; +} + +// Transpose helpers. + +inline void Transpose(Order* order) { + *order = *order == Order::kColMajor ? Order::kRowMajor : Order::kColMajor; +} + +inline void Transpose(Layout* layout) { + Transpose(&layout->order); + std::swap(layout->rows, layout->cols); +} + +template +inline void Transpose(Matrix* matrix) { + Transpose(&matrix->layout); +} + +// Helpers for KernelLayout. + +template +KernelLayout ToKernelLayout() { + KernelLayout ret; + ret.order = FixedKernelLayout::kOrder; + ret.rows = FixedKernelLayout::kRows; + ret.cols = FixedKernelLayout::kCols; + return ret; +} + +} // namespace ruy + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_INTERNAL_MATRIX_H_ diff --git a/ruy/kernel.h b/ruy/kernel.h new file mode 100644 index 0000000..d7930b4 --- /dev/null +++ b/ruy/kernel.h @@ -0,0 +1,31 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_KERNEL_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_KERNEL_H_ + +#include "ruy/platform.h" + +// IWYU pragma: begin_exports +#if RUY_PLATFORM(NEON) +#include "ruy/kernel_arm.h" +#elif RUY_PLATFORM(X86) +#include "ruy/kernel_x86.h" +#else +#include "ruy/kernel_common.h" +#endif +// IWYU pragma: end_exports + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_KERNEL_H_ diff --git a/ruy/kernel_arm.h b/ruy/kernel_arm.h new file mode 100644 index 0000000..408c23a --- /dev/null +++ b/ruy/kernel_arm.h @@ -0,0 +1,211 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_KERNEL_ARM_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_KERNEL_ARM_H_ + +#include +#include + +#include "ruy/common.h" +#include "ruy/internal_matrix.h" +#include "ruy/kernel_common.h" +#include "ruy/matrix.h" +#include "ruy/opt_set.h" +#include "ruy/path.h" +#include "ruy/platform.h" +#include "ruy/profiler/instrumentation.h" +#include "ruy/side_pair.h" +#include "ruy/size_util.h" +#include "ruy/spec.h" +#include "ruy/tune.h" + +namespace ruy { + +#if RUY_PLATFORM(NEON) && RUY_OPT_ENABLED(RUY_OPT_ASM) + +#if RUY_PLATFORM(NEON_64) +void Kernel8bitNeonOutOfOrder(const KernelParams8bit<4, 4>& params); +void Kernel8bitNeonOutOfOrder1Col(const KernelParams8bit<4, 4>& params); +#elif RUY_PLATFORM(NEON_32) +void Kernel8bitNeonOutOfOrder(const KernelParams8bit<4, 2>& params); +void Kernel8bitNeonOutOfOrder1Col(const KernelParams8bit<4, 2>& params); +#endif +void Kernel8bitNeonInOrder(const KernelParams8bit<4, 4>& params); +void Kernel8bitNeonDotprodOutOfOrder(const KernelParams8bit<8, 8>& params); +void Kernel8bitNeonDotprodOutOfOrder1Col(const KernelParams8bit<8, 8>& params); +void Kernel8bitNeonDotprodInOrder(const KernelParams8bit<8, 8>& params); + +#if RUY_PLATFORM(NEON_64) +template +struct Kernel> { + using LhsLayout = FixedKernelLayout; + using RhsLayout = FixedKernelLayout; + Tuning tuning = Tuning::kAuto; + explicit Kernel(Tuning tuning_) : tuning(tuning_) {} + void Run(const PackedMatrix& lhs, + const PackedMatrix& rhs, + const BasicSpec& spec, int start_row, + int start_col, int end_row, int end_col, + Matrix* dst) const { + KernelParams8bit params; + MakeKernelParams8bit(lhs, rhs, spec, start_row, start_col, end_row, end_col, + dst, ¶ms); + if (dst->layout.cols == 1) { + Kernel8bitNeonOutOfOrder1Col(params); + return; + } + if (__builtin_expect(tuning == Tuning::kInOrder, true)) { + Kernel8bitNeonInOrder(params); + } else { + Kernel8bitNeonOutOfOrder(params); + } + } +}; +#endif + +#if RUY_PLATFORM(NEON_32) +template +struct Kernel> { + using LhsLayout = FixedKernelLayout; + using RhsLayout = FixedKernelLayout; + Tuning tuning = Tuning::kAuto; + explicit Kernel(Tuning tuning_) : tuning(tuning_) {} + void Run(const PackedMatrix& lhs, + const PackedMatrix& rhs, + const BasicSpec& spec, int start_row, + int start_col, int end_row, int end_col, + Matrix* dst) const { + KernelParams8bit params; + MakeKernelParams8bit(lhs, rhs, spec, start_row, start_col, end_row, end_col, + dst, ¶ms); + if (dst->layout.cols == 1) { + Kernel8bitNeonOutOfOrder1Col(params); + return; + } + Kernel8bitNeonOutOfOrder(params); + } +}; +#endif + +#if RUY_PLATFORM(NEON_64) +template +struct Kernel> { + Tuning tuning = Tuning::kAuto; + using LhsLayout = FixedKernelLayout; + using RhsLayout = FixedKernelLayout; + explicit Kernel(Tuning tuning_) : tuning(tuning_) {} + void Run(const PackedMatrix& lhs, + const PackedMatrix& rhs, + const BasicSpec& spec, int start_row, + int start_col, int end_row, int end_col, + Matrix* dst) const { + KernelParams8bit params; + MakeKernelParams8bit(lhs, rhs, spec, start_row, start_col, end_row, end_col, + dst, ¶ms); + if (dst->layout.cols == 1) { + Kernel8bitNeonDotprodOutOfOrder1Col(params); + } else if (__builtin_expect(tuning == Tuning::kInOrder, true)) { + Kernel8bitNeonDotprodInOrder(params); + } else { + Kernel8bitNeonDotprodOutOfOrder(params); + } + } +}; +#endif + +void KernelFloatNeonOutOfOrder(const KernelParamsFloat<8, 8>& params); +void KernelFloatNeonInOrder(const KernelParamsFloat<8, 8>& params); +void KernelFloat32NeonOutOfOrder(const KernelParamsFloat<8, 4>& params); +void KernelFloatNeonDotprodInOrder(const KernelParamsFloat<8, 8>& params); + +#if RUY_PLATFORM(NEON_64) +// A Float kernel for ARM64 Neon. +template <> +struct Kernel> { + Tuning tuning = Tuning::kAuto; + using LhsLayout = FixedKernelLayout; + using RhsLayout = FixedKernelLayout; + explicit Kernel(Tuning tuning_) : tuning(tuning_) {} + void Run(const PackedMatrix& lhs, const PackedMatrix& rhs, + const BasicSpec& spec, int start_row, int start_col, + int end_row, int end_col, Matrix* dst) const { + KernelParamsFloat params; + MakeKernelParamsFloat(lhs, rhs, spec, start_row, start_col, end_row, + end_col, dst, ¶ms); + if (__builtin_expect(tuning == Tuning::kInOrder, true)) { + KernelFloatNeonInOrder(params); + } else { + KernelFloatNeonOutOfOrder(params); + } + } +}; +#endif + +#if RUY_PLATFORM(NEON_32) +// A Float kernel for ARM32 Neon. +template <> +struct Kernel> { + Tuning tuning = Tuning::kAuto; + using LhsLayout = FixedKernelLayout; + using RhsLayout = FixedKernelLayout; + explicit Kernel(Tuning tuning_) : tuning(tuning_) {} + void Run(const PackedMatrix& lhs, const PackedMatrix& rhs, + const BasicSpec& spec, int start_row, int start_col, + int end_row, int end_col, Matrix* dst) const { + KernelParamsFloat<8, 4> params; + + MakeKernelParamsFloat(lhs, rhs, spec, start_row, start_col, end_row, + end_col, dst, ¶ms); + + KernelFloat32NeonOutOfOrder(params); + } +}; +#endif + +// While the dotprod NEON extension does not concern floating-point arithmetic, +// its presence allows us to distinguish, in the in-order tuning case, between +// A53 and A55r1. TODO: should this be folded into tuning? +template <> +struct Kernel> { + Tuning tuning = Tuning::kAuto; + using LhsLayout = FixedKernelLayout; + using RhsLayout = FixedKernelLayout; + using Base = + Kernel>; + explicit Kernel(Tuning tuning_) : tuning(tuning_) {} + void Run(const PackedMatrix& lhs, const PackedMatrix& rhs, + const BasicSpec& spec, int start_row, int start_col, + int end_row, int end_col, Matrix* dst) const { + KernelParamsFloat params; + MakeKernelParamsFloat(lhs, rhs, spec, start_row, start_col, end_row, + end_col, dst, ¶ms); + if (__builtin_expect(tuning == Tuning::kInOrder, true)) { + KernelFloatNeonDotprodInOrder(params); + } else { + KernelFloatNeonOutOfOrder(params); + } + } +}; + +#endif // RUY_PLATFORM(NEON) && RUY_OPT_ENABLED(RUY_OPT_ASM) + +} // namespace ruy + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_KERNEL_ARM_H_ diff --git a/ruy/kernel_arm32.cc b/ruy/kernel_arm32.cc new file mode 100644 index 0000000..d537cfe --- /dev/null +++ b/ruy/kernel_arm32.cc @@ -0,0 +1,2499 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/kernel.h" +#include "ruy/opt_set.h" +#include "ruy/platform.h" +#include "ruy/profiler/instrumentation.h" + +namespace ruy { + +#if RUY_PLATFORM(NEON_32) && RUY_OPT_ENABLED(RUY_OPT_ASM) + +#define RUY_ASM_LABEL_STORE_UINT8 91 +#define RUY_ASM_LABEL_STORE_INT8 92 +#define RUY_ASM_LABEL_STORE_INT16 93 +#define RUY_ASM_LABEL_STORE_INT32 94 +#define RUY_ASM_LABEL_AFTER_STORE 99 + +#define RUY_OFFSET_LHS_BASE_PTR 0 +#define RUY_OFFSET_RHS_BASE_PTR 4 +#define RUY_OFFSET_DST_BASE_PTR 8 +#define RUY_OFFSET_BIAS 12 +#define RUY_OFFSET_START_ROW 16 +#define RUY_OFFSET_START_COL 20 +#define RUY_OFFSET_LAST_ROW 24 +#define RUY_OFFSET_LAST_COL 28 +#define RUY_OFFSET_DST_ROWS 32 +#define RUY_OFFSET_DST_COLS 36 +#define RUY_OFFSET_LHS_STRIDE 40 +#define RUY_OFFSET_RHS_STRIDE 44 +#define RUY_OFFSET_DST_STRIDE 48 +#define RUY_OFFSET_DEPTH 52 +#define RUY_OFFSET_CLAMP_MIN 56 +#define RUY_OFFSET_CLAMP_MAX 60 +#define RUY_OFFSET_FLAGS 64 + +#define RUY_STACK_OFFSET_SIZE 96 +#define RUY_STACK_OFFSET_DST_COL_PTR 0 +#define RUY_STACK_OFFSET_DST_PTR 16 +#define RUY_STACK_OFFSET_ROW 32 +#define RUY_STACK_OFFSET_COL 48 +#define RUY_STACK_OFFSET_LHS_COL_PTR 64 +#define RUY_STACK_OFFSET_RHS_COL_PTR 80 + +template +void CheckOffsetsInKernelParamsFloat32(const Params&) { + static_assert(offsetof(Params, lhs_base_ptr) == RUY_OFFSET_LHS_BASE_PTR, ""); + static_assert(offsetof(Params, rhs_base_ptr) == RUY_OFFSET_RHS_BASE_PTR, ""); + static_assert(offsetof(Params, dst_base_ptr) == RUY_OFFSET_DST_BASE_PTR, ""); + static_assert(offsetof(Params, bias) == RUY_OFFSET_BIAS, ""); + static_assert(offsetof(Params, start_row) == RUY_OFFSET_START_ROW, ""); + static_assert(offsetof(Params, start_col) == RUY_OFFSET_START_COL, ""); + static_assert(offsetof(Params, last_row) == RUY_OFFSET_LAST_ROW, ""); + static_assert(offsetof(Params, last_col) == RUY_OFFSET_LAST_COL, ""); + static_assert(offsetof(Params, dst_rows) == RUY_OFFSET_DST_ROWS, ""); + static_assert(offsetof(Params, lhs_stride) == RUY_OFFSET_LHS_STRIDE, ""); + static_assert(offsetof(Params, rhs_stride) == RUY_OFFSET_RHS_STRIDE, ""); + static_assert(offsetof(Params, dst_stride) == RUY_OFFSET_DST_STRIDE, ""); + static_assert(offsetof(Params, depth) == RUY_OFFSET_DEPTH, ""); + static_assert(offsetof(Params, clamp_min) == RUY_OFFSET_CLAMP_MIN, ""); + static_assert(offsetof(Params, clamp_max) == RUY_OFFSET_CLAMP_MAX, ""); + static_assert(offsetof(Params, flags) == RUY_OFFSET_FLAGS, ""); +} + +// Float kernel for ARM32 out-of-order cores. +// Just like Float 64 version, except accumulate in to 8x4 block to only +// use 16 128-bit NEON registers. This is a "first pass" kernel and not +// tuned. It is meant to run on out-of-order CPUs like the Krait 400 or A9. +void KernelFloat32NeonOutOfOrder(const KernelParamsFloat<8, 4>& params) { + CheckOffsetsInKernelParamsFloat32(params); + profiler::ScopeLabel label( + "Kernel (kNeon, optimized for out-of-order cores)"); + + const float* lhs_ptr = params.lhs_base_ptr; + const float* rhs_ptr = params.rhs_base_ptr; + // In ARM32 NEON, there are 16 128-bit "q" registers. These registers are + // each composed of two 64-bit "d" registers. The asm kernel below has the + // following NEON register allocation: + // Registers q3 -- q10 are accumulators. During accumulation, + // q0 -- q2 (d0 -- d5) are used to load data from LHS and RHS. q0 and q1 + // are used to load a 8x1 block of LHS, and q2 is used to load a 1x4 block + // of RHS, like this: + + // Register layout in "q" registers: + // RHS 1x4 block + // /--------------------------\ + // |q2.s[0] ... q2.s[3] | + // \--------------------------/ + // LHS 8x1 block + // /---------------------\ /--------------------- \ + // | q0.s[0] | | q3.s[0] ... q9.s[0] | + // | ... | | ... ... | + // | q0.s[3] | | q3.s[3] q9.s[3] | + // | q1.s[0] | | q4.s[0] q10.s[0] | + // | ... | | ... ... ... | + // | q1.s[3] | | q4.s[3] .. q10.s[3] | + // \---------------------/ \--------------------------/ + // accumulators 8x4 block + // q11, q14, q15 currently unused. q12 and q13 are used to load + // parameters used for the post-accumulation part of the kernel. + // For completeness, here is the register layout in "d" registers: + // RHS 1x4 block + // /--------------------------\ + // |d4[0] ... d5[1] | + // \--------------------------/ + // LHS 8x1 block + // /---------------------\ /--------------------------\ + // | d0[0] | | d6[0] ... d18[0] | + // | ... | | ... ... | + // | d1[1] | | d7[1] d19[1] | + // | d2[0] | | d8[0] d20[0] | + // | ... | | ... ... ... | + // | d3[1] | | d9[1] ... d21[1] | + // \---------------------/ \--------------------------/ + // accumulators 8x4 block + asm volatile( +#define RUY_MAKE_ZERO(reg) "vmov.f32 " #reg ", #0.0\n" + + // clang-format off + + // Load the first 32 bytes of LHS and RHS data. + // Load q0, q1 + "vld1.32 {d0, d1}, [%[lhs_ptr]]!\n" + "vld1.32 {d2, d3}, [%[lhs_ptr]]!\n" + RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n") + // Load q2 + "vld1.32 {d4, d5}, [%[rhs_ptr]]!\n" + RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n") + + "sub sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n" + + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n" + "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n" + + "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n" + "str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" + "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + + "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_START_COL) "]\n" + "str r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" + + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" + + "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_RHS_BASE_PTR) "]\n" + "str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" + // Clear accumulators. + RUY_MAKE_ZERO(q3) + RUY_MAKE_ZERO(q4) + RUY_MAKE_ZERO(q5) + RUY_MAKE_ZERO(q6) + RUY_MAKE_ZERO(q7) + RUY_MAKE_ZERO(q8) + RUY_MAKE_ZERO(q9) + RUY_MAKE_ZERO(q10) + + // r1 is the number of levels of depth that we have already loaded + // LHS and RHS data for. Corresponding to the initial ld1 instructions + // above, this is currently 1. + "mov r1, #1\n" + + // Main loop of the whole GEMM, over rows and columns of the + // destination matrix. + "1:\n" + + // Accumulation loop + "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" + "cmp r1, r2\n" + "beq 79f\n" + + "2:\n" + + "vmla.f32 q3, q0, d4[0]\n" + "vmla.f32 q5, q0, d4[1]\n" + "vmla.f32 q7, q0, d5[0]\n" + "vmla.f32 q9, q0, d5[1]\n" + "vld1.32 {d0, d1}, [%[lhs_ptr]]!\n" // Reload LHS + + "vmla.f32 q4, q1, d4[0]\n" + "vmla.f32 q6, q1, d4[1]\n" + "vmla.f32 q8, q1, d5[0]\n" + "vmla.f32 q10, q1, d5[1]\n" + "vld1.32 {d2, d3}, [%[lhs_ptr]]!\n" // Reload LHS + RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n") + "vld1.32 {d4, d5}, [%[rhs_ptr]]!\n" // Reload RHS + RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n") + + "add r1, r1, #1\n" + "cmp r1, r2\n" + + "blt 2b\n" + + "79:\n" + + // End of the inner loop on depth. Now perform the remaining + // multiply-adds of the last level of depth, for which the LHS + // and RHS data is already loaded. + + "vmla.f32 q3, q0, d4[0]\n" + "vmla.f32 q5, q0, d4[1]\n" + "vmla.f32 q7, q0, d5[0]\n" + "vmla.f32 q9, q0, d5[1]\n" + + "vmla.f32 q4, q1, d4[0]\n" + "vmla.f32 q6, q1, d4[1]\n" + "vmla.f32 q8, q1, d5[0]\n" + "vmla.f32 q10, q1, d5[1]\n" + + // End of accumulation. The registers q3 -- q10 contain the final + // float32 accumulator values of the current 8x8 destination block. + // We now have to compute the final values from these accumulators + // and advance to the next 8x8 block. We intertwine + // these two aspects whenever possible for optimal pipelining, both + // at the data flow level (prefetch data for next block as early as + // possible) and instruction pipelining level (some of the next-block + // work can dual-issue with some of the final work on the current + // block). + + // Logic to advance to the next block in preparation for the next + // iteration of the main loop. For now, we only want to compute + // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are + // not yet ready to update the values of row and col, as we still need + // the current values for the rest of the work on the current block. + + "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" + "ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + "cmp r1, r3\n" // Have we finished the last row? + + "bge 4f\n" // If finished last row, go to 4 + // Not finished last row: then advance to next row. + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" + "add r4, r4, r1, lsl #3\n" + "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" + "b 5f\n" + "4:\n" // Finished last row... + "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + // Go back to first row + "str r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" + // Now we need to advance to the next column. If we already + // finished the last column, then in principle we are done, however + // we can't just return here, as we need to allow the end work of the + // current block to complete. The good news is that at this point it + // doesn't matter what data we load for the next column, since + // we will exit from the main loop below before actually storing + // anything computed from that data. + "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" + "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" + "cmp r8, r4\n" // Have we finished the last column? + "bge 5f\n" // If yes, just carry on without updating the column pointer. + // Not finished last column: then advance to next column. + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" + "ldr r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" + "add r10, r10, r1, lsl #2\n" + "str r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" + "5:\n" + + // Set the LHS and RHS data pointers to the start of the columns just + // computed. + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" + "mov %[lhs_ptr], r4\n" + "ldr r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" + "mov %[rhs_ptr], r5\n" + + // Load some parameters needed for the end work on current block. + "ldrb r4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" + + // Offset these base pointers as needed given the current row, col. + "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + "add r5, r1, r8, lsl #2\n" + + "tst r4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" + "it ne\n" + "movne r1, r5\n" + + // Load 8 bias values. + "vld1.32 {d24, d25, d26, d27}, [r1]\n" + + // Now that we know what LHS and RHS data the next iteration of the + // main loop will need to load, we start loading the first 32 bytes of + // each of LHS and RHS, into q0 -- q2, as we don't need q0 -- q2 anymore + // in the rest of the work on the current block. + // Load q0, q1 + "vld1.32 {d0, d1, d2, d3}, [%[lhs_ptr]]!\n" + RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n") + // Load q2 + "vld1.32 {d4, d5}, [%[rhs_ptr]]!\n" + RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n") + + // Perform the bias-addition (per the above, we have just folded into + // the bias the (depth * lhs_zero_point * rhs_zero_point) term.) + "vadd.f32 q3, q3, q12\n" + "vadd.f32 q4, q4, q13\n" + "vadd.f32 q5, q5, q12\n" + "vadd.f32 q6, q6, q13\n" + "vadd.f32 q7, q7, q12\n" + "vadd.f32 q8, q8, q13\n" + "vadd.f32 q9, q9, q12\n" + "vadd.f32 q10, q10, q13\n" + + // Load the clamp_min, clamp_max bounds + "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + "vdup.32 q12, r2\n" // clamp_min + "vdup.32 q13, r3\n" // clamp_max + + // Apply the clamp_min bound + "vmax.f32 q3, q3, q12\n" + "vmax.f32 q4, q4, q12\n" + "vmax.f32 q5, q5, q12\n" + "vmax.f32 q6, q6, q12\n" + "vmax.f32 q7, q7, q12\n" + "vmax.f32 q8, q8, q12\n" + "vmax.f32 q9, q9, q12\n" + "vmax.f32 q10, q10, q12\n" + + // Apply the clamp_max bound + "vmin.f32 q3, q3, q13\n" + "vmin.f32 q4, q4, q13\n" + "vmin.f32 q5, q5, q13\n" + "vmin.f32 q6, q6, q13\n" + "vmin.f32 q7, q7, q13\n" + "vmin.f32 q8, q8, q13\n" + "vmin.f32 q9, q9, q13\n" + "vmin.f32 q10, q10, q13\n" + + // Compute how much of the 8x4 block of destination values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 8x4, there are some 8x8 blocks along the boundaries that do + // not fit entirely. + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n" + "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + "sub r1, r1, r8\n" + + "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n" + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" + "sub r2, r2, r4\n" + "mov r3, #8\n" + "mov r5, #4\n" + "cmp r1, #8\n" + // Compute r1 = how many rows of the 8x4 block fit + "it gt\n" + "movgt r1, r3\n" + "cmp r2, #4\n" + // Compute r2 = how many cols of the 8x4 block fit + "it gt\n" + "movgt r2, r5\n" + + // Test if r1==8 && r2 == 4, i.e. if all of the 8x4 block fits. + "cmp r1, r3\n" + "it eq\n" + "cmpeq r2, r5\n" + // Yes, all of the 8x4 block fits, go to fast path. + "beq 30f\n" + // Not all of the 8x4 block fits. + // Set (r3 address, r4 stride) to write to dst_tmp_buf + "mov r3, %[dst_tmp_buf]\n" + "mov r4, #32\n" + "b 31f\n" + "30:\n" + // Yes, all of the 8x4 block fits. + // Set (r3 address, r4 stride) to write directly to destination matrix. + "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" + "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "mov r4, r5\n" + "31:\n" + + // Write our float values to the destination described by + // (r3 address, r4 stride) + "vst1.32 {d6, d7, d8, d9}, [r3]\n" + "add r3, r3, r4\n" + RUY_MAKE_ZERO(q3) + RUY_MAKE_ZERO(q4) + "vst1.32 {d10, d11, d12, d13}, [r3]\n" + "add r3, r3, r4\n" + RUY_MAKE_ZERO(q5) + RUY_MAKE_ZERO(q6) + "vst1.32 {d14, d15, d16, d17}, [r3]\n" + "add r3, r3, r4\n" + RUY_MAKE_ZERO(q7) + RUY_MAKE_ZERO(q8) + "vst1.32 {d18, d19, d20, d21}, [r3]\n" + "add r3, r3, r4\n" + RUY_MAKE_ZERO(q9) + RUY_MAKE_ZERO(q10) + + // If all of the 8x4 block fits, we just finished writing it to the + // destination, so we skip the next part. + "beq 41f\n" + // Not all of the 8x8 block fits in the destination matrix. We just + // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over + // it to copy into the destination matrix the part that fits. + "ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" + "mov r3, %[dst_tmp_buf]\n" + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "mov r6, #0\n" + "50:\n" + "mov r5, #0\n" + "51:\n" + "ldr r10, [r3, r5, lsl #2]\n" + "str r10, [r4, r5, lsl #2]\n" + "add r5, r5, #1\n" + "cmp r5, r1\n" + "blt 51b\n" + "add r6, r6, #1\n" + "add r3, r3, #32\n" + "add r4, r4, r8\n" + // r2 = how many cols of the 8x4 block fit + "cmp r6, r2\n" + "blt 50b\n" + "41:\n" + // Load dst_ptr, increment, and write back. + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "add r4, r4, #32\n" + "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + // At this point we have completely finished writing values to the + // destination matrix for the current block. + + // Reload some params --- we had used r3, r5, r10 for a few other things + // since the last time we had loaded them. + "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + "ldr r6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" + "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" + + // Move to the next block of the destination matrix, for the next iter + // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already + // been updated earlier. + // Have we reached the end row? + "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + "cmp r8, r3\n" + + "beq 20f\n" // yes, end row. + // Not end row. Move to the next row. + "add r8, r8, #8\n" + // Store new value of row + "str r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + + "b 21f\n" + "20:\n" + // Was already at end row. + // Move back to first row. + "str r6, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + // Move to the next column. + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" + "add r4, r4, #4\n" + "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" + + "ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" + "ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n" + // Increment dst_col_ptr by 4 * dst_stride (i.e. 4 columns) + "add r1, r1, r8, lsl #2\n" + // Store dst_col_ptr + "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n" + // Store dst_ptr + "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "21:\n" + + // Main loop exit condition: have we hit the end column? + "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" + "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" + "cmp r8, r4\n" + + // r1 is the number of levels of depth that we have already loaded + // LHS and RHS data for. Corresponding to the initial ld1 instructions + // above, this is currently 1. + "mov r1, #1\n" + + "ble 1b\n" + + // Restore stack pointer. + "add sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n" + + // clang-format on + : [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr) + : [ params ] "r"(¶ms), [dst_tmp_buf] "r"(params.dst_tmp_buf) + // Clobber list must specify q registers (and not their constituent + // d registers). There is a (currently unexplained) slowdown if + // d registers are listed in the clobbers list. + : "r0", "r1", "r2", "r3", "r4", "r5", "r6", "r8", "r10", "cc", + "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", + "q9", "q10", "q12", "q13"); +} + +#undef RUY_MAKE_ZERO +#undef RUY_STACK_OFFSET_SIZE +#undef RUY_STACK_OFFSET_DST_COL_PTR +#undef RUY_STACK_OFFSET_DST_PTR +#undef RUY_STACK_OFFSET_ROW +#undef RUY_STACK_OFFSET_COL +#undef RUY_STACK_OFFSET_LHS_COL_PTR +#undef RUY_STACK_OFFSET_RHS_COL_PTR + +#undef RUY_OFFSET_LHS_BASE_PTR +#undef RUY_OFFSET_RHS_BASE_PTR +#undef RUY_OFFSET_DST_BASE_PTR +#undef RUY_OFFSET_BIAS +#undef RUY_OFFSET_START_ROW +#undef RUY_OFFSET_START_COL +#undef RUY_OFFSET_LAST_ROW +#undef RUY_OFFSET_LAST_COL +#undef RUY_OFFSET_DST_ROWS +#undef RUY_OFFSET_DST_COLS +#undef RUY_OFFSET_LHS_STRIDE +#undef RUY_OFFSET_RHS_STRIDE +#undef RUY_OFFSET_DST_STRIDE +#undef RUY_OFFSET_DEPTH +#undef RUY_OFFSET_CLAMP_MIN +#undef RUY_OFFSET_CLAMP_MAX +#undef RUY_OFFSET_FLAGS + +#define RUY_OFFSET_BIAS 0 +#define RUY_OFFSET_LHS_SUMS 4 +#define RUY_OFFSET_RHS_SUMS 8 +#define RUY_OFFSET_LHS_BASE_PTR 12 +#define RUY_OFFSET_MULTIPLIER_FIXEDPOINT 16 +#define RUY_OFFSET_MULTIPLIER_EXPONENT 20 +#define RUY_OFFSET_RHS_BASE_PTR 24 +#define RUY_OFFSET_DST_BASE_PTR 28 +#define RUY_OFFSET_LHS_ZERO_POINT 32 +#define RUY_OFFSET_RHS_ZERO_POINT 36 +#define RUY_OFFSET_DST_ZERO_POINT 40 +#define RUY_OFFSET_PROD_ZP_DEPTH 44 +#define RUY_OFFSET_START_ROW 48 +#define RUY_OFFSET_START_COL 52 +#define RUY_OFFSET_LAST_ROW 56 +#define RUY_OFFSET_LAST_COL 60 +#define RUY_OFFSET_DST_ROWS 64 +#define RUY_OFFSET_DST_COLS 68 +#define RUY_OFFSET_LHS_STRIDE 72 +#define RUY_OFFSET_RHS_STRIDE 76 +#define RUY_OFFSET_DST_STRIDE 80 +#define RUY_OFFSET_DEPTH 84 +#define RUY_OFFSET_CLAMP_MIN 88 +#define RUY_OFFSET_CLAMP_MAX 92 +#define RUY_OFFSET_FLAGS 96 +#define RUY_OFFSET_DST_TYPE_ID 97 + +#define RUY_STACK_OFFSET_SIZE 96 +#define RUY_STACK_OFFSET_DST_COL_PTR 0 +#define RUY_STACK_OFFSET_DST_PTR 16 +#define RUY_STACK_OFFSET_ROW 32 +#define RUY_STACK_OFFSET_COL 48 +#define RUY_STACK_OFFSET_LHS_COL_PTR 64 +#define RUY_STACK_OFFSET_RHS_COL_PTR 80 + +template +void CheckOffsetsInKernelParams8bit(const Params&) { + static_assert(offsetof(Params, lhs_zero_point) == RUY_OFFSET_LHS_ZERO_POINT, + ""); + static_assert(offsetof(Params, rhs_zero_point) == RUY_OFFSET_RHS_ZERO_POINT, + ""); + static_assert(offsetof(Params, dst_zero_point) == RUY_OFFSET_DST_ZERO_POINT, + ""); + static_assert(offsetof(Params, prod_zp_depth) == RUY_OFFSET_PROD_ZP_DEPTH, + ""); + static_assert(offsetof(Params, multiplier_fixedpoint) == + RUY_OFFSET_MULTIPLIER_FIXEDPOINT, + ""); + static_assert( + offsetof(Params, multiplier_exponent) == RUY_OFFSET_MULTIPLIER_EXPONENT, + ""); + static_assert(offsetof(Params, clamp_min) == RUY_OFFSET_CLAMP_MIN, ""); + static_assert(offsetof(Params, clamp_max) == RUY_OFFSET_CLAMP_MAX, ""); + static_assert(offsetof(Params, bias) == RUY_OFFSET_BIAS, ""); + static_assert(offsetof(Params, lhs_sums) == RUY_OFFSET_LHS_SUMS, ""); + static_assert(offsetof(Params, rhs_sums) == RUY_OFFSET_RHS_SUMS, ""); + static_assert(offsetof(Params, flags) == RUY_OFFSET_FLAGS, ""); + static_assert(offsetof(Params, lhs_base_ptr) == RUY_OFFSET_LHS_BASE_PTR, ""); + static_assert(offsetof(Params, start_row) == RUY_OFFSET_START_ROW, ""); + static_assert(offsetof(Params, last_row) == RUY_OFFSET_LAST_ROW, ""); + static_assert(offsetof(Params, last_col) == RUY_OFFSET_LAST_COL, ""); + static_assert(offsetof(Params, lhs_stride) == RUY_OFFSET_LHS_STRIDE, ""); + static_assert(offsetof(Params, rhs_stride) == RUY_OFFSET_RHS_STRIDE, ""); + static_assert(offsetof(Params, dst_stride) == RUY_OFFSET_DST_STRIDE, ""); + static_assert(offsetof(Params, depth) == RUY_OFFSET_DEPTH, ""); +} + +// Fast-int8 kernel, ported from ARM 64 version. +// Relevant target CPUs for this kernel include Krait 400 and A9, +// since these are 32-bit, out-of-order CPUs. +void Kernel8bitNeonOutOfOrder(const KernelParams8bit<4, 2>& params) { + profiler::ScopeLabel label( + "Kernel (kNeon, optimized for out-of-order cores)"); + + CheckOffsetsInKernelParams8bit(params); + + const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; + const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; + const std::int8_t* lhs_ptr = lhs_col_ptr; + const std::int8_t* rhs_ptr = rhs_col_ptr; + + // The asm kernel below has the following NEON register allocation: + // + // q6 - q13 are 128-bit (4x32b) accumulators. + // During accumulation, d0 -- d7 are used to load int8 data from LHS and + // d8 -- d11 from RHS: + // int8 RHS 16x2 block + // /-----------------------------\ + // |d8.b[0-7] ..... d10.b[0-7]| + // | ... ... | + // |d9.b[0-7] ..... d11.b[0-7]| + // \-----------------------------/ + // int8 LHS 4x16 block + // /------------------------\ /-----------------------------\ + // |d0.b[0-7] ... d1.b[0-7] | | q6 ..... q10 | + // |d2.b[0-7] ... d3.b[0-7] | | q7 ..... q11 | + // (Reload d0, d1, d2, d3) + // |d0.b[0-7] ... d1.b[0-7] | | q8 ..... q12 | + // |d2.b[0-7] ... d3.b[0-7] | | q9 ..... q13 | + // \------------------------/ \-----------------------------/ + // 128-bit accumulators 4x2 block + // + // No attempt had been made so far at implementing the RUY_OPT_MAX_STREAMING + // optimization for this kernel. + asm volatile( +#define RUY_MAKE_ZERO(reg) "vmov.i32 " #reg ", #0x00000000\n" + + // clang-format off + + // Load the first 64 bytes of LHS and RHS data. + "vld1.8 {d0, d1}, [%[lhs_ptr]]!\n" + // Clear accumulators. + RUY_MAKE_ZERO(q6) + "vld1.8 {d2, d3}, [%[lhs_ptr]]!\n" + RUY_MAKE_ZERO(q8) + RUY_MAKE_ZERO(q9) + RUY_MAKE_ZERO(q10) + "vld1.8 {d8, d9}, [%[rhs_ptr]]!\n" + RUY_MAKE_ZERO(q11) + "vld1.8 {d10, d11}, [%[rhs_ptr]]!\n" + + "sub sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n" + + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n" + RUY_MAKE_ZERO(q12) + "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n" + + "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n" + RUY_MAKE_ZERO(q13) + "str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" + RUY_MAKE_ZERO(q14) + "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + + "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_START_COL) "]\n" + RUY_MAKE_ZERO(q15) + "str r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" + + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" + + "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_RHS_BASE_PTR) "]\n" + "str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" + + + // r1 is the number of levels of depth that we have already loaded + // LHS and RHS data for. Corresponding to the initial ld1 instructions + // above, this is currently 16. + "mov r1, #16\n" + + // Main loop of the whole GEMM, over rows and columns of the + // destination matrix. + "1:\n" + + // r1 is how many levels of depth we have already loaded + // data for, r10 is the total depth. + "ldr r10, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" + "cmp r1, r10\n" + "beq 79f\n" + + "2:\n" + + // Mult, mult-acc in to q14, q15, q2, q3 + "vmull.s8 q14, d0, d8\n" + "vmull.s8 q2, d0, d10\n" + + "vmull.s8 q15, d2, d8\n" + "vmull.s8 q3, d2, d10\n" + + "vmlal.s8 q14, d1, d9\n" + "vmlal.s8 q2, d1, d11\n" + "vmlal.s8 q15, d3, d9\n" + "vmlal.s8 q3, d3, d11\n" + "vld1.8 {d0, d1, d2, d3}, [%[lhs_ptr]]!\n" // Reload LHS + + // Then pairwise accumulate in to q6, q7, q10, q11 + "vpadal.s16 q6, q14\n" + "vpadal.s16 q7, q15\n" + "vpadal.s16 q10, q2\n" + "vpadal.s16 q11, q3\n" + + // Mult, mult-acc in to q14, q15, q2, q3 + "vmull.s8 q14, d0, d8\n" + "vmull.s8 q2, d0, d10\n" + + "vmull.s8 q15, d2, d8\n" + "vmull.s8 q3, d2, d10\n" + + "vmlal.s8 q14, d1, d9\n" + "vmlal.s8 q2, d1, d11\n" + "vmlal.s8 q15, d3, d9\n" + "vmlal.s8 q3, d3, d11\n" + "vld1.8 {d0, d1, d2, d3}, [%[lhs_ptr]]!\n" // Reload LHS + + // Then pairwise accumulate in to q8, q9, q12, q13 + "vpadal.s16 q8, q14\n" + "vld1.8 {d8, d9, d10, d11}, [%[rhs_ptr]]!\n" + "vpadal.s16 q9, q15\n" + "vpadal.s16 q12, q2\n" + "vpadal.s16 q13, q3\n" + + // Prefetch the next 64 bytes of LHS and RHS data. + RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n") + RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n") + + // Each iteration of this loop advances by 16 levels of depth. + "add r1, r1, #16\n" + + // Loop termination condition + "cmp r1, r10\n" + + "blt 2b\n" + + "79:\n" + + // Mult, mult-acc in to q14, q15, q2, q3 + "vmull.s8 q14, d0, d8\n" + "vmull.s8 q2, d0, d10\n" + + "vmull.s8 q15, d2, d8\n" + "vmull.s8 q3, d2, d10\n" + + "vmlal.s8 q14, d1, d9\n" + "vmlal.s8 q2, d1, d11\n" + "vmlal.s8 q15, d3, d9\n" + "vmlal.s8 q3, d3, d11\n" + "vld1.8 {d0, d1, d2, d3}, [%[lhs_ptr]]!\n" // Reload LHS + + // Then pairwise accumulate in to q6, q7, q10, q11 + "vpadal.s16 q6, q14\n" + "vpadal.s16 q7, q15\n" + "vpadal.s16 q10, q2\n" + "vpadal.s16 q11, q3\n" + + // Mult, mult-acc in to q14, q15, q2, q3 + "vmull.s8 q14, d0, d8\n" + "vmull.s8 q2, d0, d10\n" + + "vmull.s8 q15, d2, d8\n" + "vmull.s8 q3, d2, d10\n" + + "vmlal.s8 q14, d1, d9\n" + "vmlal.s8 q2, d1, d11\n" + "vmlal.s8 q15, d3, d9\n" + "vmlal.s8 q3, d3, d11\n" + + // Then pairwise accumulate in to q8, q9, q12, q13 + "vpadal.s16 q8, q14\n" + "vpadal.s16 q9, q15\n" + "vpadal.s16 q12, q2\n" + "vpadal.s16 q13, q3\n" + + + // All accumulation over depth done. q6 - q13 contain the 4x32b + // accumulators for the 4x2 final matrix. + // We now have to compute the final 8-bit values from these int32 + // accumulators, and advance to the next 4x2 block. We intertwine + // these two aspects whenever possible for optimal pipelining, both + // at the data flow level (prefetch data for next block as early as + // possible) and instruction pipelining level (some of the next-block + // work can dual-issue with some of the final work on the current + // block). + + // q6-q13 now contain 4 x 32b + "vpadd.i32 d0, d12, d13\n" + "vpadd.i32 d1, d14, d15\n" + "vpadd.i32 d2, d16, d17\n" + "vpadd.i32 d3, d18, d19\n" + "vpadd.i32 d4, d20, d21\n" + "vpadd.i32 d5, d22, d23\n" + "vpadd.i32 d6, d24, d25\n" + "vpadd.i32 d7, d26, d27\n" + + // d0-d7 each contain 2 x 32b accumulators. + // Need to add pairwise to get 1 x 32b for each of the 4x2 entries + // of destination, (Four 'd' registers total) + "vpadd.i32 d28, d0, d1\n" + "vpadd.i32 d29, d2, d3\n" + "vpadd.i32 d30, d4, d5\n" + "vpadd.i32 d31, d6, d7\n" + + //Now d28 - d31 have the 1 x 32b accumulators for the 4x2 entries + + // Logic to advance to the next block in preparation for the next + // iteration of the main loop. For now, we only want to compute + // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are + // not yet ready to update the values of row and col, as we still need + // the current values for the rest of the work on the current block. + + "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" + "ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + "cmp r1, r3\n" // Have we finished the last row? + + "bge 4f\n" // If finished last row, go to 4 + // Not finished last row: then advance to next row. + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" + "add r4, r4, r1, lsl #2\n" + "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" + "b 5f\n" + "4:\n" // Finished last row... + "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + // Go back to first row + "str r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" + + // Now we need to advance to the next column. If we already + // finished the last column, then in principle we are done, however + // we can't just return here, as we need to allow the end work of the + // current block to complete. The good news is that at this point it + // doesn't matter what data we load for the next column, since + // we will exit from the main loop below before actually storing + // anything computed from that data. + + "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" + "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" + "cmp r8, r4\n" // Have we finished the last column? + "bge 5f\n" // If yes, just carry on without updating the column pointer. + // Not finished last column: then advance to next column. + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" + "ldr r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" + "add r10, r10, r1, lsl #1\n" + "str r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" + "5:\n" + + // Set the LHS and RHS data pointers to the start of the columns just + // computed. + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" + "mov %[lhs_ptr], r4\n" + "ldr r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" + "mov %[rhs_ptr], r5\n" + + // Now we load: bias data, LHS sums data, RHS sums data. + + // First, load the base pointers from the params. + "ldrb r4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" + + // Offset these base pointers as needed given the current row, col. + "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + "add r5, r1, r8, lsl #2\n" + + "tst r4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" + "it ne\n" + "movne r1, r5\n" + + // Load 4 bias values. + "vld1.32 {d24, d25}, [r1]\n" + + // Now that we know what LHS and RHS data the next iteration of the + // main loop will need to load, we start loading the first 32 bytes of + // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore + // in the rest of the work on the current block. + "vld1.8 {d0, d1, d2, d3}, [%[lhs_ptr]]!\n" + RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n") + "vld1.8 {d8, d9, d10, d11}, [%[rhs_ptr]]!\n" + RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n") + + // Add to the bias values the product + // (depth * lhs_zero_point * rhs_zero_point), + // See the term NZ1Z2 in equation (7) in + // https://arxiv.org/pdf/1712.05877.pdf + "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n" + "vdup.32 q9, r3\n" + "vadd.i32 q12, q12, q9\n" + + // Perform the bias-addition (per the above, we have just folded into + // the bias the (depth * lhs_zero_point * rhs_zero_point) term.) + "vadd.i32 q14, q14, q12\n" + "vadd.i32 q15, q15, q12\n" + + // LHS/RHS zero points + // Has RHS sums + "ldrb r6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" + "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n" + "beq 401f\n" + "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n" + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" + // Offset by current col * number of bytes per value + "add r3, r3, r4, lsl #2\n" + "vld1.32 { d12 }, [r3]\n" + "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n" + "vdup.32 q10, r5\n" // create lhs_zero_point_vec + // Subtract rhs_sums * lhs_zero_point, per + // equation (7) in https://arxiv.org/pdf/1712.05877.pdf + "vmls.i32 q14, q10, d12[0]\n" + "vmls.i32 q15, q10, d12[1]\n" + "401:\n" + + // Has LHS sums + "ldrb r6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" + "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n" + "beq 402f\n" + "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n" + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + // Offset by current row * number of bytes per value + "add r2, r2, r4, lsl #2\n" + "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n" + + // Load 4 lhs_sums values. + "vld1.32 {d22, d23}, [r2]\n" + "vdup.32 d13, r5\n" // rhs_zero_point + + // Compute lhs_sums * rhs_zero_point. + "vmul.i32 q11, q11, d13[1]\n" + // Subtract lhs_sums * rhs_zero_point, per + // equation (7) in https://arxiv.org/pdf/1712.05877.pdf + "vsub.s32 q14, q14, q11\n" + "vsub.s32 q15, q15, q11\n" + + // If the destination is int32, it means the user asks for the raw + // accumulators, no need for us to downquantize the value. + "ldrb r10, [%[params], #" RUY_STR(RUY_OFFSET_DST_TYPE_ID) "]\n" + "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n" + "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n" + + "402:\n" + + // At this point we have computed the final int32 values. Now we + // start down-quantizing them to obtain the final 8bit values from them. + + // As part of this down-quantization, our int32 values will be + // multiplied by a multiplier that has a fixed-point component and an + // exponent component. + + //Load the exponent part of the multiplier. + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n" + "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + "add r5, r1, r4, lsl #2\n" + "it ne\n" + "movne r1, r5\n" + + "vld1.32 {q10}, [r1]\n" + + RUY_MAKE_ZERO(q8) + "vmax.s32 q12, q10, q8\n" + + "vshl.s32 q14, q14, q12\n" + "vshl.s32 q15, q15, q12\n" + + "vmin.s32 q12, q10, q8\n" + + // Load fixed point part of the multiplier + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n" + // r6 has flags, r4 has row + "add r5, r1, r4, lsl #2\n" + "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" + "it ne\n" + "movne r1, r5\n" + "vld1.32 {q10}, [r1]\n" // multiplier_fixedpoint + + // Apply the fixed-point part of the multiplier. + "vqrdmulh.s32 q14, q14, q10\n" + "vqrdmulh.s32 q15, q15, q10\n" + + // We have some rounding division-by-power-of-two to do. This should + // always use "round to nearest". We allow for some + // freedom in how ties are broken, to strike a good compromise of + // performance on given hardware vs. perfect agreement of results + // across hardware. + // + // When RUY_OPT_NATIVE_ROUNDING is enabled, we allow for implementation + // defined tie-breaks to help performance. On NEON, this means that we + // can just use the NEON rounding instructions, such as srshl. They + // happen to be breaking ties upward. + // + // When RUY_OPT_NATIVE_ROUNDING is disabled, we implement strict + // break-ties-away-from zero, as described in Appendix B of + // https://arxiv.org/pdf/1712.05877.pdf + // When we wrote that, we thought that that would be better unbiased + // than the NEON upwards tie-breaks, and we had observed some + // improvement on some model. However, that is only more unbiased for + // data centered at zero, which was likely the case in that model, + // but is not always the case. If we wanted something more consistently + // unbiased then we should try breaking ties toward-nearest-even. +#if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING) + // Fix up values to be right-shifted, so that the (round to nearest, + // break ties upward) behavior of srshl applied to these fixed-up + // values, produces the same result as the desired (round to nearest, + // break ties away from zero) behavior on the original values. + "vand q8, q14, q12\n" + "vand q9, q15, q12\n" + "vshr.s32 q8, q8, #31\n" + "vshr.s32 q9, q9, #31\n" + "vqadd.s32 q14, q14, q8\n" + "vqadd.s34 q15, q15, q9\n" + +#endif + // At this point we have reduced the problem of correctly implementing + // rounding divide-by-power-of-two, to what the SRSHL instruction can + // do. + "vrshl.s32 q14, q14, q12\n" + "vrshl.s32 q15, q15, q12\n" + + "ldrb r10, [%[params], #" RUY_STR(RUY_OFFSET_DST_TYPE_ID) "]\n" + "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n" + "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n" + "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n" + "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n" + + // Store uint8 values: + RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n" + + // Cast-and-saturate from int32 to int16 + // After this, all values for output are in q14. + "vqmovn.s32 d28, q14\n" + "vqmovn.s32 d29, q15\n" + + // At this point, d12 -- d26, d30, d31 aren't used anymore for the + // current block, so we can start clearing these accumulators for the + // next block (next iteration of the main loop). + RUY_MAKE_ZERO(q6) + RUY_MAKE_ZERO(q7) + RUY_MAKE_ZERO(q8) + RUY_MAKE_ZERO(q9) + RUY_MAKE_ZERO(q10) + RUY_MAKE_ZERO(q11) + RUY_MAKE_ZERO(q12) + RUY_MAKE_ZERO(q13) + RUY_MAKE_ZERO(q15) + + // Load the destination zero point into each of the 8 16-bit slots + // in a q register. + "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" + "vdup.16 q13, r4\n" // dst_zero_point + + // Add the destination zero point + "vadd.i16 q14, q14, q13\n" + + // Cast-and-saturate from int16 to uint8 + // Now all 8 1-byte values are in d30. + "vqmovun.s16 d30, q14\n" + + // Load the clamp_min, clamp_max bounds + "ldrb r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + "ldrb r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + "vdup.8 d28, r2\n" // clamp_min + "vdup.8 d29, r3\n" // clamp_max + + // Apply the clamp_min bound + "vmax.u8 d30, d30, d28\n" + // Apply the clamp_max bound + "vmin.u8 d30, d30, d29\n" + + // Compute how much of the 4x2 block of destination 8bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 4x2, there are some 4x2 blocks along the boundaries that do + // not fit entirely. + + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n" + "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + "sub r1, r1, r8\n" + + "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n" + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" + "sub r2, r2, r4\n" + "mov r3, #4\n" + "mov r5, #2\n" + "cmp r1, #4\n" + // Compute r1 = how many rows of the 4x2 block fit + "it gt\n" + "movgt r1, r3\n" + + "cmp r2, #2\n" + // Compute r2 = how many cols of the 4x2 block fit + "it gt\n" + "movgt r2, r5\n" + + // Test if r1==4 && r2 == 2, i.e. if all of the 4x2 block fits. + "cmp r1, r3\n" + "it eq\n" + "cmpeq r2, r5\n" + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" + // Yes, all of the 4x2 block fits, go to fast path. + "beq 30f\n" + // Not all of the 4x2 block fits. + // Store to dst_tmp_buf + // Set r3 address to write to dst_tmp_buf. + "mov r3, %[dst_tmp_buf]\n" + "vst1.8 {d30}, [r3]\n" + + // Slow loop copying from dst_tmp_buf to dst. + "mov r6, #0\n" + "50:\n" + "mov r8, #0\n" + "51:\n" + "ldrb r10, [r3, r8]\n" + "strb r10, [r4, r8]\n" + "add r8, r8, #1\n" + "cmp r8, r1\n" + "blt 51b\n" + "add r6, r6, #1\n" + "add r3, r3, #4\n" + "add r4, r4, r5\n" + "cmp r6, r2\n" + "blt 50b\n" + "b 31f\n" + "30:\n" + // Yes, all of the 4x2 block fits. + // r3 address, r5 stride + "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "mov r4, r3\n" + "mov r6, #1\n" + + "vst1.32 {d30[0]}, [r3]\n" + "add r4, r4, r5\n" + "mov r3, r4\n" + "vst1.32 {d30[1]}, [r3]\n" + + "31:\n" + + // Load dst_ptr, increment, and write back. + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "add r4, r4, #4\n" + "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + + RUY_MAKE_ZERO(q13) + RUY_MAKE_ZERO(q14) + RUY_MAKE_ZERO(q15) + + "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" + + // Store int8 values: + RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n" + + // Cast-and-saturate from int32 to int16 + // After this, all values for output are in q14. + "vqmovn.s32 d28, q14\n" + "vqmovn.s32 d29, q15\n" + + // At this point, d12 -- d26, d30, d31 aren't used anymore for the + // current block, so we can start clearing these accumulators for the + // next block (next iteration of the main loop). + RUY_MAKE_ZERO(q6) + RUY_MAKE_ZERO(q7) + RUY_MAKE_ZERO(q8) + RUY_MAKE_ZERO(q9) + RUY_MAKE_ZERO(q10) + RUY_MAKE_ZERO(q11) + RUY_MAKE_ZERO(q12) + RUY_MAKE_ZERO(q13) + RUY_MAKE_ZERO(q15) + + // Load the destination zero point into each of the 8 16-bit slots + // in a q register. + "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" + "vdup.16 q13, r4\n" // dst_zero_point + + // Add the destination zero point + "vadd.i16 q14, q14, q13\n" + + // Cast-and-saturate from int16 to int8 + // Now all 8 1-byte values are in d30. + "vqmovn.s16 d30, q14\n" + + // Load the clamp_min, clamp_max bounds + "ldrb r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + "ldrb r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + "vdup.8 d28, r2\n" // clamp_min + "vdup.8 d29, r3\n" // clamp_max + + // Apply the clamp_min bound + "vmax.s8 d30, d30, d28\n" + // Apply the clamp_max bound + "vmin.s8 d30, d30, d29\n" + + // Compute how much of the 4x2 block of destination 8bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 4x2, there are some 4x2 blocks along the boundaries that do + // not fit entirely. + + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n" + "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + "sub r1, r1, r8\n" + + "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n" + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" + "sub r2, r2, r4\n" + "mov r3, #4\n" + "mov r5, #2\n" + "cmp r1, #4\n" + // Compute r1 = how many rows of the 4x2 block fit + "it gt\n" + "movgt r1, r3\n" + + "cmp r2, #2\n" + // Compute r2 = how many cols of the 4x2 block fit + "it gt\n" + "movgt r2, r5\n" + + // Test if r1==4 && r2 == 2, i.e. if all of the 4x2 block fits. + "cmp r1, r3\n" + "it eq\n" + "cmpeq r2, r5\n" + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" + // Yes, all of the 4x2 block fits, go to fast path. + "beq 30f\n" + // Not all of the 4x2 block fits. + // Store to dst_tmp_buf + // Set r3 address to write to dst_tmp_buf. + "mov r3, %[dst_tmp_buf]\n" + "vst1.8 {d30}, [r3]\n" + + // Slow loop copying from dst_tmp_buf to dst. + "mov r6, #0\n" + "50:\n" + "mov r8, #0\n" + "51:\n" + "ldrb r10, [r3, r8]\n" + "strb r10, [r4, r8]\n" + "add r8, r8, #1\n" + "cmp r8, r1\n" + "blt 51b\n" + "add r6, r6, #1\n" + "add r3, r3, #4\n" + "add r4, r4, r5\n" + "cmp r6, r2\n" + "blt 50b\n" + "b 31f\n" + "30:\n" + // Yes, all of the 4x2 block fits. + // r3 address, r5 stride + "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "mov r4, r3\n" + "mov r6, #1\n" + + "vst1.32 {d30[0]}, [r3]\n" + "add r4, r4, r5\n" + "mov r3, r4\n" + "vst1.32 {d30[1]}, [r3]\n" + + "31:\n" + + // Load dst_ptr, increment, and write back. + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "add r4, r4, #4\n" + "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + + RUY_MAKE_ZERO(q13) + RUY_MAKE_ZERO(q14) + RUY_MAKE_ZERO(q15) + + "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n" + + // Load the destination zero point into each of the 4 32-bit slots + // in a q register. + "ldrsh r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" + "vdup.32 q13, r4\n" // dst_zero_point + // Add the destination zero point + "vadd.s32 q14, q14, q13\n" + "vadd.s32 q15, q15, q13\n" + + // Cast-and-saturate from int32 to int16 + // After this, all values for output are in q14. + "vqmovn.s32 d28, q14\n" + "vqmovn.s32 d29, q15\n" + + // At this point, v18 -- v31 aren't used anymore for the current block, + // so we can start clearing these accumulators for the next block + // (next iteration of the main loop). + RUY_MAKE_ZERO(q6) + RUY_MAKE_ZERO(q7) + RUY_MAKE_ZERO(q8) + RUY_MAKE_ZERO(q9) + RUY_MAKE_ZERO(q10) + RUY_MAKE_ZERO(q11) + RUY_MAKE_ZERO(q15) + + // Load the clamp_min, clamp_max bounds + "ldrh r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + "ldrh r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + "vdup.16 q12, r2\n" // clamp_min + "vdup.16 q13, r3\n" // clamp_max + + // Apply the clamp_min bound + "vmax.s16 q14, q14, q12\n" + // Apply the clamp_max bound + "vmin.s16 q14, q14, q13\n" + + RUY_MAKE_ZERO(q12) + RUY_MAKE_ZERO(q13) + + // Compute how much of the 4x2 block of destination 16-bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 4x2, there are some 4x2 blocks along the boundaries that do + // not fit entirely. + + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n" + "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + "sub r1, r1, r8\n" + + "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n" + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" + "sub r2, r2, r4\n" + "mov r3, #4\n" + "mov r5, #2\n" + "cmp r1, #4\n" + // Compute r1 = how many rows of the 4x2 block fit + "it gt\n" + "movgt r1, r3\n" + + "cmp r2, #2\n" + // Compute r2 = how many cols of the 4x2 block fit + "it gt\n" + "movgt r2, r5\n" + + // Test if r1==4 && r2 == 2, i.e. if all of the 4x2 block fits. + "cmp r1, r3\n" + "it eq\n" + "cmpeq r2, r5\n" + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" + // Yes, all of the 4x2 block fits, go to fast path. + "beq 30f\n" + // Not all of the 4x2 block fits. + // Store to dst_tmp_buf + // Set r3 address to write to dst_tmp_buf. + "mov r3, %[dst_tmp_buf]\n" + "vst1.16 {q14}, [r3]\n" + + // Slow loop copying from dst_tmp_buf to dst. + "mov r6, #0\n" + "50:\n" + "mov r8, #0\n" + "51:\n" + // Shift of offset register for half-word loads not allowed in A32, + // so we shift, load/store, then shift back r8. + "lsl r8, r8, #1\n" + "ldrh r10, [r3, r8]\n" + "strh r10, [r4, r8]\n" + "lsr r8, r8, #1\n" + "add r8, r8, #1\n" + "cmp r8, r1\n" + "blt 51b\n" + "add r6, r6, #1\n" + "add r3, r3, #8\n" + "add r4, r4, r5\n" + "cmp r6, r2\n" + "blt 50b\n" + "b 31f\n" + "30:\n" + // Yes, all of the 4x2 block fits. + // r3 address, r5 stride + "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "mov r4, r3\n" + "mov r6, #2\n" + + "vst1.16 {d28[0]}, [r3], r6\n" + "add r4, r4, r5\n" + "vst1.16 {d28[1]}, [r3], r6\n" + "vst1.16 {d28[2]}, [r3], r6\n" + "vst1.16 {d28[3]}, [r3], r6\n" + "mov r3, r4\n" + "vst1.16 {d29[0]}, [r3], r6\n" + "vst1.16 {d29[1]}, [r3], r6\n" + "vst1.16 {d29[2]}, [r3], r6\n" + "vst1.16 {d29[3]}, [r3], r6\n" + "31:\n" + + // Load dst_ptr, increment, and write back. + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "add r4, r4, #8\n" + "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + + RUY_MAKE_ZERO(q14) + + "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n" + + // Since the store type is the same as the accum type, no need for + // downcast. There's also no need for clamp by min/max. + + // At this point, v20 -- v31 aren't used anymore for the current block, + // so we can start clearing these accumulators for the next block + // (next iteration of the main loop). + // Clear accumulators. + RUY_MAKE_ZERO(q6) + RUY_MAKE_ZERO(q7) + RUY_MAKE_ZERO(q8) + RUY_MAKE_ZERO(q9) + RUY_MAKE_ZERO(q10) + RUY_MAKE_ZERO(q11) + RUY_MAKE_ZERO(q12) + RUY_MAKE_ZERO(q13) + + // Compute how much of the 4x2 block of destination 32 bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 4x2, there are some 4x4 blocks along the boundaries that do + // not fit entirely. + + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n" + "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + "sub r1, r1, r8\n" + + "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n" + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" + "sub r2, r2, r4\n" + "mov r3, #4\n" + "mov r5, #2\n" + "cmp r1, #4\n" + // Compute r1 = how many rows of the 4x2 block fit + "it gt\n" + "movgt r1, r3\n" + + "cmp r2, #2\n" + // Compute r2 = how many cols of the 4x2 block fit + "it gt\n" + "movgt r2, r5\n" + + // Test if r1==4 && r2 == 2, i.e. if all of the 4x2 block fits. + "cmp r1, r3\n" + "it eq\n" + "cmpeq r2, r5\n" + // Yes, all of the 4x2 block fits, go to fast path. + "beq 30f\n" + // Not all of the 4x2 block fits. + // Set (r3 address, r4 stride) to write to dst_tmp_buf + "mov r3, %[dst_tmp_buf]\n" + "mov r4, #16\n" + "b 31f\n" + + "30:\n" + // Yes, all of the 4x2 block fits. + "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" + // r3 address, r4 stride + "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "mov r4, r5\n" + + "31:\n" + + "vst1.32 {d28, d29}, [r3]\n" + "add r3, r3, r4\n" + "vst1.32 {d30, d31}, [r3]\n" + + // If all of the 4x2 block fits, we just finished writing it to the + // destination, so we skip the next part. + "beq 41f\n" + // Not all of the 4x2 block fits in the destination matrix. We just + // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over + // it to copy into the destination matrix the part that fits. + "ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" + "mov r3, %[dst_tmp_buf]\n" + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "mov r6, #0\n" + "50:\n" + "mov r5, #0\n" + "51:\n" + "ldr r10, [r3, r5, lsl #2]\n" + "str r10, [r4, r5, lsl #2]\n" + "add r5, r5, #1\n" + "cmp r5, r1\n" + "blt 51b\n" + "add r6, r6, #1\n" + "add r3, r3, #16\n" + "add r4, r4, r8\n" + // r2 = how many cols of the 8x4 block fit + "cmp r6, r2\n" + "blt 50b\n" + + "41:\n" + // Load dst_ptr, increment, and write back. + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "add r4, r4, #16\n" + "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + + RUY_MAKE_ZERO(q10) + RUY_MAKE_ZERO(q11) + + "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" + + RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n" + + // Reload some params --- we had used x5 -- x7 for a few other things + // since the last time we had loaded them. + "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + "ldr r6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" + "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" + + // Move to the next block of the destination matrix, for the next iter + // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already + // been updated earlier. + // Have we reached the end row? + "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + "cmp r8, r3\n" + + "beq 20f\n" // yes, end row. + // Not end row. Move to the next row. + "add r8, r8, #4\n" + // Store new value of row + "str r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + + "b 21f\n" + "20:\n" + // Was already at end row. + // Move back to first row. + "str r6, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + // Move to the next column. + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" + "add r4, r4, #2\n" + "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" + + "ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" + "ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n" + // Increment dst_col_ptr by 2 * dst_stride (i.e. 2 columns) + "add r1, r1, r8, lsl #1\n" + // Store dst_col_ptr + "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n" + // Store dst_ptr + "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "21:\n" + + // Main loop exit condition: have we hit the end column? + "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" + "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" + "cmp r8, r4\n" + + // w1 is the number of levels of depth that we have already loaded + // LHS and RHS data for. Corresponding to the initial ld1 instructions + // above, this is currently 16. + "mov r1, #16\n" + + "ble 1b\n" + + // Restore stack pointer. + "add sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n" + + // clang-format on + + : [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr) + : [ params ] "r"(¶ms), [dst_tmp_buf] "r"(params.dst_tmp_buf) + : "r0", "r1", "r2", "r3", "r4", "r5", "r6", "r8", "r10", "cc", + // Clobber list must specify q registers (and not their constituent + // d registers). There is a (currently unexplained) slowdown if + // d registers are listed in the clobbers list. + "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", + "q9", "q10", "q12", "q13", "q14", "q15"); +} + +// Fast-int8 true "GEMV" kernel (RHS has 1 column). We assume the RHS +// is still packed as if it has two columns +void Kernel8bitNeonOutOfOrder1Col(const KernelParams8bit<4, 2>& params) { + profiler::ScopeLabel label( + "Kernel (kNeon, optimized for out-of-order cores)"); + + CheckOffsetsInKernelParams8bit(params); + + const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; + const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; + const std::int8_t* lhs_ptr = lhs_col_ptr; + const std::int8_t* rhs_ptr = rhs_col_ptr; + + // The asm kernel below has the following NEON register allocation: + // + // q6 - q13 are 128-bit (4x32b) accumulators. + // During accumulation, d0 -- d7 are used to load int8 data from LHS and + // d8 -- d11 from RHS: + // int8 RHS 16x1 block + // /------------\ + // | d8.b[0] | + // | ... | + // | d8.b[7] | + // | d9.b[0] | + // | ... | + // | d9.b[7] | + // \------------/ + // int8 LHS 4x16 block + // /-----------------------------------------\ /------------\ + // |d0.b[0] ... d0.b[7] d1.b[0] ... d1.b[7] | | q6 | + // |d2.b[0] ... d2.b[7] d3.b[0] ... d3.b[7] | | q7 | + // |d4.b[0] ... d4.b[7] d5.b[0] ... d5.b[7] | | q8 | + // |d6.b[0] ... d6.b[7] d7.b[0] ... d7.b[7] | | q9 | + // \-----------------------------------------/ \------------/ + // 128-bit accumulators 4x1 block + // + // No attempt had been made so far at implementing the RUY_OPT_MAX_STREAMING + // optimization for this kernel. + asm volatile( +#define RUY_MAKE_ZERO(reg) "vmov.i32 " #reg ", #0x00000000\n" + + // clang-format off + + // Load the first 64 bytes of LHS and RHS data. + "vld1.8 {d0, d1}, [%[lhs_ptr]]!\n" + "vld1.8 {d2, d3}, [%[lhs_ptr]]!\n" + "vld1.8 {d4, d5}, [%[lhs_ptr]]!\n" + "vld1.8 {d6, d7}, [%[lhs_ptr]]!\n" + "vld1.8 {d8, d9}, [%[rhs_ptr]]!\n" + // Skip the other column and advance the pointer. + "add %[rhs_ptr], %[rhs_ptr], #16\n" + + "sub sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n" + + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n" + "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n" + + "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n" + "str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" + "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + + "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_START_COL) "]\n" + "str r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" + + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" + + "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_RHS_BASE_PTR) "]\n" + "str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" + + // Clear accumulators. + RUY_MAKE_ZERO(q6) + RUY_MAKE_ZERO(q7) + RUY_MAKE_ZERO(q8) + RUY_MAKE_ZERO(q9) + RUY_MAKE_ZERO(q10) + RUY_MAKE_ZERO(q11) + RUY_MAKE_ZERO(q12) + RUY_MAKE_ZERO(q13) + RUY_MAKE_ZERO(q14) + RUY_MAKE_ZERO(q15) + + // r1 is the number of levels of depth that we have already loaded + // LHS and RHS data for. Corresponding to the initial ld1 instructions + // above, this is currently 16. + "mov r1, #16\n" + + // Main loop of the whole GEMM, over rows and columns of the + // destination matrix. + "1:\n" + + // r1 is how many levels of depth we have already loaded + // data for, r10 is the total depth. + "ldr r10, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" + "cmp r1, r10\n" + "beq 79f\n" + + "2:\n" + + // Mult, mult-acc in to q14, q15 + "vmull.s8 q14, d0, d8\n" + "vmull.s8 q15, d2, d8\n" + "vmlal.s8 q14, d1, d9\n" + "vmlal.s8 q15, d3, d9\n" + + // Then pairwise accumulate in to q6, q7 + "vpadal.s16 q6, q14\n" + "vpadal.s16 q7, q15\n" + + // Mult, mult-acc in to q14, q15 + "vmull.s8 q14, d4, d8\n" + "vmull.s8 q15, d6, d8\n" + "vmlal.s8 q14, d5, d9\n" + "vmlal.s8 q15, d7, d9\n" + + // Then pairwise accumulate in to q8, q9 + "vpadal.s16 q8, q14\n" + "vpadal.s16 q9, q15\n" + + + // Load the next 64 bytes of LHS and RHS data. + "vld1.8 {d0, d1}, [%[lhs_ptr]]!\n" + "vld1.8 {d2, d3}, [%[lhs_ptr]]!\n" + "vld1.8 {d4, d5}, [%[lhs_ptr]]!\n" + "vld1.8 {d6, d7}, [%[lhs_ptr]]!\n" + RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n") + "vld1.8 {d8, d9}, [%[rhs_ptr]]!\n" + // Skip the other column and advance the pointer. + "add %[rhs_ptr], %[rhs_ptr], #16\n" + RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n") + + // Each iteration of this loop advances by 16 levels of depth. + "add r1, r1, #16\n" + + // Loop termination condition + "cmp r1, r10\n" + + "blt 2b\n" + + "79:\n" + + // Mult, mult-acc in to q14, q15 + "vmull.s8 q14, d0, d8\n" + "vmull.s8 q15, d2, d8\n" + "vmlal.s8 q14, d1, d9\n" + "vmlal.s8 q15, d3, d9\n" + + // Then pairwise accumulate in to q6, q7 + "vpadal.s16 q6, q14\n" + "vpadal.s16 q7, q15\n" + + // Mult, mult-acc in to q14, q15 + "vmull.s8 q14, d4, d8\n" + "vmull.s8 q15, d6, d8\n" + "vmlal.s8 q14, d5, d9\n" + "vmlal.s8 q15, d7, d9\n" + + // Then pairwise accumulate in to q8, q9 + "vpadal.s16 q8, q14\n" + "vpadal.s16 q9, q15\n" + + // All accumulation over depth done. q6 - q9 contain the 4x32b + // accumulators for the 4x1 final matrix. + // We now have to compute the final 8-bit values from these int32 + // accumulators, and advance to the next 4x2 block. We intertwine + // these two aspects whenever possible for optimal pipelining, both + // at the data flow level (prefetch data for next block as early as + // possible) and instruction pipelining level (some of the next-block + // work can dual-issue with some of the final work on the current + // block). + + // q6-q9 now contain 4 x 32b + "vpadd.i32 d0, d12, d13\n" + "vpadd.i32 d1, d14, d15\n" + "vpadd.i32 d2, d16, d17\n" + "vpadd.i32 d3, d18, d19\n" + + // d0-d4 each contain 2 x 32b accumulators. + // Need to add pairwise to get 1 x 32b for each of the 4x1 entries + // of destination, (Four 'd' registers total) + "vpadd.i32 d28, d0, d1\n" + "vpadd.i32 d29, d2, d3\n" + + // Now d28,d29 have the 1 x 32b accumulators for the 4x1 entries. + + // Logic to advance to the next block in preparation for the next + // iteration of the main loop. For now, we only want to compute + // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are + // not yet ready to update the values of row and col, as we still need + // the current values for the rest of the work on the current block. + + "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" + "ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + "cmp r1, r3\n" // Have we finished the last row? + + "bge 4f\n" // If finished last row, go to 4 + // Not finished last row: then advance to next row. + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" + "add r4, r4, r1, lsl #2\n" + "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" + "b 5f\n" + "4:\n" // Finished last row... + "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + // Go back to first row + "str r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" + + // Now we need to advance to the next column. If we already + // finished the last column, then in principle we are done, however + // we can't just return here, as we need to allow the end work of the + // current block to complete. The good news is that at this point it + // doesn't matter what data we load for the next column, since + // we will exit from the main loop below before actually storing + // anything computed from that data. + + "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" + "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" + "cmp r8, r4\n" // Have we finished the last column? + "bge 5f\n" // If yes, just carry on without updating the column pointer. + // Not finished last column: then advance to next column. + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" + "ldr r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" + "add r10, r10, r1, lsl #1\n" + "str r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" + "5:\n" + + // Set the LHS and RHS data pointers to the start of the columns just + // computed. + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" + "mov %[lhs_ptr], r4\n" + "ldr r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" + "mov %[rhs_ptr], r5\n" + + // Now we load: bias data, LHS sums data, RHS sums data. + + // First, load the base pointers from the params. + "ldrb r4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" + + // Offset these base pointers as needed given the current row, col. + "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + "add r5, r1, r8, lsl #2\n" + + "tst r4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" + "it ne\n" + "movne r1, r5\n" + + // Load 4 bias values. + "vld1.32 {d24, d25}, [r1]\n" + + // Now that we know what LHS and RHS data the next iteration of the + // main loop will need to load, we start loading the first 32 bytes of + // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore + // in the rest of the work on the current block. + "vld1.8 {d0, d1}, [%[lhs_ptr]]!\n" + "vld1.8 {d2, d3}, [%[lhs_ptr]]!\n" + "vld1.8 {d4, d5}, [%[lhs_ptr]]!\n" + "vld1.8 {d6, d7}, [%[lhs_ptr]]!\n" + RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n") + "vld1.8 {d8, d9}, [%[rhs_ptr]]!\n" + // Skip the other column and advance the pointer. + "add %[rhs_ptr], %[rhs_ptr], #16\n" + RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n") + + // Add to the bias values the product + // (depth * lhs_zero_point * rhs_zero_point), + // See the term NZ1Z2 in equation (7) in + // https://arxiv.org/pdf/1712.05877.pdf + "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n" + "vdup.32 q9, r3\n" + "vadd.i32 q12, q12, q9\n" + + // Perform the bias-addition (per the above, we have just folded into + // the bias the (depth * lhs_zero_point * rhs_zero_point) term.) + "vadd.i32 q14, q14, q12\n" + + // LHS/RHS zero points + // Has RHS sums + "ldrb r6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" + "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n" + "beq 401f\n" + "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n" + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" + // Offset by current col * number of bytes per value + "add r3, r3, r4, lsl #2\n" + "vld1.32 { d12 }, [r3]\n" + "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n" + "vdup.32 q10, r5\n" // create lhs_zero_point_vec + // Subtract rhs_sums * lhs_zero_point, per + // equation (7) in https://arxiv.org/pdf/1712.05877.pdf + "vmls.i32 q14, q10, d12[0]\n" + "401:\n" + + // Has LHS sums + "ldrb r6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" + "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n" + "beq 402f\n" + "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n" + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + // Offset by current row * number of bytes per value + "add r2, r2, r4, lsl #2\n" + "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n" + + // Load 4 lhs_sums values. + "vld1.32 {d22, d23}, [r2]\n" + "vdup.32 d13, r5\n" // rhs_zero_point + + // Compute lhs_sums * rhs_zero_point. + "vmul.i32 q11, q11, d13[1]\n" + // Subtract lhs_sums * rhs_zero_point, per + // equation (7) in https://arxiv.org/pdf/1712.05877.pdf + "vsub.s32 q14, q14, q11\n" + + // If the destination is int32, it means the user asks for the raw + // accumulators, no need for us to downquantize the value. + "ldrb r10, [%[params], #" RUY_STR(RUY_OFFSET_DST_TYPE_ID) "]\n" + "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n" + "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n" + + "402:\n" + + // At this point we have computed the final int32 values. Now we + // start down-quantizing them to obtain the final 8bit values from them. + + // As part of this down-quantization, our int32 values will be + // multiplied by a multiplier that has a fixed-point component and an + // exponent component. + + //Load the exponent part of the multiplier. + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n" + "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + "add r5, r1, r4, lsl #2\n" + "it ne\n" + "movne r1, r5\n" + + "vld1.32 {q10}, [r1]\n" + + RUY_MAKE_ZERO(q8) + "vmax.s32 q12, q10, q8\n" + + "vshl.s32 q14, q14, q12\n" + + "vmin.s32 q12, q10, q8\n" + + // Load fixed point part of the multiplier + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n" + // r6 has flags, r4 has row + "add r5, r1, r4, lsl #2\n" + "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" + "it ne\n" + "movne r1, r5\n" + "vld1.32 {q10}, [r1]\n" // multiplier_fixedpoint + + // Apply the fixed-point part of the multiplier. + "vqrdmulh.s32 q14, q14, q10\n" + + // We have some rounding division-by-power-of-two to do. This should + // always use "round to nearest". We allow for some + // freedom in how ties are broken, to strike a good compromise of + // performance on given hardware vs. perfect agreement of results + // across hardware. + // + // When RUY_OPT_NATIVE_ROUNDING is enabled, we allow for implementation + // defined tie-breaks to help performance. On NEON, this means that we + // can just use the NEON rounding instructions, such as srshl. They + // happen to be breaking ties upward. + // + // When RUY_OPT_NATIVE_ROUNDING is disabled, we implement strict + // break-ties-away-from zero, as described in Appendix B of + // https://arxiv.org/pdf/1712.05877.pdf + // When we wrote that, we thought that that would be better unbiased + // than the NEON upwards tie-breaks, and we had observed some + // improvement on some model. However, that is only more unbiased for + // data centered at zero, which was likely the case in that model, + // but is not always the case. If we wanted something more consistently + // unbiased then we should try breaking ties toward-nearest-even. +#if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING) + // Fix up values to be right-shifted, so that the (round to nearest, + // break ties upward) behavior of srshl applied to these fixed-up + // values, produces the same result as the desired (round to nearest, + // break ties away from zero) behavior on the original values. + "vand q8, q14, q12\n" + "vshr.s32 q8, q8, #31\n" + "vqadd.s32 q14, q14, q8\n" + +#endif + // At this point we have reduced the problem of correctly implementing + // rounding divide-by-power-of-two, to what the SRSHL instruction can + // do. + "vrshl.s32 q14, q14, q12\n" + + "ldrb r10, [%[params], #" RUY_STR(RUY_OFFSET_DST_TYPE_ID) "]\n" + "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n" + "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n" + "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n" + "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n" + + // Store uint8 values: + RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n" + + // Cast-and-saturate from int32 to int16 + // After this, all values for output are in d28. + "vqmovn.s32 d28, q14\n" + + // At this point, d12 -- d26, d29, d30, d31 aren't used anymore for the + // current block, so we can start clearing these accumulators for the + // next block (next iteration of the main loop). + RUY_MAKE_ZERO(q6) + RUY_MAKE_ZERO(q7) + RUY_MAKE_ZERO(q8) + RUY_MAKE_ZERO(q9) + RUY_MAKE_ZERO(q10) + RUY_MAKE_ZERO(q11) + RUY_MAKE_ZERO(q12) + RUY_MAKE_ZERO(q13) + RUY_MAKE_ZERO(q15) + + // Load the destination zero point into each of the 8 16-bit slots + // in a q register. + "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" + "vdup.16 q13, r4\n" // dst_zero_point + + // Add the destination zero point + "vadd.i16 q14, q14, q13\n" + + // Cast-and-saturate from int16 to uint8 + "vqmovun.s16 d30, q14\n" + // At this point, we only need 4 8-bit values in the lower half + // of d30. + + + // Load the clamp_min, clamp_max bounds + "ldrb r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + "ldrb r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + "vdup.8 d28, r2\n" // clamp_min + "vdup.8 d29, r3\n" // clamp_max + + // Apply the clamp_min bound + "vmax.u8 d30, d30, d28\n" + // Apply the clamp_max bound + "vmin.u8 d30, d30, d29\n" + + // Compute how much of the 4x1 block of destination 8bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 4x1, there are some 4x1 blocks along the boundaries that do + // not fit entirely. + + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n" + "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + "sub r1, r1, r8\n" + + "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n" + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" + "sub r2, r2, r4\n" + "mov r3, #4\n" + "mov r5, #2\n" + "cmp r1, #4\n" + // Compute r1 = how many rows of the 4x1 block fit + "it gt\n" + "movgt r1, r3\n" + + // Test if r1==4, i.e. if all of the 4x1 block fits. + "cmp r1, r3\n" + + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" + // Yes, all of the 4x1 block fits, go to fast path. + "beq 30f\n" + // Not all of the 4x1 block fits. + // Store to dst_tmp_buf + // Set r3 address to write to dst_tmp_buf. + "mov r3, %[dst_tmp_buf]\n" + "vst1.8 {d30}, [r3]\n" + + // Slow loop copying from dst_tmp_buf to dst. + "50:\n" + "mov r8, #0\n" + "51:\n" + "ldrb r10, [r3, r8]\n" + "strb r10, [r4, r8]\n" + "add r8, r8, #1\n" + "cmp r8, r1\n" + "blt 51b\n" + "b 31f\n" + "30:\n" + // Yes, all of the 4x1 block fits. + // r3 address, r5 stride + "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "mov r4, r3\n" + "mov r6, #1\n" + + "vst1.8 {d30[0]}, [r3], r6\n" + "vst1.8 {d30[1]}, [r3], r6\n" + "vst1.8 {d30[2]}, [r3], r6\n" + "vst1.8 {d30[3]}, [r3], r6\n" + "31:\n" + + // Load dst_ptr, increment, and write back. + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "add r4, r4, #4\n" + "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + + RUY_MAKE_ZERO(q13) + RUY_MAKE_ZERO(q14) + RUY_MAKE_ZERO(q15) + + "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" + + // Store int8 values: + RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n" + + // Cast-and-saturate from int32 to int16 + // After this, all values for output are in d28. + "vqmovn.s32 d28, q14\n" + + // At this point, d12 -- d26, d29, d30, d31 aren't used anymore for the + // current block, so we can start clearing these accumulators for the + // next block (next iteration of the main loop). + RUY_MAKE_ZERO(q6) + RUY_MAKE_ZERO(q7) + RUY_MAKE_ZERO(q8) + RUY_MAKE_ZERO(q9) + RUY_MAKE_ZERO(q10) + RUY_MAKE_ZERO(q11) + RUY_MAKE_ZERO(q12) + RUY_MAKE_ZERO(q13) + RUY_MAKE_ZERO(q15) + + // Load the destination zero point into each of the 8 16-bit slots + // in a q register. + "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" + "vdup.16 q13, r4\n" // dst_zero_point + + // Add the destination zero point + "vadd.i16 q14, q14, q13\n" + + // Cast-and-saturate from int16 to int8 + "vqmovn.s16 d30, q14\n" + // At this point, we only need 4 8-bit values in the lower half + // of d30. + + // Load the clamp_min, clamp_max bounds + "ldrb r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + "ldrb r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + "vdup.8 d28, r2\n" // clamp_min + "vdup.8 d29, r3\n" // clamp_max + + // Apply the clamp_min bound + "vmax.s8 d30, d30, d28\n" + // Apply the clamp_max bound + "vmin.s8 d30, d30, d29\n" + + // Compute how much of the 4x1 block of destination 8bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 4x2, there are some 4x2 blocks along the boundaries that do + // not fit entirely. + + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n" + "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + "sub r1, r1, r8\n" + + "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n" + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" + "sub r2, r2, r4\n" + "mov r3, #4\n" + "mov r5, #2\n" + "cmp r1, #4\n" + // Compute r1 = how many rows of the 4x2 block fit + "it gt\n" + "movgt r1, r3\n" + + // Test if r1==4 i.e. if all of the 4x1 block fits. + "cmp r1, r3\n" + + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" + // Yes, all of the 4x2 block fits, go to fast path. + "beq 30f\n" + // Not all of the 4x2 block fits. + // Store to dst_tmp_buf + // Set r3 address to write to dst_tmp_buf. + "mov r3, %[dst_tmp_buf]\n" + "vst1.8 {d30}, [r3]\n" + + // Slow loop copying from dst_tmp_buf to dst. + "50:\n" + "mov r8, #0\n" + "51:\n" + "ldrb r10, [r3, r8]\n" + "strb r10, [r4, r8]\n" + "add r8, r8, #1\n" + "cmp r8, r1\n" + "blt 51b\n" + "b 31f\n" + "30:\n" + // Yes, all of the 4x1 block fits. + // r3 address, r5 stride + "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "mov r4, r3\n" + "mov r6, #1\n" + + "vst1.8 {d30[0]}, [r3], r6\n" + "vst1.8 {d30[1]}, [r3], r6\n" + "vst1.8 {d30[2]}, [r3], r6\n" + "vst1.8 {d30[3]}, [r3], r6\n" + "31:\n" + + // Load dst_ptr, increment, and write back. + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "add r4, r4, #4\n" + "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + + RUY_MAKE_ZERO(q13) + RUY_MAKE_ZERO(q14) + RUY_MAKE_ZERO(q15) + + "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n" + + // Load the destination zero point into each of the 4 32-bit slots + // in a q register. + "ldrsh r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" + "vdup.32 q13, r4\n" // dst_zero_point + // Add the destination zero point + "vadd.s32 q14, q14, q13\n" + //"vadd.s32 q15, q15, q13\n" + + // Cast-and-saturate from int32 to int16 + // After this, all values for output are in d28. + "vqmovn.s32 d28, q14\n" + + // At this point, d12 -- d26, d29, d30, d31 aren't used anymore for the + // so we can start clearing these accumulators for the next block + // (next iteration of the main loop). + RUY_MAKE_ZERO(q6) + RUY_MAKE_ZERO(q7) + RUY_MAKE_ZERO(q8) + RUY_MAKE_ZERO(q9) + RUY_MAKE_ZERO(q10) + RUY_MAKE_ZERO(q11) + RUY_MAKE_ZERO(q15) + + // Load the clamp_min, clamp_max bounds + "ldrh r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + "ldrh r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + "vdup.16 d24, r2\n" // clamp_min + "vdup.16 d26, r3\n" // clamp_max + + // Apply the clamp_min bound + "vmax.s16 d28, d28, d24\n" + // Apply the clamp_max bound + "vmin.s16 d28, d28, d26\n" + + RUY_MAKE_ZERO(q12) + RUY_MAKE_ZERO(q13) + + // Compute how much of the 4x1 block of destination 16-bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 4x1, there are some 4x1 blocks along the boundaries that do + // not fit entirely. + + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n" + "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + "sub r1, r1, r8\n" + + "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n" + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" + "sub r2, r2, r4\n" + "mov r3, #4\n" + "mov r5, #2\n" + "cmp r1, #4\n" + // Compute r1 = how many rows of the 4x1 block fit + "it gt\n" + "movgt r1, r3\n" + + // Test if r1==4, i.e. if all of the 4x1 block fits. + "cmp r1, r3\n" + + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" + // Yes, all of the 4x1 block fits, go to fast path. + "beq 30f\n" + // Not all of the 4x1 block fits. + // Store to dst_tmp_buf + // Set r3 address to write to dst_tmp_buf. + "mov r3, %[dst_tmp_buf]\n" + "vst1.16 {d28}, [r3]\n" + + // Slow loop copying from dst_tmp_buf to dst. + "50:\n" + "mov r8, #0\n" + "51:\n" + // Shift of offset register for half-word loads not allowed in A32, + // so we shift, load/store, then shift back r8. + "lsl r8, r8, #1\n" + "ldrh r10, [r3, r8]\n" + "strh r10, [r4, r8]\n" + "lsr r8, r8, #1\n" + "add r8, r8, #1\n" + "cmp r8, r1\n" + "blt 51b\n" + "b 31f\n" + "30:\n" + // Yes, all of the 4x1 block fits. + // r3 address, r5 stride + "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "mov r4, r3\n" + "mov r6, #2\n" + + "vst1.16 {d28[0]}, [r3], r6\n" + "vst1.16 {d28[1]}, [r3], r6\n" + "vst1.16 {d28[2]}, [r3], r6\n" + "vst1.16 {d28[3]}, [r3], r6\n" + "31:\n" + + // Load dst_ptr, increment, and write back. + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "add r4, r4, #8\n" + "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + + RUY_MAKE_ZERO(q14) + + "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n" + + // Since the store type is the same as the accum type, no need for + // downcast. There's also no need for clamp by min/max. + + // At this point, v20 -- v31 aren't used anymore for the current block, + // so we can start clearing these accumulators for the next block + // (next iteration of the main loop). + // Clear accumulators. + RUY_MAKE_ZERO(q6) + RUY_MAKE_ZERO(q7) + RUY_MAKE_ZERO(q8) + RUY_MAKE_ZERO(q9) + RUY_MAKE_ZERO(q10) + RUY_MAKE_ZERO(q11) + RUY_MAKE_ZERO(q12) + RUY_MAKE_ZERO(q13) + + // Compute how much of the 4x1 block of destination 32 bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 4x2, there are some 4x4 blocks along the boundaries that do + // not fit entirely. + + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n" + "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + "sub r1, r1, r8\n" + + "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n" + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" + "sub r2, r2, r4\n" + "mov r3, #4\n" + "mov r5, #2\n" + "cmp r1, #4\n" + // Compute r1 = how many rows of the 4x2 block fit + "it gt\n" + "movgt r1, r3\n" + + // Test if r1==4, i.e. if all of the 4x1 block fits. + "cmp r1, r3\n" + + // Yes, all of the 4x1 block fits, go to fast path. + "beq 30f\n" + // Not all of the 4x1 block fits. + // Set (r3 address, r4 stride) to write to dst_tmp_buf + "mov r3, %[dst_tmp_buf]\n" + "mov r4, #16\n" + "b 31f\n" + + "30:\n" + // Yes, all of the 4x1 block fits. + "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" + // r3 address, r4 stride + "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "mov r4, r5\n" + + "31:\n" + + "vst1.32 {d28, d29}, [r3]\n" + + // If all of the 4x1 block fits, we just finished writing it to the + // destination, so we skip the next part. + "beq 41f\n" + // Not all of the 4x1 block fits in the destination matrix. We just + // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over + // it to copy into the destination matrix the part that fits. + "ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" + "mov r3, %[dst_tmp_buf]\n" + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "50:\n" + "mov r5, #0\n" + "51:\n" + "ldr r10, [r3, r5, lsl #2]\n" + "str r10, [r4, r5, lsl #2]\n" + "add r5, r5, #1\n" + "cmp r5, r1\n" + "blt 51b\n" + + "41:\n" + // Load dst_ptr, increment, and write back. + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "add r4, r4, #16\n" + "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + + RUY_MAKE_ZERO(q10) + RUY_MAKE_ZERO(q11) + + "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" + + RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n" + + // Reload some params --- we had used x5 -- x7 for a few other things + // since the last time we had loaded them. + "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + "ldr r6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" + "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" + + // Move to the next block of the destination matrix, for the next iter + // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already + // been updated earlier. + // Have we reached the end row? + "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + "cmp r8, r3\n" + + "beq 20f\n" // yes, end row. + // Not end row. Move to the next row. + "add r8, r8, #4\n" + // Store new value of row + "str r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + + "b 21f\n" + "20:\n" + // Was already at end row. + // Move back to first row. + "str r6, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + // Move to the next column. + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" + "add r4, r4, #2\n" + "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" + + "ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" + "ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n" + // Increment dst_col_ptr by dst_stride (i.e. 1 column) + "add r1, r1, r8\n" + // Store dst_col_ptr + "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n" + // Store dst_ptr + "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "21:\n" + + // Main loop exit condition: have we hit the end column? + "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" + "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" + "cmp r8, r4\n" + + // w1 is the number of levels of depth that we have already loaded + // LHS and RHS data for. Corresponding to the initial ld1 instructions + // above, this is currently 16. + "mov r1, #16\n" + + "ble 1b\n" + + // Restore stack pointer. + "add sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n" + + // clang-format on + + : [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr) + : [ params ] "r"(¶ms), [dst_tmp_buf] "r"(params.dst_tmp_buf) + : "r0", "r1", "r2", "r3", "r4", "r5", "r6", "r8", "r10", "cc", + // Clobber list must specify q registers (and not their constituent + // d registers). There is a (currently unexplained) slowdown if + // d registers are listed in the clobbers list. + "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", + "q9", "q10", "q12", "q13", "q14", "q15"); +} + +#undef RUY_OFFSET_BIAS +#undef RUY_OFFSET_LHS_SUMS +#undef RUY_OFFSET_RHS_SUMS +#undef RUY_OFFSET_LHS_BASE_PTR +#undef RUY_OFFSET_MULTIPLIER_FIXEDPOINT +#undef RUY_OFFSET_MULTIPLIER_EXPONENT +#undef RUY_OFFSET_RHS_BASE_PTR +#undef RUY_OFFSET_DST_BASE_PTR +#undef RUY_OFFSET_LHS_ZERO_POINT +#undef RUY_OFFSET_RHS_ZERO_POINT +#undef RUY_OFFSET_DST_ZERO_POINT +#undef RUY_OFFSET_PROD_ZP_DEPTH +#undef RUY_OFFSET_START_ROW +#undef RUY_OFFSET_START_COL +#undef RUY_OFFSET_LAST_ROW +#undef RUY_OFFSET_LAST_COL +#undef RUY_OFFSET_DST_ROWS +#undef RUY_OFFSET_DST_COLS +#undef RUY_OFFSET_LHS_STRIDE +#undef RUY_OFFSET_RHS_STRIDE +#undef RUY_OFFSET_DST_STRIDE +#undef RUY_OFFSET_DEPTH +#undef RUY_OFFSET_CLAMP_MIN +#undef RUY_OFFSET_CLAMP_MAX +#undef RUY_OFFSET_FLAGS +#undef RUY_OFFSET_DST_TYPE_ID + +#undef RUY_STACK_OFFSET_SIZE +#undef RUY_STACK_OFFSET_DST_COL_PTR +#undef RUY_STACK_OFFSET_DST_PTR +#undef RUY_STACK_OFFSET_ROW +#undef RUY_STACK_OFFSET_COL +#undef RUY_STACK_OFFSET_LHS_COL_PTR +#undef RUY_STACK_OFFSET_RHS_COL_PTR + +#endif // RUY_PLATFORM(NEON_32) && (RUY_OPT_ENABLED(RUY_OPT_ASM) +} // namespace ruy diff --git a/ruy/kernel_arm64.cc b/ruy/kernel_arm64.cc new file mode 100644 index 0000000..38af032 --- /dev/null +++ b/ruy/kernel_arm64.cc @@ -0,0 +1,7835 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "ruy/common.h" +#include "ruy/kernel.h" +#include "ruy/opt_set.h" +#include "ruy/platform.h" +#include "ruy/profiler/instrumentation.h" + +namespace ruy { + +#if RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM) + +#define RUY_ASM_LABEL_STORE_UINT8 91 +#define RUY_ASM_LABEL_STORE_INT8 92 +#define RUY_ASM_LABEL_STORE_INT16 93 +#define RUY_ASM_LABEL_STORE_INT32 94 +#define RUY_ASM_LABEL_AFTER_STORE 99 + +#define RUY_OFFSET_BIAS 0 +#define RUY_OFFSET_LHS_SUMS 8 +#define RUY_OFFSET_RHS_SUMS 16 +#define RUY_OFFSET_LHS_BASE_PTR 24 +#define RUY_OFFSET_MULTIPLIER_FIXEDPOINT 32 +#define RUY_OFFSET_MULTIPLIER_EXPONENT 40 +#define RUY_OFFSET_RHS_BASE_PTR 48 +#define RUY_OFFSET_DST_BASE_PTR 56 +#define RUY_OFFSET_LHS_ZERO_POINT 64 +#define RUY_OFFSET_RHS_ZERO_POINT 68 +#define RUY_OFFSET_DST_ZERO_POINT 72 +#define RUY_OFFSET_PROD_ZP_DEPTH 76 +#define RUY_OFFSET_START_ROW 80 +#define RUY_OFFSET_START_COL 84 +#define RUY_OFFSET_LAST_ROW 88 +#define RUY_OFFSET_LAST_COL 92 +#define RUY_OFFSET_DST_ROWS 96 +#define RUY_OFFSET_DST_COLS 100 +#define RUY_OFFSET_LHS_STRIDE 104 +#define RUY_OFFSET_RHS_STRIDE 108 +#define RUY_OFFSET_DST_STRIDE 112 +#define RUY_OFFSET_DEPTH 116 +#define RUY_OFFSET_CLAMP_MIN 120 +#define RUY_OFFSET_CLAMP_MAX 124 +#define RUY_OFFSET_FLAGS 128 + +template +void CheckOffsetsInKernelParams8bit(const Params&) { + static_assert(offsetof(Params, lhs_zero_point) == RUY_OFFSET_LHS_ZERO_POINT, + ""); + static_assert(offsetof(Params, rhs_zero_point) == RUY_OFFSET_RHS_ZERO_POINT, + ""); + static_assert(offsetof(Params, dst_zero_point) == RUY_OFFSET_DST_ZERO_POINT, + ""); + static_assert(offsetof(Params, prod_zp_depth) == RUY_OFFSET_PROD_ZP_DEPTH, + ""); + static_assert(offsetof(Params, multiplier_fixedpoint) == + RUY_OFFSET_MULTIPLIER_FIXEDPOINT, + ""); + static_assert( + offsetof(Params, multiplier_exponent) == RUY_OFFSET_MULTIPLIER_EXPONENT, + ""); + static_assert(offsetof(Params, clamp_min) == RUY_OFFSET_CLAMP_MIN, ""); + static_assert(offsetof(Params, clamp_max) == RUY_OFFSET_CLAMP_MAX, ""); + static_assert(offsetof(Params, bias) == RUY_OFFSET_BIAS, ""); + static_assert(offsetof(Params, lhs_sums) == RUY_OFFSET_LHS_SUMS, ""); + static_assert(offsetof(Params, rhs_sums) == RUY_OFFSET_RHS_SUMS, ""); + static_assert(offsetof(Params, flags) == RUY_OFFSET_FLAGS, ""); + static_assert(offsetof(Params, lhs_base_ptr) == RUY_OFFSET_LHS_BASE_PTR, ""); + static_assert(offsetof(Params, start_row) == RUY_OFFSET_START_ROW, ""); + static_assert(offsetof(Params, last_row) == RUY_OFFSET_LAST_ROW, ""); + static_assert(offsetof(Params, last_col) == RUY_OFFSET_LAST_COL, ""); + static_assert(offsetof(Params, lhs_stride) == RUY_OFFSET_LHS_STRIDE, ""); + static_assert(offsetof(Params, rhs_stride) == RUY_OFFSET_RHS_STRIDE, ""); + static_assert(offsetof(Params, dst_stride) == RUY_OFFSET_DST_STRIDE, ""); + static_assert(offsetof(Params, depth) == RUY_OFFSET_DEPTH, ""); +} + +// Fast-int8-trick kernel, similar to this production gemmlowp kernel: +// NEON_64bit_GEMM_Int8Operands_AccumTwoWithin16Bits +// https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L2296 +// +// Relevant target CPUs for this kernel include ARM Cortex-A73 and Cortex-A75, +// since these are 64-bit, out-of-order and without dotprod support. +void Kernel8bitNeonOutOfOrder(const KernelParams8bit<4, 4>& params) { + profiler::ScopeLabel label( + "Kernel (kNeon, optimized for out-of-order cores)"); + + CheckOffsetsInKernelParams8bit(params); + + const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; + const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; + const std::int8_t* lhs_ptr = lhs_col_ptr; + const std::int8_t* rhs_ptr = rhs_col_ptr; + void* dst_col_ptr = params.dst_base_ptr; + void* dst_ptr = dst_col_ptr; + int row = params.start_row; + int col = params.start_col; + + // The asm kernel below has the following NEON register allocation: + // + // v16 -- v31 are int32 accumulators. + // During accumulation, v0 -- v3 are used to load int8 data from LHS and + // v4 -- v7 from RHS: + // + // int8 RHS 16x4 block + // /-----------------------------------------\ + // |v4.b[0] ... v7.b[0] | + // | ... ... | + // |v4.b[15] ... v7.b[15] | + // \-----------------------------------------/ + // int8 LHS 4x16 block + // /---------------------\ /-----------------------------------------\ + // |v0.b[0] ... v0.b[15] | |v16.4s ... v28.4s | + // |v1.b[0] ... v1.b[15] | |v17.4s ... v29.4s | + // |v2.b[0] ... v2.b[15] | |v18.4s ... v30.4s | + // |v3.b[0] ... v3.b[15] | |v19.4s ... v31.4s | + // \---------------------/ \-----------------------------------------/ + // int32 accumulators 4x4 block + // + // No attempt had been made so far at implementing the RUY_OPT_MAX_STREAMING + // optimization for this kernel. + asm volatile( +#define RUY_MAKE_ZERO(reg) "dup " #reg ".4s, wzr\n" + + // clang-format off + + // Load some parameters into registers. + "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" + "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" + "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" + "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" + "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" + "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" + "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" + + // Load the first 64 bytes of LHS and RHS data. + "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v2.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v3.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v4.16b}, [%[rhs_ptr]], #16\n" + "ld1 {v5.16b}, [%[rhs_ptr]], #16\n" + "ld1 {v6.16b}, [%[rhs_ptr]], #16\n" + "ld1 {v7.16b}, [%[rhs_ptr]], #16\n" + + // Clear accumulators. + RUY_MAKE_ZERO(v16) + RUY_MAKE_ZERO(v17) + RUY_MAKE_ZERO(v18) + RUY_MAKE_ZERO(v19) + RUY_MAKE_ZERO(v20) + RUY_MAKE_ZERO(v21) + RUY_MAKE_ZERO(v22) + RUY_MAKE_ZERO(v23) + RUY_MAKE_ZERO(v24) + RUY_MAKE_ZERO(v25) + RUY_MAKE_ZERO(v26) + RUY_MAKE_ZERO(v27) + RUY_MAKE_ZERO(v28) + RUY_MAKE_ZERO(v29) + RUY_MAKE_ZERO(v30) + RUY_MAKE_ZERO(v31) + + // w1 is the number of levels of depth that we have already loaded + // LHS and RHS data for. Corresponding to the initial ld1 instructions + // above, this is currently 16. + "mov w1, #16\n" + + // Perform the first few multiply-adds on the data that we have already + // loaded. + "smull v8.8h, v0.8b, v4.8b\n" + "smull v9.8h, v1.8b, v4.8b\n" + "smull v10.8h, v2.8b, v4.8b\n" + "smull v11.8h, v3.8b, v4.8b\n" + "smull v12.8h, v0.8b, v5.8b\n" + "smull v13.8h, v1.8b, v5.8b\n" + "smull v14.8h, v2.8b, v5.8b\n" + "smull v15.8h, v3.8b, v5.8b\n" + + // Multiply-accumulate second-half, again into the same + // 16bit local accumulator registers. This is where we + // take advantage of having int8 instead of uint8 and therefore + // being able to accumulate two products into int16. + "smlal2 v8.8h, v0.16b, v4.16b\n" + "smlal2 v9.8h, v1.16b, v4.16b\n" + "smlal2 v10.8h, v2.16b, v4.16b\n" + "smlal2 v11.8h, v3.16b, v4.16b\n" + "smlal2 v12.8h, v0.16b, v5.16b\n" + "smlal2 v13.8h, v1.16b, v5.16b\n" + "smlal2 v14.8h, v2.16b, v5.16b\n" + "smlal2 v15.8h, v3.16b, v5.16b\n" + + + // Main loop of the whole GEMM, over rows and columns of the + // destination matrix. + "1:\n" + + // Reminder - w1 is how many levels of depth we have already loaded + // data for, w12 is the total depth. + "cmp w1, w12\n" + "beq 79f\n" + + "2:\n" + + // Some multiplications and 16-bit accumulation were already done above, + // so we start right away in the middle. + "sadalp v16.4s, v8.8h\n" + "ld1 {v4.16b}, [%[rhs_ptr]], #16\n" + "smull v8.8h, v0.8b, v6.8b\n" + "sadalp v17.4s, v9.8h\n" + "ld1 {v5.16b}, [%[rhs_ptr]], #16\n" + "smull v9.8h, v1.8b, v6.8b\n" + "sadalp v18.4s, v10.8h\n" + "smull v10.8h, v2.8b, v6.8b\n" + "sadalp v19.4s, v11.8h\n" + "smull v11.8h, v3.8b, v6.8b\n" + "sadalp v20.4s, v12.8h\n" + "smull v12.8h, v0.8b, v7.8b\n" + "sadalp v21.4s, v13.8h\n" + "smull v13.8h, v1.8b, v7.8b\n" + "sadalp v22.4s, v14.8h\n" + "smull v14.8h, v2.8b, v7.8b\n" + "sadalp v23.4s, v15.8h\n" + "smull v15.8h, v3.8b, v7.8b\n" + + // Multiply-accumulate second-half, again into the same + // 16bit local accumulator registers. This is where we + // take advantage of having int8 instead of uint8 and therefore + // being able to accumulate two products into int16. + "smlal2 v8.8h, v0.16b, v6.16b\n" + "smlal2 v9.8h, v1.16b, v6.16b\n" + "smlal2 v10.8h, v2.16b, v6.16b\n" + "smlal2 v11.8h, v3.16b, v6.16b\n" + + "ld1 {v6.16b}, [%[rhs_ptr]], #16\n" + + "smlal2 v12.8h, v0.16b, v7.16b\n" + "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" + "smlal2 v13.8h, v1.16b, v7.16b\n" + "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" + "smlal2 v14.8h, v2.16b, v7.16b\n" + "ld1 {v2.16b}, [%[lhs_ptr]], #16\n" + "smlal2 v15.8h, v3.16b, v7.16b\n" + "ld1 {v3.16b}, [%[lhs_ptr]], #16\n" + + "sadalp v24.4s, v8.8h\n" + "smull v8.8h, v0.8b, v4.8b\n" + "sadalp v25.4s, v9.8h\n" + "ld1 {v7.16b}, [%[rhs_ptr]], #16\n" + "smull v9.8h, v1.8b, v4.8b\n" + "sadalp v26.4s, v10.8h\n" + "smull v10.8h, v2.8b, v4.8b\n" + "sadalp v27.4s, v11.8h\n" + "smull v11.8h, v3.8b, v4.8b\n" + "sadalp v28.4s, v12.8h\n" + "smull v12.8h, v0.8b, v5.8b\n" + "sadalp v29.4s, v13.8h\n" + "smull v13.8h, v1.8b, v5.8b\n" + "sadalp v30.4s, v14.8h\n" + "smull v14.8h, v2.8b, v5.8b\n" + "sadalp v31.4s, v15.8h\n" + "smull v15.8h, v3.8b, v5.8b\n" + + // Multiply-accumulate second-half, again into the same + // 16bit local accumulator registers. This is where we + // take advantage of having int8 instead of uint8 and therefore + // being able to accumulate two products into int16. + "smlal2 v8.8h, v0.16b, v4.16b\n" + "smlal2 v9.8h, v1.16b, v4.16b\n" + "smlal2 v10.8h, v2.16b, v4.16b\n" + "smlal2 v11.8h, v3.16b, v4.16b\n" + + "smlal2 v12.8h, v0.16b, v5.16b\n" + "smlal2 v13.8h, v1.16b, v5.16b\n" + "smlal2 v14.8h, v2.16b, v5.16b\n" + "smlal2 v15.8h, v3.16b, v5.16b\n" + + + + // Each iteration of this loop advances by 16 levels of depth. + "add w1, w1, #16\n" + + // Loop termination condition + "cmp w1, w12\n" + + "blt 2b\n" + + "79:\n" + + "sadalp v16.4s, v8.8h\n" + "smull v8.8h, v0.8b, v6.8b\n" + "sadalp v17.4s, v9.8h\n" + "smull v9.8h, v1.8b, v6.8b\n" + "sadalp v18.4s, v10.8h\n" + "smull v10.8h, v2.8b, v6.8b\n" + "sadalp v19.4s, v11.8h\n" + "smull v11.8h, v3.8b, v6.8b\n" + "sadalp v20.4s, v12.8h\n" + "smull v12.8h, v0.8b, v7.8b\n" + "sadalp v21.4s, v13.8h\n" + "smull v13.8h, v1.8b, v7.8b\n" + "sadalp v22.4s, v14.8h\n" + "smull v14.8h, v2.8b, v7.8b\n" + "sadalp v23.4s, v15.8h\n" + "smull v15.8h, v3.8b, v7.8b\n" + + // Multiply-accumulate second-half, again into the same + // 16bit local accumulator registers. This is where we + // take advantage of having int8 instead of uint8 and therefore + // being able to accumulate two products into int16. + "smlal2 v8.8h, v0.16b, v6.16b\n" + "smlal2 v9.8h, v1.16b, v6.16b\n" + "smlal2 v10.8h, v2.16b, v6.16b\n" + "smlal2 v11.8h, v3.16b, v6.16b\n" + + "smlal2 v12.8h, v0.16b, v7.16b\n" + "smlal2 v13.8h, v1.16b, v7.16b\n" + "smlal2 v14.8h, v2.16b, v7.16b\n" + "smlal2 v15.8h, v3.16b, v7.16b\n" + + "sadalp v24.4s, v8.8h\n" + "sadalp v25.4s, v9.8h\n" + "sadalp v26.4s, v10.8h\n" + "sadalp v27.4s, v11.8h\n" + "sadalp v28.4s, v12.8h\n" + "sadalp v29.4s, v13.8h\n" + "sadalp v30.4s, v14.8h\n" + "sadalp v31.4s, v15.8h\n" + + // End of accumulation. The registers v16 -- v31 contain the final + // int32 accumulator values of the current 4x4 destination block. + // We now have to compute the final 8-bit values from these int32 + // accumulators, and advance to the next 4x4 block. We intertwine + // these two aspects whenever possible for optimal pipelining, both + // at the data flow level (prefetch data for next block as early as + // possible) and instruction pipelining level (some of the next-block + // work can dual-issue with some of the final work on the current + // block). + + // Reduce 32bit accumulators horizontally. + "addp v16.4s, v16.4s, v17.4s\n" + "addp v18.4s, v18.4s, v19.4s\n" + "addp v20.4s, v20.4s, v21.4s\n" + "addp v22.4s, v22.4s, v23.4s\n" + "addp v24.4s, v24.4s, v25.4s\n" + "addp v26.4s, v26.4s, v27.4s\n" + "addp v28.4s, v28.4s, v29.4s\n" + "addp v30.4s, v30.4s, v31.4s\n" + + // Reduce 32bit accumulators horizontally, second pass + // (each pass adds pairwise. we need to add 4-wise). + "addp v16.4s, v16.4s, v18.4s\n" + "addp v17.4s, v20.4s, v22.4s\n" + "addp v18.4s, v24.4s, v26.4s\n" + "addp v19.4s, v28.4s, v30.4s\n" + + // Logic to advance to the next block in preparation for the next + // iteration of the main loop. For now, we only want to compute + // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are + // not yet ready to update the values of row and col, as we still need + // the current values for the rest of the work on the current block. + + "cmp %w[row], w7\n" // Have we finished the last row? + "bge 4f\n" // If finished last row, go to 4 + // Not finished last row: then advance to next row. + "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #2\n" + "b 5f\n" + "4:\n" // Finished last row... + "mov %[lhs_col_ptr], x5\n" // Go back to first row + // Now we need to advance to the next column. If we already + // finished the last column, then in principle we are done, however + // we can't just return here, as we need to allow the end work of the + // current block to complete. The good news is that at this point it + // doesn't matter what data we load for the next column, since + // we will exit from the main loop below before actually storing + // anything computed from that data. + "cmp %w[col], w8\n" // Have we finished the last column? + "bge 5f\n" // If yes, just carry on without updating the column pointer. + // Not finished last column: then advance to next column. + "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #2\n" + "5:\n" + + // Set the LHS and RHS data pointers to the start of the columns just + // computed. + "mov %[lhs_ptr], %[lhs_col_ptr]\n" + "mov %[rhs_ptr], %[rhs_col_ptr]\n" + + // Load some parameters needed for the end work on current block. + RUY_MAKE_ZERO(v8) + "ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" + "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n" + "ins v13.h[4], w4\n" // dst_zero_point + "ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n" + "ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" + "dup v9.4s, w3\n" // create prod_zp_depth_vec + "add x5, x4, %x[row], lsl #2\n" + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" + "csel x4, x4, x5, eq\n" + + "ld1 {v15.4s}, [x4]\n" // multiplier_fixedpoint + + // Now we load: bias data, LHS sums data, RHS sums data. + + // First, load the base pointers from the params. + "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" + + "add x5, x1, %x[row], lsl #2\n" + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" + "csel x1, x1, x5, eq\n" + + // Load 4 bias values. + "ld1 {v14.4s}, [x1]\n" + + // Now that we know what LHS and RHS data the next iteration of the + // main loop will need to load, we start loading the first 32 bytes of + // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore + // in the rest of the work on the current block. + "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v2.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v3.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v4.16b}, [%[rhs_ptr]], #16\n" + "ld1 {v5.16b}, [%[rhs_ptr]], #16\n" + "ld1 {v6.16b}, [%[rhs_ptr]], #16\n" + "ld1 {v7.16b}, [%[rhs_ptr]], #16\n" + + // Add to the bias values the product (depth * lhs_zero_point * rhs_zero_point), + // See the term NZ1Z2 in equation (7) in https://arxiv.org/pdf/1712.05877.pdf + "add v14.4s, v14.4s, v9.4s\n" + + // Perform the bias-addition (per the above, we have just folded into + // the bias the (depth * lhs_zero_point * rhs_zero_point) term.) + "add v16.4s, v16.4s, v14.4s\n" + "add v17.4s, v17.4s, v14.4s\n" + "add v18.4s, v18.4s, v14.4s\n" + "add v19.4s, v19.4s, v14.4s\n" + + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n" + "beq 401f\n" + "ldr x3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n" + "add x3, x3, %x[col], lsl #2\n" + "ld1 {v14.4s}, [x3]\n" + "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n" + "dup v10.4s, w5\n" // create lhs_zero_point_vec + // Subtract rhs_sums * lhs_zero_point, per + // equation (7) in https://arxiv.org/pdf/1712.05877.pdf + "mls v16.4s, v10.4s, v14.s[0]\n" + "mls v17.4s, v10.4s, v14.s[1]\n" + "mls v18.4s, v10.4s, v14.s[2]\n" + "mls v19.4s, v10.4s, v14.s[3]\n" + "401:\n" + + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n" + "beq 402f\n" + "ldr x2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n" + "add x2, x2, %x[row], lsl #2\n" + "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n" + // Load 4 lhs_sums values. + "ld1 {v11.4s}, [x2]\n" + "ins v13.s[1], w5\n" // rhs_zero_point + // Compute lhs_sums * rhs_zero_point. + "mul v11.4s, v11.4s, v13.s[1]\n" + // Subtract lhs_sums * rhs_zero_point, per + // equation (7) in https://arxiv.org/pdf/1712.05877.pdf + "sub v16.4s, v16.4s, v11.4s\n" + "sub v17.4s, v17.4s, v11.4s\n" + "sub v18.4s, v18.4s, v11.4s\n" + "sub v19.4s, v19.4s, v11.4s\n" + + // If the destination is int32, it means the user asks for the raw + // accumulators, no need for us to downquantize the value. + "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n" + "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n" + + "402:\n" + + // At this point we have computed the final int32 values. Now we + // start down-quantizing them to obtain the final 8bit values from them. + + // As part of this down-quantization, our int32 values will be + // multiplied by a multiplier that has a fixed-point component and an + // exponent component. + + //Load the exponent part of the multiplier. + "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n" + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" + "add x5, x1, %x[row], lsl #2\n" + "csel x1, x1, x5, eq\n" + + "ld1 {v14.4s}, [x1]\n" + + "smax v12.4s, v14.4s, v8.4s\n" + + "sshl v16.4s, v16.4s, v12.4s\n" + "sshl v17.4s, v17.4s, v12.4s\n" + "sshl v18.4s, v18.4s, v12.4s\n" + "sshl v19.4s, v19.4s, v12.4s\n" + + "smin v12.4s, v14.4s, v8.4s\n" + + // Apply the fixed-point part of the multiplier. + "sqrdmulh v16.4s, v16.4s, v15.4s\n" + "sqrdmulh v17.4s, v17.4s, v15.4s\n" + "sqrdmulh v18.4s, v18.4s, v15.4s\n" + "sqrdmulh v19.4s, v19.4s, v15.4s\n" + + // We have some rounding division-by-power-of-two to do. This should + // always use "round to nearest". We allow for some + // freedom in how ties are broken, to strike a good compromise of + // performance on given hardware vs. perfect agreement of results + // across hardware. + // + // When RUY_OPT_NATIVE_ROUNDING is enabled, we allow for implementation + // defined tie-breaks to help performance. On NEON, this means that we + // can just use the NEON rounding instructions, such as srshl. They + // happen to be breaking ties upward. + // + // When RUY_OPT_NATIVE_ROUNDING is disabled, we implement strict + // break-ties-away-from zero, as described in Appendix B of + // https://arxiv.org/pdf/1712.05877.pdf + // When we wrote that, we thought that that would be better unbiased + // than the NEON upwards tie-breaks, and we had observed some + // improvement on some model. However, that is only more unbiased for + // data centered at zero, which was likely the case in that model, + // but is not always the case. If we wanted something more consistently + // unbiased then we should try breaking ties toward-nearest-even. +#if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING) + // Fix up values to be right-shifted, so that the (round to nearest, + // break ties upward) behavior of srshl applied to these fixed-up + // values, produces the same result as the desired (round to nearest, + // break ties away from zero) behavior on the original values. + "and v8.16b, v16.16b, v12.16b\n" + "and v9.16b, v17.16b, v12.16b\n" + "and v14.16b, v18.16b, v12.16b\n" + "and v15.16b, v19.16b, v12.16b\n" + "sshr v8.4s, v8.4s, #31\n" + "sshr v9.4s, v9.4s, #31\n" + "sshr v14.4s, v14.4s, #31\n" + "sshr v15.4s, v15.4s, #31\n" + "sqadd v16.4s, v16.4s, v8.4s\n" + "sqadd v17.4s, v17.4s, v9.4s\n" + "sqadd v18.4s, v18.4s, v14.4s\n" + "sqadd v19.4s, v19.4s, v15.4s\n" +#endif + // At this point we have reduced the problem of correctly implementing + // rounding divide-by-power-of-two, to what the SRSHL instruction can + // do. + "srshl v16.4s, v16.4s, v12.4s\n" + "srshl v17.4s, v17.4s, v12.4s\n" + "srshl v18.4s, v18.4s, v12.4s\n" + "srshl v19.4s, v19.4s, v12.4s\n" + + "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n" + "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n" + "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n" + "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n" + + // Cast-and-saturate from int32 to int16 + "sqxtn v16.4h, v16.4s\n" + "sqxtn2 v16.8h, v17.4s\n" + "sqxtn v17.4h, v18.4s\n" + "sqxtn2 v17.8h, v19.4s\n" + + // At this point, v18 -- v31 aren't used anymore for the current block, + // so we can start clearing these accumulators for the next block + // (next iteration of the main loop). + RUY_MAKE_ZERO(v18) + RUY_MAKE_ZERO(v19) + RUY_MAKE_ZERO(v20) + RUY_MAKE_ZERO(v21) + RUY_MAKE_ZERO(v22) + RUY_MAKE_ZERO(v23) + RUY_MAKE_ZERO(v24) + RUY_MAKE_ZERO(v25) + RUY_MAKE_ZERO(v26) + RUY_MAKE_ZERO(v27) + RUY_MAKE_ZERO(v28) + RUY_MAKE_ZERO(v29) + RUY_MAKE_ZERO(v30) + RUY_MAKE_ZERO(v31) + + // Add the destination zero point + "dup v14.8h, v13.h[4]\n" + "add v16.8h, v16.8h, v14.8h\n" + "add v17.8h, v17.8h, v14.8h\n" + + // Cast-and-saturate from int16 to uint8 + "sqxtun v16.8b, v16.8h\n" + "sqxtun2 v16.16b, v17.8h\n" + + // Load the clamp_min, clamp_max bounds + "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + "dup v14.16b, w2\n" // clamp_min + "dup v15.16b, w3\n" // clamp_max + + // Apply the clamp_min bound + "umax v16.16b, v16.16b, v14.16b\n" + // Apply the clamp_max bound + "umin v16.16b, v16.16b, v15.16b\n" + + // Compute how much of the 4x4 block of destination 8bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 4x4, there are some 4x4 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + "sub w2, %w[dst_cols], %w[col]\n" + "mov w3, #4\n" + "cmp w1, #4\n" + // Compute w1 = how many rows of the 4x4 block fit + "csel w1, w1, w3, le\n" + "cmp w2, #4\n" + // Compute w2 = how many cols of the 4x4 block fit + "csel w2, w2, w3, le\n" + + // Test if w1==4 && w2 == 4, i.e. if all of the 4x4 block fits. + "cmp w1, w3\n" + "ccmp w2, w3, 0, eq\n" + "mov x4, %[dst_ptr]\n" + // Yes, all of the 4x4 block fits, go to fast path. + "beq 30f\n" + // Not all of the 4x4 block fits. + // Store to dst_tmp_buf + "st1 {v16.16b}, [%[dst_tmp_buf]]\n" + // Slow loop copying from dst_tmp_buf to dst. + "mov x3, %[dst_tmp_buf]\n" + "mov w6, #0\n" + "50:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov w5, #0\n" + "51:\n" + "ldrb w7, [x3, w5, uxtw]\n" + "strb w7, [x4, w5, uxtw]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 51b\n" + "add w6, w6, #1\n" + "add x3, x3, #4\n" + "add x4, x4, x11\n" + "cmp w6, w2\n" + "blt 50b\n" + "b 31f\n" + "30:\n" + // Yes, all of the 4x4 block fits. + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v16.b}[0], [x3], #1\n" + "add x4, x4, x11\n" + "st1 {v16.b}[1], [x3], #1\n" + "st1 {v16.b}[2], [x3], #1\n" + "st1 {v16.b}[3], [x3], #1\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v16.b}[4], [x3], #1\n" + "add x4, x4, x11\n" + "st1 {v16.b}[5], [x3], #1\n" + "st1 {v16.b}[6], [x3], #1\n" + "st1 {v16.b}[7], [x3], #1\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v16.b}[8], [x3], #1\n" + "add x4, x4, x11\n" + "st1 {v16.b}[9], [x3], #1\n" + "st1 {v16.b}[10], [x3], #1\n" + "st1 {v16.b}[11], [x3], #1\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v16.b}[12], [x3], #1\n" + "add x4, x4, x11\n" + "st1 {v16.b}[13], [x3], #1\n" + "st1 {v16.b}[14], [x3], #1\n" + "st1 {v16.b}[15], [x3], #1\n" + "31:\n" + + "add %[dst_ptr], %[dst_ptr], #4\n" + + RUY_MAKE_ZERO(v16) + RUY_MAKE_ZERO(v17) + + "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n" + + // Cast-and-saturate from int32 to int16 + "sqxtn v16.4h, v16.4s\n" + "sqxtn2 v16.8h, v17.4s\n" + "sqxtn v17.4h, v18.4s\n" + "sqxtn2 v17.8h, v19.4s\n" + + // At this point, v18 -- v31 aren't used anymore for the current block, + // so we can start clearing these accumulators for the next block + // (next iteration of the main loop). + RUY_MAKE_ZERO(v18) + RUY_MAKE_ZERO(v19) + RUY_MAKE_ZERO(v20) + RUY_MAKE_ZERO(v21) + RUY_MAKE_ZERO(v22) + RUY_MAKE_ZERO(v23) + RUY_MAKE_ZERO(v24) + RUY_MAKE_ZERO(v25) + RUY_MAKE_ZERO(v26) + RUY_MAKE_ZERO(v27) + RUY_MAKE_ZERO(v28) + RUY_MAKE_ZERO(v29) + RUY_MAKE_ZERO(v30) + RUY_MAKE_ZERO(v31) + + // Add the destination zero point + "dup v14.8h, v13.h[4]\n" + "add v16.8h, v16.8h, v14.8h\n" + "add v17.8h, v17.8h, v14.8h\n" + + // Cast-and-saturate from int16 to int8 + "sqxtn v16.8b, v16.8h\n" + "sqxtn2 v16.16b, v17.8h\n" + + // Load the clamp_min, clamp_max bounds + "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + "dup v14.16b, w2\n" // clamp_min + "dup v15.16b, w3\n" // clamp_max + + // Apply the clamp_min bound + "smax v16.16b, v16.16b, v14.16b\n" + // Apply the clamp_max bound + "smin v16.16b, v16.16b, v15.16b\n" + + // Compute how much of the 4x4 block of destination 8bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 4x4, there are some 4x4 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + "sub w2, %w[dst_cols], %w[col]\n" + "mov w3, #4\n" + "cmp w1, #4\n" + // Compute w1 = how many rows of the 4x4 block fit + "csel w1, w1, w3, le\n" + "cmp w2, #4\n" + // Compute w2 = how many cols of the 4x4 block fit + "csel w2, w2, w3, le\n" + + // Test if w1==4 && w2 == 4, i.e. if all of the 4x4 block fits. + "cmp w1, w3\n" + "ccmp w2, w3, 0, eq\n" + "mov x4, %[dst_ptr]\n" + // Yes, all of the 4x4 block fits, go to fast path. + "beq 30f\n" + // Not all of the 4x4 block fits. + // Store to dst_tmp_buf + "st1 {v16.16b}, [%[dst_tmp_buf]]\n" + // Slow loop copying from dst_tmp_buf to dst. + "mov x3, %[dst_tmp_buf]\n" + "mov w6, #0\n" + "50:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov w5, #0\n" + "51:\n" + "ldrb w7, [x3, w5, uxtw]\n" + "strb w7, [x4, w5, uxtw]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 51b\n" + "add w6, w6, #1\n" + "add x3, x3, #4\n" + "add x4, x4, x11\n" + "cmp w6, w2\n" + "blt 50b\n" + "b 31f\n" + "30:\n" + // Yes, all of the 4x4 block fits. + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v16.b}[0], [x3], #1\n" + "add x4, x4, x11\n" + "st1 {v16.b}[1], [x3], #1\n" + "st1 {v16.b}[2], [x3], #1\n" + "st1 {v16.b}[3], [x3], #1\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v16.b}[4], [x3], #1\n" + "add x4, x4, x11\n" + "st1 {v16.b}[5], [x3], #1\n" + "st1 {v16.b}[6], [x3], #1\n" + "st1 {v16.b}[7], [x3], #1\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v16.b}[8], [x3], #1\n" + "add x4, x4, x11\n" + "st1 {v16.b}[9], [x3], #1\n" + "st1 {v16.b}[10], [x3], #1\n" + "st1 {v16.b}[11], [x3], #1\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v16.b}[12], [x3], #1\n" + "add x4, x4, x11\n" + "st1 {v16.b}[13], [x3], #1\n" + "st1 {v16.b}[14], [x3], #1\n" + "st1 {v16.b}[15], [x3], #1\n" + "31:\n" + + "add %[dst_ptr], %[dst_ptr], #4\n" + + RUY_MAKE_ZERO(v16) + RUY_MAKE_ZERO(v17) + + "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n" + + // Add the destination zero point + "dup v14.4h, v13.h[4]\n" + "saddw v16.4s, v16.4s, v14.4h\n" + "saddw v17.4s, v17.4s, v14.4h\n" + "saddw v18.4s, v18.4s, v14.4h\n" + "saddw v19.4s, v19.4s, v14.4h\n" + + // Cast-and-saturate from int32 to int16 + "sqxtn v16.4h, v16.4s\n" + "sqxtn2 v16.8h, v17.4s\n" + "sqxtn v17.4h, v18.4s\n" + "sqxtn2 v17.8h, v19.4s\n" + + // At this point, v18 -- v31 aren't used anymore for the current block, + // so we can start clearing these accumulators for the next block + // (next iteration of the main loop). + RUY_MAKE_ZERO(v18) + RUY_MAKE_ZERO(v19) + RUY_MAKE_ZERO(v20) + RUY_MAKE_ZERO(v21) + RUY_MAKE_ZERO(v22) + RUY_MAKE_ZERO(v23) + RUY_MAKE_ZERO(v24) + RUY_MAKE_ZERO(v25) + RUY_MAKE_ZERO(v26) + RUY_MAKE_ZERO(v27) + RUY_MAKE_ZERO(v28) + RUY_MAKE_ZERO(v29) + RUY_MAKE_ZERO(v30) + RUY_MAKE_ZERO(v31) + + // Load the clamp_min, clamp_max bounds + "ldrh w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + "ldrh w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + "dup v14.8h, w2\n" // clamp_min + "dup v15.8h, w3\n" // clamp_max + + // Apply the clamp_min bound + "smax v16.8h, v16.8h, v14.8h\n" + "smax v17.8h, v17.8h, v14.8h\n" + // Apply the clamp_max bound + "smin v16.8h, v16.8h, v15.8h\n" + "smin v17.8h, v17.8h, v15.8h\n" + + // Compute how much of the 4x4 block of destination 8bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 4x4, there are some 4x4 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + "sub w2, %w[dst_cols], %w[col]\n" + "mov w3, #4\n" + "cmp w1, #4\n" + // Compute w1 = how many rows of the 4x4 block fit + "csel w1, w1, w3, le\n" + "cmp w2, #4\n" + // Compute w2 = how many cols of the 4x4 block fit + "csel w2, w2, w3, le\n" + + // Test if w1==4 && w2 == 4, i.e. if all of the 8x8 block fits. + "cmp w1, w3\n" + "ccmp w2, w3, 0, eq\n" + "mov x4, %[dst_ptr]\n" + // Yes, all of the 4x4 block fits, go to fast path. + "beq 30f\n" + // Not all of the 4x4 block fits. + // Store to dst_tmp_buf + "str q16, [%[dst_tmp_buf], #0]\n" + "str q17, [%[dst_tmp_buf], #16]\n" + // Slow loop copying from dst_tmp_buf to dst. + "mov x3, %[dst_tmp_buf]\n" + "mov w6, #0\n" + "50:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov w5, #0\n" + "51:\n" + "ldrh w7, [x3, x5, lsl #1]\n" + "strh w7, [x4, x5, lsl #1]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 51b\n" + "add w6, w6, #1\n" + "add x3, x3, #8\n" + "add x4, x4, x11\n" + "cmp w6, w2\n" + "blt 50b\n" + "b 31f\n" + "30:\n" + // Yes, all of the 4x4 block fits. + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v16.h}[0], [x3], #2\n" + "add x4, x4, x11\n" + "st1 {v16.h}[1], [x3], #2\n" + "st1 {v16.h}[2], [x3], #2\n" + "st1 {v16.h}[3], [x3], #2\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v16.h}[4], [x3], #2\n" + "add x4, x4, x11\n" + "st1 {v16.h}[5], [x3], #2\n" + "st1 {v16.h}[6], [x3], #2\n" + "st1 {v16.h}[7], [x3], #2\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v17.h}[0], [x3], #2\n" + "add x4, x4, x11\n" + "st1 {v17.h}[1], [x3], #2\n" + "st1 {v17.h}[2], [x3], #2\n" + "st1 {v17.h}[3], [x3], #2\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v17.h}[4], [x3], #2\n" + "add x4, x4, x11\n" + "st1 {v17.h}[5], [x3], #2\n" + "st1 {v17.h}[6], [x3], #2\n" + "st1 {v17.h}[7], [x3], #2\n" + "31:\n" + + "add %[dst_ptr], %[dst_ptr], #8\n" + + RUY_MAKE_ZERO(v16) + RUY_MAKE_ZERO(v17) + + "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n" + + // Since the store type is the same as the accum type, no need for + // downcast. There's also no need for clamp by min/max. + + // At this point, v20 -- v31 aren't used anymore for the current block, + // so we can start clearing these accumulators for the next block + // (next iteration of the main loop). + RUY_MAKE_ZERO(v20) + RUY_MAKE_ZERO(v21) + RUY_MAKE_ZERO(v22) + RUY_MAKE_ZERO(v23) + RUY_MAKE_ZERO(v24) + RUY_MAKE_ZERO(v25) + RUY_MAKE_ZERO(v26) + RUY_MAKE_ZERO(v27) + RUY_MAKE_ZERO(v28) + RUY_MAKE_ZERO(v29) + RUY_MAKE_ZERO(v30) + RUY_MAKE_ZERO(v31) + + // Compute how much of the 4x4 block of destination 8bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 4x4, there are some 4x4 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + "sub w2, %w[dst_cols], %w[col]\n" + "mov w3, #4\n" + "cmp w1, #4\n" + // Compute w1 = how many rows of the 4x4 block fit + "csel w1, w1, w3, le\n" + "cmp w2, #4\n" + // Compute w2 = how many cols of the 4x4 block fit + "csel w2, w2, w3, le\n" + + // Test if w1==4 && w2 == 4, i.e. if all of the 8x8 block fits. + "cmp w1, w3\n" + "ccmp w2, w3, 0, eq\n" + "mov x4, %[dst_ptr]\n" + // Yes, all of the 4x4 block fits, go to fast path. + "beq 30f\n" + // Not all of the 4x4 block fits. + // Store to dst_tmp_buf + "str q16, [%[dst_tmp_buf], #0]\n" + "str q17, [%[dst_tmp_buf], #16]\n" + "str q18, [%[dst_tmp_buf], #32]\n" + "str q19, [%[dst_tmp_buf], #48]\n" + // Slow loop copying from dst_tmp_buf to dst. + "mov x3, %[dst_tmp_buf]\n" + "mov w6, #0\n" + "50:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov w5, #0\n" + "51:\n" + "ldr w7, [x3, x5, lsl #2]\n" + "str w7, [x4, x5, lsl #2]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 51b\n" + "add w6, w6, #1\n" + "add x3, x3, #16\n" + "add x4, x4, x11\n" + "cmp w6, w2\n" + "blt 50b\n" + "b 31f\n" + "30:\n" + // Yes, all of the 4x4 block fits. + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v16.s}[0], [x3], #4\n" + "add x4, x4, x11\n" + "st1 {v16.s}[1], [x3], #4\n" + "st1 {v16.s}[2], [x3], #4\n" + "st1 {v16.s}[3], [x3], #4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v17.s}[0], [x3], #4\n" + "add x4, x4, x11\n" + "st1 {v17.s}[1], [x3], #4\n" + "st1 {v17.s}[2], [x3], #4\n" + "st1 {v17.s}[3], [x3], #4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v18.s}[0], [x3], #4\n" + "add x4, x4, x11\n" + "st1 {v18.s}[1], [x3], #4\n" + "st1 {v18.s}[2], [x3], #4\n" + "st1 {v18.s}[3], [x3], #4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v19.s}[0], [x3], #4\n" + "add x4, x4, x11\n" + "st1 {v19.s}[1], [x3], #4\n" + "st1 {v19.s}[2], [x3], #4\n" + "st1 {v19.s}[3], [x3], #4\n" + "31:\n" + + "add %[dst_ptr], %[dst_ptr], #16\n" + + RUY_MAKE_ZERO(v16) + RUY_MAKE_ZERO(v17) + RUY_MAKE_ZERO(v18) + RUY_MAKE_ZERO(v19) + + RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n" + + // For the next block: perform the first few multiply-adds on the data + // that we have already loaded. + "smull v8.8h, v0.8b, v4.8b\n" + "smull v9.8h, v1.8b, v4.8b\n" + "smull v10.8h, v2.8b, v4.8b\n" + "smull v11.8h, v3.8b, v4.8b\n" + "smull v12.8h, v0.8b, v5.8b\n" + "smull v13.8h, v1.8b, v5.8b\n" + "smull v14.8h, v2.8b, v5.8b\n" + "smull v15.8h, v3.8b, v5.8b\n" + "smlal2 v8.8h, v0.16b, v4.16b\n" + "smlal2 v9.8h, v1.16b, v4.16b\n" + "smlal2 v10.8h, v2.16b, v4.16b\n" + "smlal2 v11.8h, v3.16b, v4.16b\n" + "smlal2 v12.8h, v0.16b, v5.16b\n" + "smlal2 v13.8h, v1.16b, v5.16b\n" + "smlal2 v14.8h, v2.16b, v5.16b\n" + "smlal2 v15.8h, v3.16b, v5.16b\n" + + // Reload some params --- we had used x5 -- x7 for a few other things + // since the last time we had loaded them. + "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" + "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" + + // Move to the next block of the destination matrix, for the next iter + // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already + // been updated earlier. + // Have we reached the end row? + "cmp %w[row], w7\n" + "beq 20f\n" // yes, end row. + // Not end row. Move to the next row. + "add %w[row], %w[row], #4\n" + "b 21f\n" + "20:\n" + // Was already at end row. + "mov %w[row], w6\n" // Move back to first row. + "add %w[col], %w[col], #4\n" // Move to the next column. + "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #2\n" + "mov %[dst_ptr], %[dst_col_ptr]\n" + "21:\n" + + // Main loop exit condition: have we hit the end column? + "cmp %w[col], w8\n" + + // w1 is the number of levels of depth that we have already loaded + // LHS and RHS data for. Corresponding to the initial ld1 instructions + // above, this is currently 4. + "mov w1, #16\n" + + "ble 1b\n" + + // clang-format on + + : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr), + [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), + [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col) + : [ params ] "r"(¶ms), [dst_rows] "r"(params.dst_rows), + [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf), + [dst_type_id] "r"(params.dst_type_id) + : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc", + "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", + "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31"); +} + +// Similar to existing Kernel8bitNeonOutOfOrder but specialized for the case of +// RHS cols == 1. +// Relevant target CPUs for this kernel include ARM Cortex-A73 and Cortex-A75, +// since these are 64-bit, out-of-order and without dotprod support. +void Kernel8bitNeonOutOfOrder1Col(const KernelParams8bit<4, 4>& params) { + profiler::ScopeLabel label( + "Kernel (kNeon, optimized for out-of-order cores)"); + + CheckOffsetsInKernelParams8bit(params); + + const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; + const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; + const std::int8_t* lhs_ptr = lhs_col_ptr; + const std::int8_t* rhs_ptr = rhs_col_ptr; + void* dst_col_ptr = params.dst_base_ptr; + void* dst_ptr = dst_col_ptr; + int row = params.start_row; + int col = params.start_col; + + // The asm kernel below has the following NEON register allocation: + // + // v16 -- v19 are int32 accumulators. + // During accumulation, v0 -- v3 are used to load int8 data from LHS and + // v4 from RHS: + // + // int8 RHS 16x1 block + // /-----------\ + // |v4.b[0] | + // | ... | + // |v4.b[15] | + // \-----------/ + // int8 LHS 4x16 block + // /---------------------\ /-----------\ + // |v0.b[0] ... v0.b[15] | |v16.4s | + // |v1.b[0] ... v1.b[15] | |v17.4s | + // |v2.b[0] ... v2.b[15] | |v18.4s | + // |v3.b[0] ... v3.b[15] | |v19.4s | + // \---------------------/ \-----------/ + // int32 accumulators 4x1 block + // + // No attempt had been made so far at implementing the RUY_OPT_MAX_STREAMING + // optimization for this kernel. + asm volatile( +#define RUY_MAKE_ZERO(reg) "dup " #reg ".4s, wzr\n" + + // clang-format off + + // Load some parameters into registers. + "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" + "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" + "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" + "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" + "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" + "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" + "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" + + // Load the first 64 bytes of LHS and RHS data. + "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v2.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v3.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v4.16b}, [%[rhs_ptr]], #16\n" + "add %[rhs_ptr], %[rhs_ptr], #48\n" + + // Clear accumulators. + RUY_MAKE_ZERO(v16) + RUY_MAKE_ZERO(v17) + RUY_MAKE_ZERO(v18) + RUY_MAKE_ZERO(v19) + + // w1 is the number of levels of depth that we have already loaded + // LHS and RHS data for. Corresponding to the initial ld1 instructions + // above, this is currently 16. + "mov w1, #16\n" + + // Perform the first few multiply-adds on the data that we have already + // loaded. + "smull v8.8h, v0.8b, v4.8b\n" + "smull v9.8h, v1.8b, v4.8b\n" + "smull v10.8h, v2.8b, v4.8b\n" + "smull v11.8h, v3.8b, v4.8b\n" + + // Multiply-accumulate second-half, again into the same + // 16bit local accumulator registers. This is where we + // take advantage of having int8 instead of uint8 and therefore + // being able to accumulate two products into int16. + "smlal2 v8.8h, v0.16b, v4.16b\n" + "smlal2 v9.8h, v1.16b, v4.16b\n" + "smlal2 v10.8h, v2.16b, v4.16b\n" + "smlal2 v11.8h, v3.16b, v4.16b\n" + + // Main loop of the whole GEMM, over rows and columns of the + // destination matrix. + "1:\n" + + // Reminder - w1 is how many levels of depth we have already loaded + // data for, w12 is the total depth. + "cmp w1, w12\n" + "beq 79f\n" + + "2:\n" + + // Some multiplications and 16-bit accumulation were already done above, + // so we start right away in the middle. + "sadalp v16.4s, v8.8h\n" + "ld1 {v4.16b}, [%[rhs_ptr]], #16\n" + "add %[rhs_ptr], %[rhs_ptr], #48\n" + "sadalp v17.4s, v9.8h\n" + "sadalp v18.4s, v10.8h\n" + "sadalp v19.4s, v11.8h\n" + + "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v2.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v3.16b}, [%[lhs_ptr]], #16\n" + + "smull v8.8h, v0.8b, v4.8b\n" + "smull v9.8h, v1.8b, v4.8b\n" + "smull v10.8h, v2.8b, v4.8b\n" + "smull v11.8h, v3.8b, v4.8b\n" + + // Multiply-accumulate second-half, again into the same + // 16bit local accumulator registers. This is where we + // take advantage of having int8 instead of uint8 and therefore + // being able to accumulate two products into int16. + "smlal2 v8.8h, v0.16b, v4.16b\n" + "smlal2 v9.8h, v1.16b, v4.16b\n" + "smlal2 v10.8h, v2.16b, v4.16b\n" + "smlal2 v11.8h, v3.16b, v4.16b\n" + + // Each iteration of this loop advances by 16 levels of depth. + "add w1, w1, #16\n" + + // Loop termination condition + "cmp w1, w12\n" + + "blt 2b\n" + + "79:\n" + + "sadalp v16.4s, v8.8h\n" + "sadalp v17.4s, v9.8h\n" + "sadalp v18.4s, v10.8h\n" + "sadalp v19.4s, v11.8h\n" + + // End of accumulation. The registers v16 -- v19 contain the final + // int32 accumulator values of the current 4x1 destination block. + // We now have to compute the final 8-bit values from these int32 + // accumulators, and advance to the next 4x1 block. We intertwine + // these two aspects whenever possible for optimal pipelining, both + // at the data flow level (prefetch data for next block as early as + // possible) and instruction pipelining level (some of the next-block + // work can dual-issue with some of the final work on the current + // block). + + // Reduce 32bit accumulators horizontally. + "addp v16.4s, v16.4s, v17.4s\n" + "addp v18.4s, v18.4s, v19.4s\n" + + // Reduce 32bit accumulators horizontally, second pass + // (each pass adds pairwise. we need to add 4-wise). + "addp v16.4s, v16.4s, v18.4s\n" + + // Logic to advance to the next block in preparation for the next + // iteration of the main loop. For now, we only want to compute + // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are + // not yet ready to update the values of row and col, as we still need + // the current values for the rest of the work on the current block. + + "cmp %w[row], w7\n" // Have we finished the last row? + "bge 4f\n" // If finished last row, go to 4 + // Not finished last row: then advance to next row. + "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #2\n" + "b 5f\n" + "4:\n" // Finished last row... + "mov %[lhs_col_ptr], x5\n" // Go back to first row + // Now we need to advance to the next column. If we already + // finished the last column, then in principle we are done, however + // we can't just return here, as we need to allow the end work of the + // current block to complete. The good news is that at this point it + // doesn't matter what data we load for the next column, since + // we will exit from the main loop below before actually storing + // anything computed from that data. + "cmp %w[col], w8\n" // Have we finished the last column? + "bge 5f\n" // If yes, just carry on without updating the column pointer. + // Not finished last column: then advance to next column. + // (still multiply column stride by 4 due to packing) + "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #2\n" + "5:\n" + + // Set the LHS and RHS data pointers to the start of the columns just + // computed. + "mov %[lhs_ptr], %[lhs_col_ptr]\n" + "mov %[rhs_ptr], %[rhs_col_ptr]\n" + + // Load some parameters needed for the end work on current block. + RUY_MAKE_ZERO(v8) + "ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" + "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n" + "ins v13.h[4], w4\n" // dst_zero_point + "ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n" + "ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" + "dup v9.4s, w3\n" // create prod_zp_depth_vec + "add x5, x4, %x[row], lsl #2\n" + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" + "csel x4, x4, x5, eq\n" + + "ld1 {v15.4s}, [x4]\n" // multiplier_fixedpoint + + // Now we load: bias data, LHS sums data, RHS sums data. + + // First, load the base pointers from the params. + "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" + + "add x5, x1, %x[row], lsl #2\n" + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" + "csel x1, x1, x5, eq\n" + + // Load 4 bias values. + "ld1 {v14.4s}, [x1]\n" + + // Now that we know what LHS and RHS data the next iteration of the + // main loop will need to load, we start loading the first 32 bytes of + // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore + // in the rest of the work on the current block. + "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v2.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v3.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v4.16b}, [%[rhs_ptr]], #16\n" + "add %[rhs_ptr], %[rhs_ptr], #48\n" + + // Add to the bias values the product (depth * lhs_zero_point * rhs_zero_point), + // See the term NZ1Z2 in equation (7) in https://arxiv.org/pdf/1712.05877.pdf + "add v14.4s, v14.4s, v9.4s\n" + + // Perform the bias-addition (per the above, we have just folded into + // the bias the (depth * lhs_zero_point * rhs_zero_point) term.) + // (all four 32-bit accumulators are in v16 at this point) + "add v16.4s, v16.4s, v14.4s\n" + + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n" + "beq 401f\n" + "ldr x3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n" + "add x3, x3, %x[col], lsl #2\n" + "ld1 {v14.4s}, [x3]\n" + "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n" + "dup v10.4s, w5\n" // create lhs_zero_point_vec + // Subtract rhs_sums * lhs_zero_point, per + // equation (7) in https://arxiv.org/pdf/1712.05877.pdf + "mls v16.4s, v10.4s, v14.s[0]\n" + "401:\n" + + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n" + "beq 402f\n" + "ldr x2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n" + "add x2, x2, %x[row], lsl #2\n" + "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n" + // Load 4 lhs_sums values. + "ld1 {v11.4s}, [x2]\n" + "ins v13.s[1], w5\n" // rhs_zero_point + // Compute lhs_sums * rhs_zero_point. + "mul v11.4s, v11.4s, v13.s[1]\n" + // Subtract lhs_sums * rhs_zero_point, per + // equation (7) in https://arxiv.org/pdf/1712.05877.pdf + "sub v16.4s, v16.4s, v11.4s\n" + + // If the destination is int32, it means the user asks for the raw + // accumulators, no need for us to downquantize the value. + "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n" + "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n" + + "402:\n" + + // At this point we have computed the final int32 values. Now we + // start down-quantizing them to obtain the final 8bit values from them. + + // As part of this down-quantization, our int32 values will be + // multiplied by a multiplier that has a fixed-point component and an + // exponent component. + + //Load the exponent part of the multiplier. + "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n" + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" + "add x5, x1, %x[row], lsl #2\n" + "csel x1, x1, x5, eq\n" + + "ld1 {v14.4s}, [x1]\n" + + "smax v12.4s, v14.4s, v8.4s\n" + + "sshl v16.4s, v16.4s, v12.4s\n" + + "smin v12.4s, v14.4s, v8.4s\n" + + // Apply the fixed-point part of the multiplier. + "sqrdmulh v16.4s, v16.4s, v15.4s\n" + + // We have some rounding division-by-power-of-two to do. This should + // always use "round to nearest". We allow for some + // freedom in how ties are broken, to strike a good compromise of + // performance on given hardware vs. perfect agreement of results + // across hardware. + // + // When RUY_OPT_NATIVE_ROUNDING is enabled, we allow for implementation + // defined tie-breaks to help performance. On NEON, this means that we + // can just use the NEON rounding instructions, such as srshl. They + // happen to be breaking ties upward. + // + // When RUY_OPT_NATIVE_ROUNDING is disabled, we implement strict + // break-ties-away-from zero, as described in Appendix B of + // https://arxiv.org/pdf/1712.05877.pdf + // When we wrote that, we thought that that would be better unbiased + // than the NEON upwards tie-breaks, and we had observed some + // improvement on some model. However, that is only more unbiased for + // data centered at zero, which was likely the case in that model, + // but is not always the case. If we wanted something more consistently + // unbiased then we should try breaking ties toward-nearest-even. +#if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING) + // Fix up values to be right-shifted, so that the (round to nearest, + // break ties upward) behavior of srshl applied to these fixed-up + // values, produces the same result as the desired (round to nearest, + // break ties away from zero) behavior on the original values. + "and v8.16b, v16.16b, v12.16b\n" + "sshr v8.4s, v8.4s, #31\n" + "sqadd v16.4s, v16.4s, v8.4s\n" +#endif + // At this point we have reduced the problem of correctly implementing + // rounding divide-by-power-of-two, to what the SRSHL instruction can + // do. + "srshl v16.4s, v16.4s, v12.4s\n" + + "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n" + "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n" + "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n" + "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n" + + // Cast-and-saturate from int32 to int16 + // After this instruction, all data is in lower half (64-bits) of v16 + "sqxtn v16.4h, v16.4s\n" + + // At this point, v18 -- v31 aren't used anymore for the current block, + // so we can start clearing these accumulators for the next block + // (next iteration of the main loop). + RUY_MAKE_ZERO(v18) + RUY_MAKE_ZERO(v19) + + // Add the destination zero point + "dup v14.8h, v13.h[4]\n" + "add v16.8h, v16.8h, v14.8h\n" + + // Cast-and-saturate from int16 to uint8 + // Now all data is in the first 32-bits of v16 + "sqxtun v16.8b, v16.8h\n" + + // Load the clamp_min, clamp_max bounds + "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + "dup v14.16b, w2\n" // clamp_min + "dup v15.16b, w3\n" // clamp_max + + // Apply the clamp_min bound + "umax v16.16b, v16.16b, v14.16b\n" + // Apply the clamp_max bound + "umin v16.16b, v16.16b, v15.16b\n" + + // Compute how much of the 4x1 block of destination 8bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 4x1, there are some 4x1 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + "mov w3, #4\n" + "cmp w1, #4\n" + // Compute w1 = how many rows of the 4x1 block fit + "csel w1, w1, w3, le\n" + + // Test if w1==4, i.e. if all of the 4x1 block fits. + "cmp w1, w3\n" + + "mov x4, %[dst_ptr]\n" + // Yes, all of the 4x1 block fits, go to fast path. + "beq 30f\n" + // Not all of the 4x1 block fits. + // Store to dst_tmp_buf + "st1 {v16.16b}, [%[dst_tmp_buf]]\n" + // Slow loop copying from dst_tmp_buf to dst. + "mov x3, %[dst_tmp_buf]\n" + "mov w6, #0\n" + "50:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov w5, #0\n" + "51:\n" + "ldrb w7, [x3, w5, uxtw]\n" + "strb w7, [x4, w5, uxtw]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 51b\n" + "b 31f\n" + "30:\n" + // Yes, all of the 4x1 block fits. + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v16.b}[0], [x3], #1\n" + "st1 {v16.b}[1], [x3], #1\n" + "st1 {v16.b}[2], [x3], #1\n" + "st1 {v16.b}[3], [x3], #1\n" + "31:\n" + + "add %[dst_ptr], %[dst_ptr], #4\n" + + RUY_MAKE_ZERO(v16) + RUY_MAKE_ZERO(v17) + + "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n" + + // Cast-and-saturate from int32 to int16 + // After this, all values for output are in the lower half (64 bits) of v16. + "sqxtn v16.4h, v16.4s\n" + + // At this point, v18 -- v31 aren't used anymore for the current block, + // so we can start clearing these accumulators for the next block + // (next iteration of the main loop). + RUY_MAKE_ZERO(v18) + RUY_MAKE_ZERO(v19) + + // Add the destination zero point + "dup v14.8h, v13.h[4]\n" + "add v16.8h, v16.8h, v14.8h\n" + + // Cast-and-saturate from int16 to int8 + "sqxtn v16.8b, v16.8h\n" + // At this point, we only need 4 lowest 8-bit values in v16. + + // Load the clamp_min, clamp_max bounds + "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + "dup v14.16b, w2\n" // clamp_min + "dup v15.16b, w3\n" // clamp_max + + // Apply the clamp_min bound + "smax v16.16b, v16.16b, v14.16b\n" + // Apply the clamp_max bound + "smin v16.16b, v16.16b, v15.16b\n" + + // Compute how much of the 4x4 block of destination 8bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 4x4, there are some 4x4 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + "sub w2, %w[dst_cols], %w[col]\n" + "mov w3, #4\n" + "cmp w1, #4\n" + // Compute w1 = how many rows of the 4x1 block fit + "csel w1, w1, w3, le\n" + "cmp w2, #4\n" + + // Test if w1==4, i.e. if all of the 4x1 block fits. + "cmp w1, w3\n" + "ccmp w2, w3, 0, eq\n" + "mov x4, %[dst_ptr]\n" + // Yes, all of the 4x1 block fits, go to fast path. + "beq 30f\n" + // Not all of the 4x4 block fits. + // Store to dst_tmp_buf + "st1 {v16.16b}, [%[dst_tmp_buf]]\n" + // Slow loop copying from dst_tmp_buf to dst. + "mov x3, %[dst_tmp_buf]\n" + "mov w6, #0\n" + "50:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov w5, #0\n" + "51:\n" + "ldrb w7, [x3, w5, uxtw]\n" + "strb w7, [x4, w5, uxtw]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 51b\n" + "b 31f\n" + "30:\n" + // Yes, all of the 4x4 block fits. + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v16.b}[0], [x3], #1\n" + "st1 {v16.b}[1], [x3], #1\n" + "st1 {v16.b}[2], [x3], #1\n" + "st1 {v16.b}[3], [x3], #1\n" + "31:\n" + + "add %[dst_ptr], %[dst_ptr], #4\n" + + RUY_MAKE_ZERO(v16) + RUY_MAKE_ZERO(v17) + + "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n" + + // Add the destination zero point + "dup v14.4h, v13.h[4]\n" + "saddw v16.4s, v16.4s, v14.4h\n" + + // Cast-and-saturate from int32 to int16 + // After this instruction, all data is in lower half of v16. + "sqxtn v16.4h, v16.4s\n" + + // At this point, v18 -- v31 aren't used anymore for the current block, + // so we can start clearing these accumulators for the next block + // (next iteration of the main loop). + RUY_MAKE_ZERO(v18) + RUY_MAKE_ZERO(v19) + + // Load the clamp_min, clamp_max bounds + "ldrh w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + "ldrh w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + "dup v14.8h, w2\n" // clamp_min + "dup v15.8h, w3\n" // clamp_max + + // Apply the clamp_min bound + "smax v16.8h, v16.8h, v14.8h\n" + // Apply the clamp_max bound + "smin v16.8h, v16.8h, v15.8h\n" + + // Compute how much of the 4x4 block of destination 8bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 4x4, there are some 4x4 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + "sub w2, %w[dst_cols], %w[col]\n" + "mov w3, #4\n" + "cmp w1, #4\n" + // Compute w1 = how many rows of the 4x4 block fit + "csel w1, w1, w3, le\n" + "cmp w2, #4\n" + + // Test if w1==4 && w2 == 4, i.e. if all of the 8x8 block fits. + "cmp w1, w3\n" + "mov x4, %[dst_ptr]\n" + // Yes, all of the 4x4 block fits, go to fast path. + "beq 30f\n" + // Not all of the 4x4 block fits. + // Store to dst_tmp_buf + "str q16, [%[dst_tmp_buf], #0]\n" + // Slow loop copying from dst_tmp_buf to dst. + "mov x3, %[dst_tmp_buf]\n" + "mov w6, #0\n" + "50:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov w5, #0\n" + "51:\n" + "ldrh w7, [x3, x5, lsl #1]\n" + "strh w7, [x4, x5, lsl #1]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 51b\n" + "blt 50b\n" + "b 31f\n" + "30:\n" + // Yes, all of the 4x4 block fits. + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v16.h}[0], [x3], #2\n" + "st1 {v16.h}[1], [x3], #2\n" + "st1 {v16.h}[2], [x3], #2\n" + "st1 {v16.h}[3], [x3], #2\n" + "31:\n" + + "add %[dst_ptr], %[dst_ptr], #8\n" + + RUY_MAKE_ZERO(v16) + RUY_MAKE_ZERO(v17) + + "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n" + + // Since the store type is the same as the accum type, no need for + // downcast. There's also no need for clamp by min/max. + + // Compute how much of the 4x4 block of destination 8bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 4x4, there are some 4x4 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + "sub w2, %w[dst_cols], %w[col]\n" + "mov w3, #4\n" + "cmp w1, #4\n" + // Compute w1 = how many rows of the 4x4 block fit + "csel w1, w1, w3, le\n" + "cmp w2, #4\n" + + // Test if w1==4 i.e. if all of the 4x1 block fits. + "cmp w1, w3\n" + "ccmp w2, w3, 0, eq\n" + "mov x4, %[dst_ptr]\n" + // Yes, all of the 4x1 block fits, go to fast path. + "beq 30f\n" + // Not all of the 4x4 block fits. + // Store to dst_tmp_buf + "str q16, [%[dst_tmp_buf], #0]\n" + // Slow loop copying from dst_tmp_buf to dst. + "mov x3, %[dst_tmp_buf]\n" + "mov w6, #0\n" + "50:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov w5, #0\n" + "51:\n" + "ldr w7, [x3, x5, lsl #2]\n" + "str w7, [x4, x5, lsl #2]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 51b\n" + "b 31f\n" + "30:\n" + // Yes, all of the 4x4 block fits. + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v16.s}[0], [x3], #4\n" + "st1 {v16.s}[1], [x3], #4\n" + "st1 {v16.s}[2], [x3], #4\n" + "st1 {v16.s}[3], [x3], #4\n" + "31:\n" + + "add %[dst_ptr], %[dst_ptr], #16\n" + + RUY_MAKE_ZERO(v16) + RUY_MAKE_ZERO(v17) + RUY_MAKE_ZERO(v18) + RUY_MAKE_ZERO(v19) + + RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n" + + // For the next block: perform the first few multiply-adds on the data + // that we have already loaded. + "smull v8.8h, v0.8b, v4.8b\n" + "smull v9.8h, v1.8b, v4.8b\n" + "smull v10.8h, v2.8b, v4.8b\n" + "smull v11.8h, v3.8b, v4.8b\n" + "smlal2 v8.8h, v0.16b, v4.16b\n" + "smlal2 v9.8h, v1.16b, v4.16b\n" + "smlal2 v10.8h, v2.16b, v4.16b\n" + "smlal2 v11.8h, v3.16b, v4.16b\n" + + // Reload some params --- we had used x5 -- x7 for a few other things + // since the last time we had loaded them. + "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" + "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" + + // Move to the next block of the destination matrix, for the next iter + // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already + // been updated earlier. + // Have we reached the end row? + "cmp %w[row], w7\n" + "beq 20f\n" // yes, end row. + // Not end row. Move to the next row. + "add %w[row], %w[row], #4\n" + "b 21f\n" + "20:\n" + // Was already at end row. + "mov %w[row], w6\n" // Move back to first row. + "add %w[col], %w[col], #4\n" // Move to the next column. + "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #2\n" + "mov %[dst_ptr], %[dst_col_ptr]\n" + "21:\n" + + // Main loop exit condition: have we hit the end column? + "cmp %w[col], w8\n" + + // w1 is the number of levels of depth that we have already loaded + // LHS and RHS data for. Corresponding to the initial ld1 instructions + // above, this is currently 16. + "mov w1, #16\n" + + "ble 1b\n" + + // clang-format on + + : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr), + [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), + [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col) + : [ params ] "r"(¶ms), [dst_rows] "r"(params.dst_rows), + [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf), + [dst_type_id] "r"(params.dst_type_id) + : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc", + "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", + "v13", "v14", "v15", "v16", "v17", "v18", "v19"); +} + +// Variant of the above Kernel8bitNeonOutOfOrder, tuned for in-order CPUs. +// Specifically here, the relevant in-order CPUs are ARM Cortex-A53 and +// the original Cortex-A55, since these are 64-bit and do not support dotprod. +// +// While this kernel does not have a direct equivalent in gemmlowp, it was +// developed based on insights that David Mansell at ARM shared with their +// contribution of gemmlowp kernels tuned for Cortex-A53, with very helpful +// comments. Specifically, see this comment about tuning for Cortex-A53: +// https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L4215 +void Kernel8bitNeonInOrder(const KernelParams8bit<4, 4>& params) { + profiler::ScopeLabel label("Kernel (kNeon, optimized for in-order cores)"); + + CheckOffsetsInKernelParams8bit(params); + + const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; + const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; + const std::int8_t* lhs_ptr = lhs_col_ptr; + const std::int8_t* rhs_ptr = rhs_col_ptr; + void* dst_col_ptr = params.dst_base_ptr; + void* dst_ptr = dst_col_ptr; + int row = params.start_row; + int col = params.start_col; + + // The asm kernel below has the following NEON register allocation: + // + // v16 -- v31 are int32 accumulators. + // During accumulation, v0 -- v3 are used to load int8 data from LHS and + // v4 -- v7 from RHS: + // + // int8 RHS 16x4 block + // /-----------------------------------------\ + // |v4.b[0] ... v7.b[0] | + // | ... ... | + // |v4.b[15] ... v7.b[15] | + // \-----------------------------------------/ + // int8 LHS 4x16 block + // /---------------------\ /-----------------------------------------\ + // |v0.b[0] ... v0.b[15] | |v16.4s ... v28.4s | + // |v1.b[0] ... v1.b[15] | |v17.4s ... v29.4s | + // |v2.b[0] ... v2.b[15] | |v18.4s ... v30.4s | + // |v3.b[0] ... v3.b[15] | |v19.4s ... v31.4s | + // \---------------------/ \-----------------------------------------/ + // int32 accumulators 4x4 block + asm volatile( +#define RUY_MAKE_ZERO(reg) "dup " #reg ".4s, wzr\n" + + // clang-format off + + // Load some parameters into registers. + "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + RUY_MAKE_ZERO(v16) + "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" + RUY_MAKE_ZERO(v17) + "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" + RUY_MAKE_ZERO(v18) + "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" + RUY_MAKE_ZERO(v19) + "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" + RUY_MAKE_ZERO(v20) + "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" + RUY_MAKE_ZERO(v21) + "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" + RUY_MAKE_ZERO(v22) + "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" + RUY_MAKE_ZERO(v23) + + // Load the first 64 bytes of LHS and RHS data. + "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" + RUY_MAKE_ZERO(v24) + "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" + RUY_MAKE_ZERO(v25) + "ld1 {v2.16b}, [%[lhs_ptr]], #16\n" + RUY_MAKE_ZERO(v26) + "ld1 {v3.16b}, [%[lhs_ptr]], #16\n" + RUY_MAKE_ZERO(v27) + "ld1 {v4.16b}, [%[rhs_ptr]], #16\n" + RUY_MAKE_ZERO(v28) + "ld1 {v5.16b}, [%[rhs_ptr]], #16\n" + RUY_MAKE_ZERO(v29) + "ld1 {v6.16b}, [%[rhs_ptr]], #16\n" + RUY_MAKE_ZERO(v30) + "ld1 {v7.16b}, [%[rhs_ptr]], #16\n" + RUY_MAKE_ZERO(v31) + + + // w1 is the number of levels of depth that we have already loaded + // LHS and RHS data for. Corresponding to the initial ld1 instructions + // above, this is currently 16. + "mov w1, #16\n" + + // Perform the first few multiply-adds on the data that we have already + // loaded. + "smull v8.8h, v0.8b, v4.8b\n" + "smull v9.8h, v1.8b, v4.8b\n" + "smull v10.8h, v2.8b, v4.8b\n" + "smull v11.8h, v3.8b, v4.8b\n" + "smull v12.8h, v0.8b, v5.8b\n" + "smull v13.8h, v1.8b, v5.8b\n" + "smull v14.8h, v2.8b, v5.8b\n" + "smull v15.8h, v3.8b, v5.8b\n" + + // Multiply-accumulate second-half, again into the same + // 16bit local accumulator registers. This is where we + // take advantage of having int8 instead of uint8 and therefore + // being able to accumulate two products into int16. + "smlal2 v8.8h, v0.16b, v4.16b\n" + "smlal2 v9.8h, v1.16b, v4.16b\n" + "smlal2 v10.8h, v2.16b, v4.16b\n" + "smlal2 v11.8h, v3.16b, v4.16b\n" + "smlal2 v12.8h, v0.16b, v5.16b\n" + "smlal2 v13.8h, v1.16b, v5.16b\n" + "smlal2 v14.8h, v2.16b, v5.16b\n" + "smlal2 v15.8h, v3.16b, v5.16b\n" + + + // Main loop of the whole GEMM, over rows and columns of the + // destination matrix. + "1:\n" + + // Reminder - w1 is how many levels of depth we have already loaded + // data for, w12 is the total depth. + "cmp w1, w12\n" + "beq 79f\n" + + "2:\n" + + // Some multiplications and 16-bit accumulation were already done above, + // so we start right away in the middle. + "sadalp v16.4s, v8.8h\n" + "ldr d4, [%[rhs_ptr], #0]\n" + "smull v8.8h, v0.8b, v6.8b\n" + "ldr x7, [%[rhs_ptr], #8]\n" + "sadalp v17.4s, v9.8h\n" + "ldr d5, [%[rhs_ptr], #16]\n" + "smull v9.8h, v1.8b, v6.8b\n" + "ldr x8, [%[rhs_ptr], #24]\n" + "sadalp v18.4s, v10.8h\n" + "smull v10.8h, v2.8b, v6.8b\n" + "sadalp v19.4s, v11.8h\n" + "add %[lhs_ptr], %[lhs_ptr], #64\n" + "smull v11.8h, v3.8b, v6.8b\n" + "add %[rhs_ptr], %[rhs_ptr], #64\n" + "sadalp v20.4s, v12.8h\n" + // Each iteration of this loop advances by 16 levels of depth. + "add w1, w1, #16\n" + "smull v12.8h, v0.8b, v7.8b\n" + // Loop termination condition + "cmp w1, w12\n" + "sadalp v21.4s, v13.8h\n" + "ldr x3, [%[lhs_ptr], #-56]\n" + "smull v13.8h, v1.8b, v7.8b\n" + "ldr x4, [%[lhs_ptr], #-40]\n" + "sadalp v22.4s, v14.8h\n" + "ldr x5, [%[lhs_ptr], #-24]\n" + "smull v14.8h, v2.8b, v7.8b\n" + "ldr x6, [%[lhs_ptr], #-8]\n" + "sadalp v23.4s, v15.8h\n" + "smull v15.8h, v3.8b, v7.8b\n" + + // Multiply-accumulate second-half, again into the same + // 16bit local accumulator registers. This is where we + // take advantage of having int8 instead of uint8 and therefore + // being able to accumulate two products into int16. + "smlal2 v8.8h, v0.16b, v6.16b\n" + "smlal2 v9.8h, v1.16b, v6.16b\n" + "smlal2 v10.8h, v2.16b, v6.16b\n" + "ldr x9, [%[rhs_ptr], #-24]\n" + "smlal2 v11.8h, v3.16b, v6.16b\n" + "ldr d6, [%[rhs_ptr], #-32]\n" + "smlal2 v12.8h, v0.16b, v7.16b\n" + "ldr d0, [%[lhs_ptr], #-64]\n" + "smlal2 v13.8h, v1.16b, v7.16b\n" + "ldr d1, [%[lhs_ptr], #-48]\n" + "smlal2 v14.8h, v2.16b, v7.16b\n" + "ins v4.d[1], x7\n" + "smlal2 v15.8h, v3.16b, v7.16b\n" + "ins v5.d[1], x8\n" + + "ldr d2, [%[lhs_ptr], #-32]\n" + "ins v0.d[1], x3\n" + "sadalp v24.4s, v8.8h\n" + "ldr d3, [%[lhs_ptr], #-16]\n" + "ins v1.d[1], x4\n" + "smull v8.8h, v0.8b, v4.8b\n" + "ins v2.d[1], x5\n" + "sadalp v25.4s, v9.8h\n" + "ins v3.d[1], x6\n" + "smull v9.8h, v1.8b, v4.8b\n" + "ldr d7, [%[rhs_ptr], #-16]\n" + "sadalp v26.4s, v10.8h\n" + "ldr x10, [%[rhs_ptr], #-8]\n" + "smull v10.8h, v2.8b, v4.8b\n" + "sadalp v27.4s, v11.8h\n" + "smull v11.8h, v3.8b, v4.8b\n" + "sadalp v28.4s, v12.8h\n" + "smull v12.8h, v0.8b, v5.8b\n" + "sadalp v29.4s, v13.8h\n" + "smull v13.8h, v1.8b, v5.8b\n" + "sadalp v30.4s, v14.8h\n" + "smull v14.8h, v2.8b, v5.8b\n" + "sadalp v31.4s, v15.8h\n" + "smull v15.8h, v3.8b, v5.8b\n" + + // Multiply-accumulate second-half, again into the same + // 16bit local accumulator registers. This is where we + // take advantage of having int8 instead of uint8 and therefore + // being able to accumulate two products into int16. + "smlal2 v8.8h, v0.16b, v4.16b\n" + "smlal2 v9.8h, v1.16b, v4.16b\n" + "smlal2 v10.8h, v2.16b, v4.16b\n" + "smlal2 v11.8h, v3.16b, v4.16b\n" + + "smlal2 v12.8h, v0.16b, v5.16b\n" + "smlal2 v13.8h, v1.16b, v5.16b\n" + "ins v6.d[1], x9\n" + "smlal2 v14.8h, v2.16b, v5.16b\n" + "ins v7.d[1], x10\n" + "smlal2 v15.8h, v3.16b, v5.16b\n" + + "blt 2b\n" + + "79:\n" + + "sadalp v16.4s, v8.8h\n" + "smull v8.8h, v0.8b, v6.8b\n" + "sadalp v17.4s, v9.8h\n" + "smull v9.8h, v1.8b, v6.8b\n" + "sadalp v18.4s, v10.8h\n" + "smull v10.8h, v2.8b, v6.8b\n" + "sadalp v19.4s, v11.8h\n" + "smull v11.8h, v3.8b, v6.8b\n" + "sadalp v20.4s, v12.8h\n" + "smull v12.8h, v0.8b, v7.8b\n" + "sadalp v21.4s, v13.8h\n" + "smull v13.8h, v1.8b, v7.8b\n" + "sadalp v22.4s, v14.8h\n" + "smull v14.8h, v2.8b, v7.8b\n" + "sadalp v23.4s, v15.8h\n" + "smull v15.8h, v3.8b, v7.8b\n" + + // Multiply-accumulate second-half, again into the same + // 16bit local accumulator registers. This is where we + // take advantage of having int8 instead of uint8 and therefore + // being able to accumulate two products into int16. + "smlal2 v8.8h, v0.16b, v6.16b\n" + "smlal2 v9.8h, v1.16b, v6.16b\n" + "smlal2 v10.8h, v2.16b, v6.16b\n" + "smlal2 v11.8h, v3.16b, v6.16b\n" + + "smlal2 v12.8h, v0.16b, v7.16b\n" + "smlal2 v13.8h, v1.16b, v7.16b\n" + "smlal2 v14.8h, v2.16b, v7.16b\n" + "smlal2 v15.8h, v3.16b, v7.16b\n" + + "sadalp v24.4s, v8.8h\n" + "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + "sadalp v25.4s, v9.8h\n" + "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" + "sadalp v26.4s, v10.8h\n" + "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" + "sadalp v27.4s, v11.8h\n" + "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" + "sadalp v28.4s, v12.8h\n" + "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" + "sadalp v29.4s, v13.8h\n" + "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" + "sadalp v30.4s, v14.8h\n" + "sadalp v31.4s, v15.8h\n" + + // End of accumulation. The registers v16 -- v31 contain the final + // int32 accumulator values of the current 4x4 destination block. + // We now have to compute the final 8-bit values from these int32 + // accumulators, and advance to the next 4x4 block. We intertwine + // these two aspects whenever possible for optimal pipelining, both + // at the data flow level (prefetch data for next block as early as + // possible) and instruction pipelining level (some of the next-block + // work can dual-issue with some of the final work on the current + // block). + + // Reduce 32bit accumulators horizontally. + "addp v16.4s, v16.4s, v17.4s\n" + "addp v18.4s, v18.4s, v19.4s\n" + "addp v20.4s, v20.4s, v21.4s\n" + "addp v22.4s, v22.4s, v23.4s\n" + "addp v24.4s, v24.4s, v25.4s\n" + "addp v26.4s, v26.4s, v27.4s\n" + "addp v28.4s, v28.4s, v29.4s\n" + "addp v30.4s, v30.4s, v31.4s\n" + + // Reduce 32bit accumulators horizontally, second pass + // (each pass adds pairwise. we need to add 4-wise). + "addp v16.4s, v16.4s, v18.4s\n" + "addp v17.4s, v20.4s, v22.4s\n" + "addp v18.4s, v24.4s, v26.4s\n" + "addp v19.4s, v28.4s, v30.4s\n" + + // Logic to advance to the next block in preparation for the next + // iteration of the main loop. For now, we only want to compute + // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are + // not yet ready to update the values of row and col, as we still need + // the current values for the rest of the work on the current block. + + "cmp %w[row], w7\n" // Have we finished the last row? + "bge 4f\n" // If finished last row, go to 4 + // Not finished last row: then advance to next row. + "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #2\n" + "b 5f\n" + "4:\n" // Finished last row... + "mov %[lhs_col_ptr], x5\n" // Go back to first row + // Now we need to advance to the next column. If we already + // finished the last column, then in principle we are done, however + // we can't just return here, as we need to allow the end work of the + // current block to complete. The good news is that at this point it + // doesn't matter what data we load for the next column, since + // we will exit from the main loop below before actually storing + // anything computed from that data. + "cmp %w[col], w8\n" // Have we finished the last column? + "bge 5f\n" // If yes, just carry on without updating the column pointer. + // Not finished last column: then advance to next column. + "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #2\n" + "5:\n" + + // Set the LHS and RHS data pointers to the start of the columns just + // computed. + "mov %[lhs_ptr], %[lhs_col_ptr]\n" + "mov %[rhs_ptr], %[rhs_col_ptr]\n" + + // Load some parameters needed for the end work on current block. + RUY_MAKE_ZERO(v8) + "ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" + "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n" + "ins v13.h[4], w4\n" // dst_zero_point + "ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n" + "ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" + "dup v9.4s, w3\n" // create prod_zp_depth_vec + "add x5, x4, %x[row], lsl #2\n" + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" + "csel x4, x4, x5, eq\n" + + "ld1 {v15.4s}, [x4]\n" // multiplier_fixedpoint + + "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" + "add x5, x1, %x[row], lsl #2\n" + + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" + "csel x1, x1, x5, eq\n" + + // Load 4 bias values. + "ld1 {v14.4s}, [x1]\n" + + // Now that we know what LHS and RHS data the next iteration of the + // main loop will need to load, we start loading the first 32 bytes of + // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore + // in the rest of the work on the current block. + + // Add to the bias values the product (depth * lhs_zero_point * rhs_zero_point), + // See the term NZ1Z2 in equation (7) in https://arxiv.org/pdf/1712.05877.pdf + "add v14.4s, v14.4s, v9.4s\n" + "ldr d0, [%[lhs_ptr], #0]\n" + + // Perform the bias-addition (per the above, we have just folded into + // the bias the (depth * lhs_zero_point * rhs_zero_point) term.) + "add v16.4s, v16.4s, v14.4s\n" + "ldr d1, [%[lhs_ptr], #16]\n" + "add v17.4s, v17.4s, v14.4s\n" + "ldr d2, [%[lhs_ptr], #32]\n" + "add v18.4s, v18.4s, v14.4s\n" + "ldr d3, [%[lhs_ptr], #48]\n" + "add v19.4s, v19.4s, v14.4s\n" + "ldr d4, [%[rhs_ptr], #0]\n" + "ldr d5, [%[rhs_ptr], #16]\n" + "ldr d6, [%[rhs_ptr], #32]\n" + "ldr d7, [%[rhs_ptr], #48]\n" + + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n" + "beq 401f\n" + "ldr x3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n" + "add x3, x3, %x[col], lsl #2\n" + "ld1 {v14.4s}, [x3]\n" + "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n" + "dup v10.4s, w5\n" // create lhs_zero_point_vec + // Subtract rhs_sums * lhs_zero_point, per + // equation (7) in https://arxiv.org/pdf/1712.05877.pdf + "mls v16.4s, v10.4s, v14.s[0]\n" + "mls v17.4s, v10.4s, v14.s[1]\n" + "mls v18.4s, v10.4s, v14.s[2]\n" + "mls v19.4s, v10.4s, v14.s[3]\n" + "401:\n" + + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n" + "beq 402f\n" + "ldr x2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n" + "add x2, x2, %x[row], lsl #2\n" + "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n" + // Load 4 lhs_sums values. + "ld1 {v11.4s}, [x2]\n" + "ins v13.s[1], w5\n" // rhs_zero_point + // Compute lhs_sums * rhs_zero_point. + "mul v11.4s, v11.4s, v13.s[1]\n" + // Subtract lhs_sums * rhs_zero_point, per + // equation (7) in https://arxiv.org/pdf/1712.05877.pdf + "sub v16.4s, v16.4s, v11.4s\n" + "sub v17.4s, v17.4s, v11.4s\n" + "sub v18.4s, v18.4s, v11.4s\n" + "sub v19.4s, v19.4s, v11.4s\n" + + // If the destination is int32, it means the user asks for the raw + // accumulators, no need for us to downquantize the value. + "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n" + "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n" + + "402:\n" + + // At this point we have computed the final int32 values. Now we + // start down-quantizing them to obtain the final 8bit values from them. + + // As part of this down-quantization, our int32 values will be + // multiplied by a multiplier that has a fixed-point component and an + // exponent component. + + + "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n" + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" + "add x5, x1, %x[row], lsl #2\n" + "csel x1, x1, x5, eq\n" + + "ld1 {v14.4s}, [x1]\n" + + "smax v12.4s, v14.4s, v8.4s\n" + "ldr x1, [%[lhs_ptr], #8]\n" + + "sshl v16.4s, v16.4s, v12.4s\n" + "ldr x2, [%[lhs_ptr], #24]\n" + "sshl v17.4s, v17.4s, v12.4s\n" + "ldr x3, [%[lhs_ptr], #40]\n" + "sshl v18.4s, v18.4s, v12.4s\n" + "ldr x4, [%[lhs_ptr], #56]\n" + "sshl v19.4s, v19.4s, v12.4s\n" + + "smin v12.4s, v14.4s, v8.4s\n" + + // Apply the fixed-point part of the multiplier. + "ins v0.d[1], x1\n" + "ldr x1, [%[rhs_ptr], #8]\n" + "sqrdmulh v16.4s, v16.4s, v15.4s\n" + "ins v1.d[1], x2\n" + "ldr x2, [%[rhs_ptr], #24]\n" + "sqrdmulh v17.4s, v17.4s, v15.4s\n" + "ins v2.d[1], x3\n" + "ldr x3, [%[rhs_ptr], #40]\n" + "sqrdmulh v18.4s, v18.4s, v15.4s\n" + "ins v3.d[1], x4\n" + "ldr x4, [%[rhs_ptr], #56]\n" + "sqrdmulh v19.4s, v19.4s, v15.4s\n" + + // We have some rounding division-by-power-of-two to do. This should + // always use "round to nearest". We allow for some + // freedom in how ties are broken, to strike a good compromise of + // performance on given hardware vs. perfect agreement of results + // across hardware. + // + // When RUY_OPT_NATIVE_ROUNDING is enabled, we allow for implementation + // defined tie-breaks to help performance. On NEON, this means that we + // can just use the NEON rounding instructions, such as srshl. They + // happen to be breaking ties upward. + // + // When RUY_OPT_NATIVE_ROUNDING is disabled, we implement strict + // break-ties-away-from zero, as described in Appendix B of + // https://arxiv.org/pdf/1712.05877.pdf + // When we wrote that, we thought that that would be better unbiased + // than the NEON upwards tie-breaks, and we had observed some + // improvement on some model. However, that is only more unbiased for + // data centered at zero, which was likely the case in that model, + // but is not always the case. If we wanted something more consistently + // unbiased then we should try breaking ties toward-nearest-even. +#if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING) + // Fix up values to be right-shifted, so that the (round to nearest, + // break ties upward) behavior of srshl applied to these fixed-up + // values, produces the same result as the desired (round to nearest, + // break ties away from zero) behavior on the original values. + "and v8.16b, v16.16b, v12.16b\n" + "and v9.16b, v17.16b, v12.16b\n" + "and v14.16b, v18.16b, v12.16b\n" + "and v15.16b, v19.16b, v12.16b\n" + "sshr v8.4s, v8.4s, #31\n" + "sshr v9.4s, v9.4s, #31\n" + "sshr v14.4s, v14.4s, #31\n" + "sshr v15.4s, v15.4s, #31\n" + "sqadd v16.4s, v16.4s, v8.4s\n" + "sqadd v17.4s, v17.4s, v9.4s\n" + "sqadd v18.4s, v18.4s, v14.4s\n" + "sqadd v19.4s, v19.4s, v15.4s\n" +#endif + // At this point we have reduced the problem of correctly implementing + // rounding divide-by-power-of-two, to what the SRSHL instruction can + // do. + "srshl v16.4s, v16.4s, v12.4s\n" + "srshl v17.4s, v17.4s, v12.4s\n" + "srshl v18.4s, v18.4s, v12.4s\n" + "srshl v19.4s, v19.4s, v12.4s\n" + + "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n" + "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n" + "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n" + "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n" + + "ins v4.d[1], x1\n" + "sqxtn v16.4h, v16.4s\n" + "ins v5.d[1], x2\n" + "sqxtn2 v16.8h, v17.4s\n" + "ins v6.d[1], x3\n" + "sqxtn v17.4h, v18.4s\n" + "ins v7.d[1], x4\n" + RUY_MAKE_ZERO(v18) + "sqxtn2 v17.8h, v19.4s\n" + + // At this point, v18 -- v31 aren't used anymore for the current block, + // so we can start clearing these accumulators for the next block + // (next iteration of the main loop). + RUY_MAKE_ZERO(v19) + + // Add the destination zero point + "add %[lhs_ptr], %[lhs_ptr], #64\n" + "dup v14.8h, v13.h[4]\n" + RUY_MAKE_ZERO(v20) + "add %[rhs_ptr], %[rhs_ptr], #64\n" + "add v16.8h, v16.8h, v14.8h\n" + RUY_MAKE_ZERO(v21) + "add v17.8h, v17.8h, v14.8h\n" + RUY_MAKE_ZERO(v22) + + // Cast-and-saturate from int16 to uint8 + "sqxtun v16.8b, v16.8h\n" + RUY_MAKE_ZERO(v23) + "sqxtun2 v16.16b, v17.8h\n" + RUY_MAKE_ZERO(v24) + + // Load the clamp_min, clamp_max bounds + "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + RUY_MAKE_ZERO(v25) + "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + RUY_MAKE_ZERO(v26) + "dup v14.16b, w2\n" // clamp_min + RUY_MAKE_ZERO(v27) + "dup v15.16b, w3\n" // clamp_max + RUY_MAKE_ZERO(v28) + + // Apply the clamp_min bound + "umax v16.16b, v16.16b, v14.16b\n" + RUY_MAKE_ZERO(v29) + // Apply the clamp_max bound + "umin v16.16b, v16.16b, v15.16b\n" + RUY_MAKE_ZERO(v30) + + // Compute how much of the 4x4 block of destination 8bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 4x4, there are some 4x4 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + RUY_MAKE_ZERO(v31) + "sub w2, %w[dst_cols], %w[col]\n" + "mov w3, #4\n" + "cmp w1, #4\n" + // Compute w1 = how many rows of the 4x4 block fit + "csel w1, w1, w3, le\n" + "cmp w2, #4\n" + // Compute w2 = how many cols of the 4x4 block fit + "csel w2, w2, w3, le\n" + + // Test if w1==4 && w2 == 4, i.e. if all of the 8x8 block fits. + "cmp w1, w3\n" + "ccmp w2, w3, 0, eq\n" + "mov x4, %[dst_ptr]\n" + // Yes, all of the 4x4 block fits, go to fast path. + "beq 30f\n" + // Not all of the 4x4 block fits. + // Store to dst_tmp_buf + "st1 {v16.16b}, [%[dst_tmp_buf]]\n" + // Slow loop copying from dst_tmp_buf to dst. + "mov x3, %[dst_tmp_buf]\n" + "mov w6, #0\n" + "50:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov w5, #0\n" + "51:\n" + "ldrb w7, [x3, w5, uxtw]\n" + "strb w7, [x4, w5, uxtw]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 51b\n" + "add w6, w6, #1\n" + "add x3, x3, #4\n" + "add x4, x4, x11\n" + "cmp w6, w2\n" + "blt 50b\n" + "b 31f\n" + "30:\n" + // Yes, all of the 4x4 block fits. + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v16.b}[0], [x3], #1\n" + "add x4, x4, x11\n" + "st1 {v16.b}[1], [x3], #1\n" + "st1 {v16.b}[2], [x3], #1\n" + "st1 {v16.b}[3], [x3], #1\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v16.b}[4], [x3], #1\n" + "add x4, x4, x11\n" + "st1 {v16.b}[5], [x3], #1\n" + "st1 {v16.b}[6], [x3], #1\n" + "st1 {v16.b}[7], [x3], #1\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v16.b}[8], [x3], #1\n" + "add x4, x4, x11\n" + "st1 {v16.b}[9], [x3], #1\n" + "st1 {v16.b}[10], [x3], #1\n" + "st1 {v16.b}[11], [x3], #1\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v16.b}[12], [x3], #1\n" + "add x4, x4, x11\n" + "st1 {v16.b}[13], [x3], #1\n" + "st1 {v16.b}[14], [x3], #1\n" + "st1 {v16.b}[15], [x3], #1\n" + "31:\n" + + "add %[dst_ptr], %[dst_ptr], #4\n" + + RUY_MAKE_ZERO(v16) + RUY_MAKE_ZERO(v17) + + "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n" + + "ins v4.d[1], x1\n" + "sqxtn v16.4h, v16.4s\n" + "ins v5.d[1], x2\n" + "sqxtn2 v16.8h, v17.4s\n" + "ins v6.d[1], x3\n" + "sqxtn v17.4h, v18.4s\n" + "ins v7.d[1], x4\n" + RUY_MAKE_ZERO(v18) + "sqxtn2 v17.8h, v19.4s\n" + + // At this point, v18 -- v31 aren't used anymore for the current block, + // so we can start clearing these accumulators for the next block + // (next iteration of the main loop). + RUY_MAKE_ZERO(v19) + + // Add the destination zero point + "add %[lhs_ptr], %[lhs_ptr], #64\n" + "dup v14.8h, v13.h[4]\n" + RUY_MAKE_ZERO(v20) + "add %[rhs_ptr], %[rhs_ptr], #64\n" + "add v16.8h, v16.8h, v14.8h\n" + RUY_MAKE_ZERO(v21) + "add v17.8h, v17.8h, v14.8h\n" + RUY_MAKE_ZERO(v22) + + // Cast-and-saturate from int16 to uint8 + "sqxtn v16.8b, v16.8h\n" + RUY_MAKE_ZERO(v23) + "sqxtn2 v16.16b, v17.8h\n" + RUY_MAKE_ZERO(v24) + + // Load the clamp_min, clamp_max bounds + "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + RUY_MAKE_ZERO(v25) + "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + RUY_MAKE_ZERO(v26) + "dup v14.16b, w2\n" // clamp_min + RUY_MAKE_ZERO(v27) + "dup v15.16b, w3\n" // clamp_max + RUY_MAKE_ZERO(v28) + + // Apply the clamp_min bound + "smax v16.16b, v16.16b, v14.16b\n" + RUY_MAKE_ZERO(v29) + // Apply the clamp_max bound + "smin v16.16b, v16.16b, v15.16b\n" + RUY_MAKE_ZERO(v30) + + // Compute how much of the 4x4 block of destination 8bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 4x4, there are some 4x4 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + RUY_MAKE_ZERO(v31) + "sub w2, %w[dst_cols], %w[col]\n" + "mov w3, #4\n" + "cmp w1, #4\n" + // Compute w1 = how many rows of the 4x4 block fit + "csel w1, w1, w3, le\n" + "cmp w2, #4\n" + // Compute w2 = how many cols of the 4x4 block fit + "csel w2, w2, w3, le\n" + + // Test if w1==4 && w2 == 4, i.e. if all of the 8x8 block fits. + "cmp w1, w3\n" + "ccmp w2, w3, 0, eq\n" + "mov x4, %[dst_ptr]\n" + // Yes, all of the 4x4 block fits, go to fast path. + "beq 30f\n" + // Not all of the 4x4 block fits. + // Store to dst_tmp_buf + "st1 {v16.16b}, [%[dst_tmp_buf]]\n" + // Slow loop copying from dst_tmp_buf to dst. + "mov x3, %[dst_tmp_buf]\n" + "mov w6, #0\n" + "50:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov w5, #0\n" + "51:\n" + "ldrb w7, [x3, w5, uxtw]\n" + "strb w7, [x4, w5, uxtw]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 51b\n" + "add w6, w6, #1\n" + "add x3, x3, #4\n" + "add x4, x4, x11\n" + "cmp w6, w2\n" + "blt 50b\n" + "b 31f\n" + "30:\n" + // Yes, all of the 4x4 block fits. + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v16.b}[0], [x3], #1\n" + "add x4, x4, x11\n" + "st1 {v16.b}[1], [x3], #1\n" + "st1 {v16.b}[2], [x3], #1\n" + "st1 {v16.b}[3], [x3], #1\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v16.b}[4], [x3], #1\n" + "add x4, x4, x11\n" + "st1 {v16.b}[5], [x3], #1\n" + "st1 {v16.b}[6], [x3], #1\n" + "st1 {v16.b}[7], [x3], #1\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v16.b}[8], [x3], #1\n" + "add x4, x4, x11\n" + "st1 {v16.b}[9], [x3], #1\n" + "st1 {v16.b}[10], [x3], #1\n" + "st1 {v16.b}[11], [x3], #1\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v16.b}[12], [x3], #1\n" + "add x4, x4, x11\n" + "st1 {v16.b}[13], [x3], #1\n" + "st1 {v16.b}[14], [x3], #1\n" + "st1 {v16.b}[15], [x3], #1\n" + "31:\n" + + "add %[dst_ptr], %[dst_ptr], #4\n" + + RUY_MAKE_ZERO(v16) + RUY_MAKE_ZERO(v17) + + "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n" + + // Add the destination zero point + "dup v14.4h, v13.h[4]\n" + "saddw v16.4s, v16.4s, v14.4h\n" + "saddw v17.4s, v17.4s, v14.4h\n" + "saddw v18.4s, v18.4s, v14.4h\n" + "saddw v19.4s, v19.4s, v14.4h\n" + + // Cast-and-saturate from int32 to int16 + "ins v4.d[1], x1\n" + "sqxtn v16.4h, v16.4s\n" + "ins v5.d[1], x2\n" + "sqxtn2 v16.8h, v17.4s\n" + "ins v6.d[1], x3\n" + "sqxtn v17.4h, v18.4s\n" + "ins v7.d[1], x4\n" + RUY_MAKE_ZERO(v18) + "sqxtn2 v17.8h, v19.4s\n" + + // At this point, v18 -- v31 aren't used anymore for the current block, + // so we can start clearing these accumulators for the next block + // (next iteration of the main loop). + RUY_MAKE_ZERO(v19) + + "add %[lhs_ptr], %[lhs_ptr], #64\n" + RUY_MAKE_ZERO(v20) + "add %[rhs_ptr], %[rhs_ptr], #64\n" + RUY_MAKE_ZERO(v21) + RUY_MAKE_ZERO(v22) + + RUY_MAKE_ZERO(v23) + RUY_MAKE_ZERO(v24) + + // Load the clamp_min, clamp_max bounds + "ldrh w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + RUY_MAKE_ZERO(v25) + "ldrh w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + RUY_MAKE_ZERO(v26) + "dup v14.8h, w2\n" // clamp_min + RUY_MAKE_ZERO(v27) + "dup v15.8h, w3\n" // clamp_max + RUY_MAKE_ZERO(v28) + + // Apply the clamp_min bound + "smax v16.8h, v16.8h, v14.8h\n" + "smax v17.8h, v17.8h, v14.8h\n" + RUY_MAKE_ZERO(v29) + // Apply the clamp_max bound + "smin v16.8h, v16.8h, v15.8h\n" + "smin v17.8h, v17.8h, v15.8h\n" + RUY_MAKE_ZERO(v30) + + // Compute how much of the 4x4 block of destination 8bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 4x4, there are some 4x4 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + RUY_MAKE_ZERO(v31) + "sub w2, %w[dst_cols], %w[col]\n" + "mov w3, #4\n" + "cmp w1, #4\n" + // Compute w1 = how many rows of the 4x4 block fit + "csel w1, w1, w3, le\n" + "cmp w2, #4\n" + // Compute w2 = how many cols of the 4x4 block fit + "csel w2, w2, w3, le\n" + + // Test if w1==4 && w2 == 4, i.e. if all of the 4x4 block fits. + "cmp w1, w3\n" + "ccmp w2, w3, 0, eq\n" + "mov x4, %[dst_ptr]\n" + // Yes, all of the 4x4 block fits, go to fast path. + "beq 30f\n" + // Not all of the 4x4 block fits. + // Store to dst_tmp_buf + "str q16, [%[dst_tmp_buf], #0]\n" + "str q17, [%[dst_tmp_buf], #16]\n" + // Slow loop copying from dst_tmp_buf to dst. + "mov x3, %[dst_tmp_buf]\n" + "mov w6, #0\n" + "50:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov w5, #0\n" + "51:\n" + "ldrh w7, [x3, x5, lsl #1]\n" + "strh w7, [x4, x5, lsl #1]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 51b\n" + "add w6, w6, #1\n" + "add x3, x3, #8\n" + "add x4, x4, x11\n" + "cmp w6, w2\n" + "blt 50b\n" + "b 31f\n" + "30:\n" + // Yes, all of the 4x4 block fits. + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v16.h}[0], [x3], #2\n" + "add x4, x4, x11\n" + "st1 {v16.h}[1], [x3], #2\n" + "st1 {v16.h}[2], [x3], #2\n" + "st1 {v16.h}[3], [x3], #2\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v16.h}[4], [x3], #2\n" + "add x4, x4, x11\n" + "st1 {v16.h}[5], [x3], #2\n" + "st1 {v16.h}[6], [x3], #2\n" + "st1 {v16.h}[7], [x3], #2\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v17.h}[0], [x3], #2\n" + "add x4, x4, x11\n" + "st1 {v17.h}[1], [x3], #2\n" + "st1 {v17.h}[2], [x3], #2\n" + "st1 {v17.h}[3], [x3], #2\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v17.h}[4], [x3], #2\n" + "add x4, x4, x11\n" + "st1 {v17.h}[5], [x3], #2\n" + "st1 {v17.h}[6], [x3], #2\n" + "st1 {v17.h}[7], [x3], #2\n" + "31:\n" + + "add %[dst_ptr], %[dst_ptr], #8\n" + + RUY_MAKE_ZERO(v16) + RUY_MAKE_ZERO(v17) + + "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n" + + "ldr x1, [%[lhs_ptr], #8]\n" + "ldr x2, [%[lhs_ptr], #24]\n" + "ldr x3, [%[lhs_ptr], #40]\n" + "ldr x4, [%[lhs_ptr], #56]\n" + + "ins v0.d[1], x1\n" + "ldr x1, [%[rhs_ptr], #8]\n" + "ins v1.d[1], x2\n" + "ldr x2, [%[rhs_ptr], #24]\n" + "ins v2.d[1], x3\n" + "ldr x3, [%[rhs_ptr], #40]\n" + "ins v3.d[1], x4\n" + "ldr x4, [%[rhs_ptr], #56]\n" + "ins v4.d[1], x1\n" + "ins v5.d[1], x2\n" + "ins v6.d[1], x3\n" + "ins v7.d[1], x4\n" + + // Since the store type is the same as the accum type, no need for + // downcast. There's also no need for clamp by min/max. + + // At this point, v20 -- v31 aren't used anymore for the current block, + // so we can start clearing these accumulators for the next block + // (next iteration of the main loop). + + RUY_MAKE_ZERO(v20) + "add %[lhs_ptr], %[lhs_ptr], #64\n" + RUY_MAKE_ZERO(v21) + "add %[rhs_ptr], %[rhs_ptr], #64\n" + RUY_MAKE_ZERO(v22) + + RUY_MAKE_ZERO(v23) + RUY_MAKE_ZERO(v24) + RUY_MAKE_ZERO(v25) + RUY_MAKE_ZERO(v26) + RUY_MAKE_ZERO(v27) + RUY_MAKE_ZERO(v28) + RUY_MAKE_ZERO(v29) + RUY_MAKE_ZERO(v30) + + // Compute how much of the 4x4 block of destination 8bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 4x4, there are some 4x4 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + RUY_MAKE_ZERO(v31) + "sub w2, %w[dst_cols], %w[col]\n" + "mov w3, #4\n" + "cmp w1, #4\n" + // Compute w1 = how many rows of the 4x4 block fit + "csel w1, w1, w3, le\n" + "cmp w2, #4\n" + // Compute w2 = how many cols of the 4x4 block fit + "csel w2, w2, w3, le\n" + + // Test if w1==4 && w2 == 4, i.e. if all of the 8x8 block fits. + "cmp w1, w3\n" + "ccmp w2, w3, 0, eq\n" + "mov x4, %[dst_ptr]\n" + // Yes, all of the 4x4 block fits, go to fast path. + "beq 30f\n" + // Not all of the 4x4 block fits. + // Store to dst_tmp_buf + "str q16, [%[dst_tmp_buf], #0]\n" + "str q17, [%[dst_tmp_buf], #16]\n" + "str q18, [%[dst_tmp_buf], #32]\n" + "str q19, [%[dst_tmp_buf], #48]\n" + // Slow loop copying from dst_tmp_buf to dst. + "mov x3, %[dst_tmp_buf]\n" + "mov w6, #0\n" + "50:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov w5, #0\n" + "51:\n" + "ldr w7, [x3, x5, lsl #2]\n" + "str w7, [x4, x5, lsl #2]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 51b\n" + "add w6, w6, #1\n" + "add x3, x3, #16\n" + "add x4, x4, x11\n" + "cmp w6, w2\n" + "blt 50b\n" + "b 31f\n" + "30:\n" + // Yes, all of the 4x4 block fits. + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v16.s}[0], [x3], #4\n" + "add x4, x4, x11\n" + "st1 {v16.s}[1], [x3], #4\n" + "st1 {v16.s}[2], [x3], #4\n" + "st1 {v16.s}[3], [x3], #4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v17.s}[0], [x3], #4\n" + "add x4, x4, x11\n" + "st1 {v17.s}[1], [x3], #4\n" + "st1 {v17.s}[2], [x3], #4\n" + "st1 {v17.s}[3], [x3], #4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v18.s}[0], [x3], #4\n" + "add x4, x4, x11\n" + "st1 {v18.s}[1], [x3], #4\n" + "st1 {v18.s}[2], [x3], #4\n" + "st1 {v18.s}[3], [x3], #4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v19.s}[0], [x3], #4\n" + "add x4, x4, x11\n" + "st1 {v19.s}[1], [x3], #4\n" + "st1 {v19.s}[2], [x3], #4\n" + "st1 {v19.s}[3], [x3], #4\n" + "31:\n" + + "add %[dst_ptr], %[dst_ptr], #16\n" + + RUY_MAKE_ZERO(v16) + RUY_MAKE_ZERO(v17) + RUY_MAKE_ZERO(v18) + RUY_MAKE_ZERO(v19) + + RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n" + + // For the next block: perform the first few multiply-adds on the data + // that we have already loaded. + "smull v8.8h, v0.8b, v4.8b\n" + "smull v9.8h, v1.8b, v4.8b\n" + "smull v10.8h, v2.8b, v4.8b\n" + // Reload some params --- we had used x5 -- x7 for a few other things + // since the last time we had loaded them. + "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + "smull v11.8h, v3.8b, v4.8b\n" + "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" + "smull v12.8h, v0.8b, v5.8b\n" + "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" + "smull v13.8h, v1.8b, v5.8b\n" + "smull v14.8h, v2.8b, v5.8b\n" + "smull v15.8h, v3.8b, v5.8b\n" + // Move to the next block of the destination matrix, for the next iter + // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already + // been updated earlier. + // Have we reached the end row? + "cmp %w[row], w7\n" + "smlal2 v8.8h, v0.16b, v4.16b\n" + "smlal2 v9.8h, v1.16b, v4.16b\n" + "smlal2 v10.8h, v2.16b, v4.16b\n" + "smlal2 v11.8h, v3.16b, v4.16b\n" + "smlal2 v12.8h, v0.16b, v5.16b\n" + "smlal2 v13.8h, v1.16b, v5.16b\n" + "smlal2 v14.8h, v2.16b, v5.16b\n" + "smlal2 v15.8h, v3.16b, v5.16b\n" + + + "beq 20f\n" // yes, end row. + // Not end row. Move to the next row. + "add %w[row], %w[row], #4\n" + "b 21f\n" + "20:\n" + // Was already at end row. + "mov %w[row], w6\n" // Move back to first row. + "add %w[col], %w[col], #4\n" // Move to the next column. + "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #2\n" + "mov %[dst_ptr], %[dst_col_ptr]\n" + "21:\n" + + // Main loop exit condition: have we hit the end column? + "cmp %w[col], w8\n" + + // w1 is the number of levels of depth that we have already loaded + // LHS and RHS data for. Corresponding to the initial ld1 instructions + // above, this is currently 4. + "mov w1, #16\n" + + "ble 1b\n" + + // clang-format on + + : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr), + [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), + [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col) + : [ params ] "r"(¶ms),[dst_rows] "r"(params.dst_rows), + [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf), + [dst_type_id] "r"(params.dst_type_id) + : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc", + "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", + "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31"); +} + +// Kernel taking advantage of the optional dotprod instruction. +// This is very similar to (and directly inspired by) this gemmlowp kernel +// which was contributed by David Mansell at ARM: +// NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_dotproduct +// https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L3391 +// +// Besides the ruy-ification, the main difference here is that we use a 8x8 +// instead of 12x8 width, so as to stick to power-of-two widths. This slightly +// narrower kernel layout is still wide enough to achieve high performance +// although we haven't actually performed a real comparison to know exactly +// how this compares to ARM's aforementioned kernel. +// +// Relevant target CPUs for this kernel include ARM Cortex-A76, +// since these are 64-bit, out-of-order and with dotprod support. +void Kernel8bitNeonDotprodOutOfOrder(const KernelParams8bit<8, 8>& params) { + profiler::ScopeLabel label( + "Kernel (kNeonDotprod, optimized for out-of-order cores)"); + + CheckOffsetsInKernelParams8bit(params); + + const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; + const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; + const std::int8_t* lhs_ptr = lhs_col_ptr; + const std::int8_t* rhs_ptr = rhs_col_ptr; + void* dst_col_ptr = params.dst_base_ptr; + void* dst_ptr = dst_col_ptr; + int row = params.start_row; + int col = params.start_col; + + // The asm kernel below has the following NEON register allocation: + // + // v16 -- v31 are int32 accumulators. + // During accumulation, v0 -- v15 are used to load int8 data from LHS and + // RHS. At least v0 and v1 are used to load a 8x4 block of LHS, and v2 and + // v3 are used to load a 4x8 block of RHS, like this: + // + // int8 RHS 4x8 block + // /-----------------------------------------\ + // |v2.b[0] ... v2.b[12] v3.b[0] ... v3.b[12]| + // | ... ... | + // |v2.b[3] ... v2.b[15] v3.b[3] ... v3.b[15]| + // \-----------------------------------------/ + // int8 LHS 8x4 block + // /---------------------\ /-----------------------------------------\ + // |v0.b[0] ... v0.b[3] | |v16.s[0] ... v30.s[0]| + // | ... ... | | ... ... | + // |v0.b[12] ... v0.b[15]| |v16.s[3] ... v30.s[3]| + // |v1.b[0] ... v1.b[3] | |v17.s[0] ... v31.s[0]| + // | ... ... | | ... ... | + // |v1.b[12] ... v1.b[15]| |v17.s[3] ... v31.s[3]| + // \---------------------/ \-----------------------------------------/ + // int32 accumulators 8x8 block + // + // In the RUY_OPT_MAX_STREAMING part of the kernel, this elementary step + // is repeated 4 times, using 4x more registers for LHS and RHS, so that + // is where instead of using v0 -- v3 for LHS and RHS, we use v0 -- v15. + // + // Outside of the RUY_OPT_MAX_STREAMING part of the kernel, v4 -- v7 are + // unused, and v8 -- v15 are used for loading parameters used for the + // post-accumulation part of the kernel. + asm volatile( +#define RUY_MAKE_ZERO(reg) "dup " #reg ".4s, wzr\n" + + // clang-format off + + // Load some parameters into registers. + "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" + "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" + "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" + "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" + "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" + "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" + "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" + + // Load the first 32 bytes of LHS and RHS data. + "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v2.16b}, [%[rhs_ptr]], #16\n" + "ld1 {v3.16b}, [%[rhs_ptr]], #16\n" + + // Clear accumulators. + RUY_MAKE_ZERO(v16) + RUY_MAKE_ZERO(v17) + RUY_MAKE_ZERO(v18) + RUY_MAKE_ZERO(v19) + RUY_MAKE_ZERO(v20) + RUY_MAKE_ZERO(v21) + RUY_MAKE_ZERO(v22) + RUY_MAKE_ZERO(v23) + RUY_MAKE_ZERO(v24) + RUY_MAKE_ZERO(v25) + RUY_MAKE_ZERO(v26) + RUY_MAKE_ZERO(v27) + RUY_MAKE_ZERO(v28) + RUY_MAKE_ZERO(v29) + RUY_MAKE_ZERO(v30) + RUY_MAKE_ZERO(v31) + + // w1 is the number of levels of depth that we have already loaded + // LHS and RHS data for. Corresponding to the initial ld1 instructions + // above, this is currently 4. + "mov w1, #4\n" + + // Perform the first few multiply-adds on the data that we have already + // loaded. + ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" + ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" + ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" + ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" + + // Main loop of the whole GEMM, over rows and columns of the + // destination matrix. + "1:\n" + + // Optional, maximally-streaming, partial-unrolling (4x unrolled) + // optimization of the kernel inner loop (over depth). For more + // comments, see the non-unrolled loop below after the #endif. +#if RUY_OPT_ENABLED(RUY_OPT_MAX_STREAMING) + "cmp w12, #32\n" + "blt 78f\n" + + "ld1 {v4.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v5.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v6.16b}, [%[rhs_ptr]], #16\n" + "ld1 {v7.16b}, [%[rhs_ptr]], #16\n" + "ld1 {v8.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v9.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v10.16b}, [%[rhs_ptr]], #16\n" + "ld1 {v11.16b}, [%[rhs_ptr]], #16\n" + "ld1 {v12.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v13.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v14.16b}, [%[rhs_ptr]], #16\n" + "ld1 {v15.16b}, [%[rhs_ptr]], #16\n" + "mov w1, #16\n" + + "and w3, w12, #-16\n" + "81:\n" + "add w1, w1, #16\n" + + ".word 0x4f83e018 // sdot v24.4s, v0.16b, v3.4b[0]\n" + ".word 0x4fa3e01a // sdot v26.4s, v0.16b, v3.4b[1]\n" + ".word 0x4f83e81c // sdot v28.4s, v0.16b, v3.4b[2]\n" + ".word 0x4fa3e81e // sdot v30.4s, v0.16b, v3.4b[3]\n" + "ldr q0, [%[lhs_ptr], #0]\n" + ".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n" + ".word 0x4fa2e033 // sdot v19.4s, v1.16b, v2.4b[1]\n" + ".word 0x4f82e835 // sdot v21.4s, v1.16b, v2.4b[2]\n" + ".word 0x4fa2e837 // sdot v23.4s, v1.16b, v2.4b[3]\n" + "ldr q2, [%[rhs_ptr], #0]\n" + ".word 0x4f83e039 // sdot v25.4s, v1.16b, v3.4b[0]\n" + ".word 0x4fa3e03b // sdot v27.4s, v1.16b, v3.4b[1]\n" + ".word 0x4f83e83d // sdot v29.4s, v1.16b, v3.4b[2]\n" + ".word 0x4fa3e83f // sdot v31.4s, v1.16b, v3.4b[3]\n" + "ldr q1, [%[lhs_ptr], #16]\n" + + ".word 0x4f87e098 // sdot v24.4s, v4.16b, v7.4b[0]\n" + ".word 0x4fa7e09a // sdot v26.4s, v4.16b, v7.4b[1]\n" + "ldr q3, [%[rhs_ptr], #16]\n" + ".word 0x4f87e89c // sdot v28.4s, v4.16b, v7.4b[2]\n" + ".word 0x4fa7e89e // sdot v30.4s, v4.16b, v7.4b[3]\n" + ".word 0x4f86e0b1 // sdot v17.4s, v5.16b, v6.4b[0]\n" + ".word 0x4fa6e0b3 // sdot v19.4s, v5.16b, v6.4b[1]\n" + ".word 0x4f86e8b5 // sdot v21.4s, v5.16b, v6.4b[2]\n" + ".word 0x4fa6e8b7 // sdot v23.4s, v5.16b, v6.4b[3]\n" + ".word 0x4f87e0b9 // sdot v25.4s, v5.16b, v7.4b[0]\n" + ".word 0x4fa7e0bb // sdot v27.4s, v5.16b, v7.4b[1]\n" + ".word 0x4f87e8bd // sdot v29.4s, v5.16b, v7.4b[2]\n" + ".word 0x4fa7e8bf // sdot v31.4s, v5.16b, v7.4b[3]\n" + "ldr q5, [%[lhs_ptr], #48]\n" + ".word 0x4f86e090 // sdot v16.4s, v4.16b, v6.4b[0]\n" + ".word 0x4fa6e092 // sdot v18.4s, v4.16b, v6.4b[1]\n" + "ldr q7, [%[rhs_ptr], #48]\n" + ".word 0x4f86e894 // sdot v20.4s, v4.16b, v6.4b[2]\n" + ".word 0x4fa6e896 // sdot v22.4s, v4.16b, v6.4b[3]\n" + "ldr q4, [%[lhs_ptr], #32]\n" + + ".word 0x4f8be118 // sdot v24.4s, v8.16b, v11.4b[0]\n" + ".word 0x4fabe11a // sdot v26.4s, v8.16b, v11.4b[1]\n" + "ldr q6, [%[rhs_ptr], #32]\n" + ".word 0x4f8be91c // sdot v28.4s, v8.16b, v11.4b[2]\n" + ".word 0x4fabe91e // sdot v30.4s, v8.16b, v11.4b[3]\n" + ".word 0x4f8ae131 // sdot v17.4s, v9.16b, v10.4b[0]\n" + ".word 0x4faae133 // sdot v19.4s, v9.16b, v10.4b[1]\n" + ".word 0x4f8ae935 // sdot v21.4s, v9.16b, v10.4b[2]\n" + ".word 0x4faae937 // sdot v23.4s, v9.16b, v10.4b[3]\n" + ".word 0x4f8be139 // sdot v25.4s, v9.16b, v11.4b[0]\n" + ".word 0x4fabe13b // sdot v27.4s, v9.16b, v11.4b[1]\n" + ".word 0x4f8be93d // sdot v29.4s, v9.16b, v11.4b[2]\n" + ".word 0x4fabe93f // sdot v31.4s, v9.16b, v11.4b[3]\n" + "ldr q9, [%[lhs_ptr], #80]\n" + ".word 0x4f8ae110 // sdot v16.4s, v8.16b, v10.4b[0]\n" + ".word 0x4faae112 // sdot v18.4s, v8.16b, v10.4b[1]\n" + "ldr q11, [%[rhs_ptr], #80]\n" + ".word 0x4f8ae914 // sdot v20.4s, v8.16b, v10.4b[2]\n" + ".word 0x4faae916 // sdot v22.4s, v8.16b, v10.4b[3]\n" + "ldr q8, [%[lhs_ptr], #64]\n" + + ".word 0x4f8fe198 // sdot v24.4s, v12.16b, v15.4b[0]\n" + ".word 0x4fafe19a // sdot v26.4s, v12.16b, v15.4b[1]\n" + "ldr q10, [%[rhs_ptr], #64]\n" + ".word 0x4f8fe99c // sdot v28.4s, v12.16b, v15.4b[2]\n" + ".word 0x4fafe99e // sdot v30.4s, v12.16b, v15.4b[3]\n" + "add %[lhs_ptr], %[lhs_ptr], #128\n" + ".word 0x4f8ee1b1 // sdot v17.4s, v13.16b, v14.4b[0]\n" + ".word 0x4faee1b3 // sdot v19.4s, v13.16b, v14.4b[1]\n" + "add %[rhs_ptr], %[rhs_ptr], #128\n" + ".word 0x4f8ee9b5 // sdot v21.4s, v13.16b, v14.4b[2]\n" + ".word 0x4faee9b7 // sdot v23.4s, v13.16b, v14.4b[3]\n" + ".word 0x4f8fe1b9 // sdot v25.4s, v13.16b, v15.4b[0]\n" + ".word 0x4fafe1bb // sdot v27.4s, v13.16b, v15.4b[1]\n" + "cmp w1, w3\n" + ".word 0x4f8fe9bd // sdot v29.4s, v13.16b, v15.4b[2]\n" + ".word 0x4fafe9bf // sdot v31.4s, v13.16b, v15.4b[3]\n" + "ldr q13, [%[lhs_ptr], #-16]\n" + ".word 0x4f8ee190 // sdot v16.4s, v12.16b, v14.4b[0]\n" + ".word 0x4faee192 // sdot v18.4s, v12.16b, v14.4b[1]\n" + "ldr q15, [%[rhs_ptr], #-16]\n" + ".word 0x4f8ee994 // sdot v20.4s, v12.16b, v14.4b[2]\n" + ".word 0x4faee996 // sdot v22.4s, v12.16b, v14.4b[3]\n" + "ldr q12, [%[lhs_ptr], #-32]\n" + + ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" + ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" + "ldr q14, [%[rhs_ptr], #-32]\n" + ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" + ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" + + "blt 81b\n" + + ".word 0x4f87e098 // sdot v24.4s, v4.16b, v7.4b[0]\n" + ".word 0x4fa7e09a // sdot v26.4s, v4.16b, v7.4b[1]\n" + ".word 0x4f87e89c // sdot v28.4s, v4.16b, v7.4b[2]\n" + ".word 0x4fa7e89e // sdot v30.4s, v4.16b, v7.4b[3]\n" + ".word 0x4f86e0b1 // sdot v17.4s, v5.16b, v6.4b[0]\n" + ".word 0x4fa6e0b3 // sdot v19.4s, v5.16b, v6.4b[1]\n" + ".word 0x4f86e8b5 // sdot v21.4s, v5.16b, v6.4b[2]\n" + ".word 0x4fa6e8b7 // sdot v23.4s, v5.16b, v6.4b[3]\n" + ".word 0x4f87e0b9 // sdot v25.4s, v5.16b, v7.4b[0]\n" + ".word 0x4fa7e0bb // sdot v27.4s, v5.16b, v7.4b[1]\n" + ".word 0x4f87e8bd // sdot v29.4s, v5.16b, v7.4b[2]\n" + ".word 0x4fa7e8bf // sdot v31.4s, v5.16b, v7.4b[3]\n" + ".word 0x4f86e090 // sdot v16.4s, v4.16b, v6.4b[0]\n" + ".word 0x4fa6e092 // sdot v18.4s, v4.16b, v6.4b[1]\n" + ".word 0x4f86e894 // sdot v20.4s, v4.16b, v6.4b[2]\n" + ".word 0x4fa6e896 // sdot v22.4s, v4.16b, v6.4b[3]\n" + + ".word 0x4f8be118 // sdot v24.4s, v8.16b, v11.4b[0]\n" + ".word 0x4fabe11a // sdot v26.4s, v8.16b, v11.4b[1]\n" + ".word 0x4f8be91c // sdot v28.4s, v8.16b, v11.4b[2]\n" + ".word 0x4fabe91e // sdot v30.4s, v8.16b, v11.4b[3]\n" + ".word 0x4f8ae131 // sdot v17.4s, v9.16b, v10.4b[0]\n" + ".word 0x4faae133 // sdot v19.4s, v9.16b, v10.4b[1]\n" + ".word 0x4f8ae935 // sdot v21.4s, v9.16b, v10.4b[2]\n" + ".word 0x4faae937 // sdot v23.4s, v9.16b, v10.4b[3]\n" + ".word 0x4f8be139 // sdot v25.4s, v9.16b, v11.4b[0]\n" + ".word 0x4fabe13b // sdot v27.4s, v9.16b, v11.4b[1]\n" + ".word 0x4f8be93d // sdot v29.4s, v9.16b, v11.4b[2]\n" + ".word 0x4fabe93f // sdot v31.4s, v9.16b, v11.4b[3]\n" + ".word 0x4f8ae110 // sdot v16.4s, v8.16b, v10.4b[0]\n" + ".word 0x4faae112 // sdot v18.4s, v8.16b, v10.4b[1]\n" + ".word 0x4f8ae914 // sdot v20.4s, v8.16b, v10.4b[2]\n" + ".word 0x4faae916 // sdot v22.4s, v8.16b, v10.4b[3]\n" + + ".word 0x4f8fe198 // sdot v24.4s, v12.16b, v15.4b[0]\n" + ".word 0x4fafe19a // sdot v26.4s, v12.16b, v15.4b[1]\n" + ".word 0x4f8fe99c // sdot v28.4s, v12.16b, v15.4b[2]\n" + ".word 0x4fafe99e // sdot v30.4s, v12.16b, v15.4b[3]\n" + ".word 0x4f8ee1b1 // sdot v17.4s, v13.16b, v14.4b[0]\n" + ".word 0x4faee1b3 // sdot v19.4s, v13.16b, v14.4b[1]\n" + ".word 0x4f8ee9b5 // sdot v21.4s, v13.16b, v14.4b[2]\n" + ".word 0x4faee9b7 // sdot v23.4s, v13.16b, v14.4b[3]\n" + ".word 0x4f8fe1b9 // sdot v25.4s, v13.16b, v15.4b[0]\n" + ".word 0x4fafe1bb // sdot v27.4s, v13.16b, v15.4b[1]\n" + ".word 0x4f8fe9bd // sdot v29.4s, v13.16b, v15.4b[2]\n" + ".word 0x4fafe9bf // sdot v31.4s, v13.16b, v15.4b[3]\n" + ".word 0x4f8ee190 // sdot v16.4s, v12.16b, v14.4b[0]\n" + ".word 0x4faee192 // sdot v18.4s, v12.16b, v14.4b[1]\n" + ".word 0x4f8ee994 // sdot v20.4s, v12.16b, v14.4b[2]\n" + ".word 0x4faee996 // sdot v22.4s, v12.16b, v14.4b[3]\n" + + "78:\n" + +#endif // #if RUY_OPT_ENABLED(RUY_OPT_MAX_STREAMING) + + // Ordinary kernel inner loop (over depth), the simpler loop that the + // above was an equivalent 4x-partially-unrolled version of. + + // Reminder - w1 is how many levels of depth we have already loaded + // data for, w12 is the total depth. + "cmp w1, w12\n" + "beq 79f\n" + + "2:\n" + + // Because of the data that we have already loaded, we can start the + // loop body right away with some multiply-adds. + ".word 0x4f83e018 // sdot v24.4s, v0.16b, v3.4b[0]\n" + ".word 0x4fa3e01a // sdot v26.4s, v0.16b, v3.4b[1]\n" + // Each iteration of this loop advances by 4 levels of depth. + "add w1, w1, #4\n" + ".word 0x4f83e81c // sdot v28.4s, v0.16b, v3.4b[2]\n" + ".word 0x4fa3e81e // sdot v30.4s, v0.16b, v3.4b[3]\n" + "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" + ".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n" + ".word 0x4fa2e033 // sdot v19.4s, v1.16b, v2.4b[1]\n" + // Loop termination condition. + "cmp w1, w12\n" + ".word 0x4f82e835 // sdot v21.4s, v1.16b, v2.4b[2]\n" + ".word 0x4fa2e837 // sdot v23.4s, v1.16b, v2.4b[3]\n" + "ld1 {v2.16b}, [%[rhs_ptr]], #16\n" + ".word 0x4f83e039 // sdot v25.4s, v1.16b, v3.4b[0]\n" + ".word 0x4fa3e03b // sdot v27.4s, v1.16b, v3.4b[1]\n" + ".word 0x4f83e83d // sdot v29.4s, v1.16b, v3.4b[2]\n" + ".word 0x4fa3e83f // sdot v31.4s, v1.16b, v3.4b[3]\n" + "ld1 {v3.16b}, [%[rhs_ptr]], #16\n" + ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" + ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" + ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" + ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" + "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" + + "blt 2b\n" + + "79:\n" + // End of the inner loop on depth. Now perform the remaining + // multiply-adds of the last 4 levels of depth, for which the LHS + // and RHS data is already loaded. + + ".word 0x4f83e018 // sdot v24.4s, v0.16b, v3.4b[0]\n" + ".word 0x4fa3e01a // sdot v26.4s, v0.16b, v3.4b[1]\n" + ".word 0x4f83e81c // sdot v28.4s, v0.16b, v3.4b[2]\n" + ".word 0x4fa3e81e // sdot v30.4s, v0.16b, v3.4b[3]\n" + ".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n" + ".word 0x4fa2e033 // sdot v19.4s, v1.16b, v2.4b[1]\n" + ".word 0x4f82e835 // sdot v21.4s, v1.16b, v2.4b[2]\n" + ".word 0x4fa2e837 // sdot v23.4s, v1.16b, v2.4b[3]\n" + ".word 0x4f83e039 // sdot v25.4s, v1.16b, v3.4b[0]\n" + ".word 0x4fa3e03b // sdot v27.4s, v1.16b, v3.4b[1]\n" + ".word 0x4f83e83d // sdot v29.4s, v1.16b, v3.4b[2]\n" + ".word 0x4fa3e83f // sdot v31.4s, v1.16b, v3.4b[3]\n" + + // End of accumulation. The registers v16 -- v31 contain the final + // int32 accumulator values of the current 8x8 destination block. + // We now have to compute the final 8-bit values from these int32 + // accumulators, and advance to the next 8x8 block. We intertwine + // these two aspects whenever possible for optimal pipelining, both + // at the data flow level (prefetch data for next block as early as + // possible) and instruction pipelining level (some of the next-block + // work can dual-issue with some of the final work on the current + // block). + + // Logic to advance to the next block in preparation for the next + // iteration of the main loop. For now, we only want to compute + // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are + // not yet ready to update the values of row and col, as we still need + // the current values for the rest of the work on the current block. + + "cmp %w[row], w7\n" // Have we finished the last row? + "bge 4f\n" // If finished last row, go to 4 + // Not finished last row: then advance to next row. + "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #3\n" + "b 5f\n" + "4:\n" // Finished last row... + "mov %[lhs_col_ptr], x5\n" // Go back to first row + // Now we need to advance to the next column. If we already + // finished the last column, then in principle we are done, however + // we can't just return here, as we need to allow the end work of the + // current block to complete. The good news is that at this point it + // doesn't matter what data we load for the next column, since + // we will exit from the main loop below before actually storing + // anything computed from that data. + "cmp %w[col], w8\n" // Have we finished the last column? + "bge 5f\n" // If yes, just carry on without updating the column pointer. + // Not finished last column: then advance to next column. + "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #3\n" + "5:\n" + + // Set the LHS and RHS data pointers to the start of the columns just + // computed. + "mov %[lhs_ptr], %[lhs_col_ptr]\n" + "mov %[rhs_ptr], %[rhs_col_ptr]\n" + + // Load some parameters needed for the end work on current block. + RUY_MAKE_ZERO(v8) + "ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" + "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n" + "ins v13.h[4], w4\n" // dst_zero_point + "ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n" + "ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" + "dup v9.4s, w3\n" // create prod_zp_depth_vec + "add x5, x4, %x[row], lsl #2\n" + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" + "csel x4, x4, x5, eq\n" + + "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" + "add x5, x1, %x[row], lsl #2\n" + + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" + "csel x1, x1, x5, eq\n" + + // Load 8 bias values. + "ld1 {v14.4s}, [x1], #16\n" + "ld1 {v15.4s}, [x1]\n" + + // Now that we know what LHS and RHS data the next iteration of the + // main loop will need to load, we start loading the first 32 bytes of + // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore + // in the rest of the work on the current block. + "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v2.16b}, [%[rhs_ptr]], #16\n" + "ld1 {v3.16b}, [%[rhs_ptr]], #16\n" + + // Add to the bias values the product (depth * lhs_zero_point * rhs_zero_point), + // See the term NZ1Z2 in equation (7) in https://arxiv.org/pdf/1712.05877.pdf + "add v14.4s, v14.4s, v9.4s\n" + "add v15.4s, v15.4s, v9.4s\n" + + // Perform the bias-addition (per the above, we have just folded into + // the bias the (depth * lhs_zero_point * rhs_zero_point) term.) + "add v16.4s, v16.4s, v14.4s\n" + "add v17.4s, v17.4s, v15.4s\n" + "add v18.4s, v18.4s, v14.4s\n" + "add v19.4s, v19.4s, v15.4s\n" + "add v20.4s, v20.4s, v14.4s\n" + "add v21.4s, v21.4s, v15.4s\n" + "add v22.4s, v22.4s, v14.4s\n" + "add v23.4s, v23.4s, v15.4s\n" + "add v24.4s, v24.4s, v14.4s\n" + "add v25.4s, v25.4s, v15.4s\n" + "add v26.4s, v26.4s, v14.4s\n" + "add v27.4s, v27.4s, v15.4s\n" + "add v28.4s, v28.4s, v14.4s\n" + "add v29.4s, v29.4s, v15.4s\n" + "add v30.4s, v30.4s, v14.4s\n" + "add v31.4s, v31.4s, v15.4s\n" + + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n" + "beq 401f\n" + "ldr x3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n" + "add x3, x3, %x[col], lsl #2\n" + "ld1 {v14.4s}, [x3], #16\n" + "ld1 {v15.4s}, [x3]\n" + "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n" + "dup v10.4s, w5\n" // create lhs_zero_point_vec + // Subtract rhs_sums * lhs_zero_point, per + // equation (7) in https://arxiv.org/pdf/1712.05877.pdf + "mls v16.4s, v10.4s, v14.s[0]\n" + "mls v17.4s, v10.4s, v14.s[0]\n" + "mls v18.4s, v10.4s, v14.s[1]\n" + "mls v19.4s, v10.4s, v14.s[1]\n" + "mls v20.4s, v10.4s, v14.s[2]\n" + "mls v21.4s, v10.4s, v14.s[2]\n" + "mls v22.4s, v10.4s, v14.s[3]\n" + "mls v23.4s, v10.4s, v14.s[3]\n" + "mls v24.4s, v10.4s, v15.s[0]\n" + "mls v25.4s, v10.4s, v15.s[0]\n" + "mls v26.4s, v10.4s, v15.s[1]\n" + "mls v27.4s, v10.4s, v15.s[1]\n" + "mls v28.4s, v10.4s, v15.s[2]\n" + "mls v29.4s, v10.4s, v15.s[2]\n" + "mls v30.4s, v10.4s, v15.s[3]\n" + "mls v31.4s, v10.4s, v15.s[3]\n" + "401:\n" + + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n" + "beq 402f\n" + "ldr x2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n" + "add x2, x2, %x[row], lsl #2\n" + "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n" + // Load 4 lhs_sums values. + "ld1 {v11.4s}, [x2], #16\n" + "ld1 {v12.4s}, [x2]\n" + "ins v13.s[1], w5\n" // rhs_zero_point + // Compute lhs_sums * rhs_zero_point. + "mul v11.4s, v11.4s, v13.s[1]\n" + "mul v12.4s, v12.4s, v13.s[1]\n" + // Subtract lhs_sums * rhs_zero_point, per + // equation (7) in https://arxiv.org/pdf/1712.05877.pdf + "sub v16.4s, v16.4s, v11.4s\n" + "sub v17.4s, v17.4s, v12.4s\n" + "sub v18.4s, v18.4s, v11.4s\n" + "sub v19.4s, v19.4s, v12.4s\n" + "sub v20.4s, v20.4s, v11.4s\n" + "sub v21.4s, v21.4s, v12.4s\n" + "sub v22.4s, v22.4s, v11.4s\n" + "sub v23.4s, v23.4s, v12.4s\n" + "sub v24.4s, v24.4s, v11.4s\n" + "sub v25.4s, v25.4s, v12.4s\n" + "sub v26.4s, v26.4s, v11.4s\n" + "sub v27.4s, v27.4s, v12.4s\n" + "sub v28.4s, v28.4s, v11.4s\n" + "sub v29.4s, v29.4s, v12.4s\n" + "sub v30.4s, v30.4s, v11.4s\n" + "sub v31.4s, v31.4s, v12.4s\n" + + "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n" + "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n" + + "402:\n" + + // At this point we have computed the final int32 values. Now we + // start down-quantizing them to obtain the final 8bit values from them. + + // As part of this down-quantization, our int32 values will be + // multiplied by a multiplier that has a fixed-point component and an + // exponent component. + + //Load the exponent part of the multiplier. + "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n" + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" + "add x5, x1, %x[row], lsl #2\n" + "csel x1, x1, x5, eq\n" + + "ldr q9, [x1]\n" + "ldr q10, [x1, #16]\n" + + "tst w6, #" RUY_STR(RUY_ASM_FLAG_NEEDS_LEFT_SHIFT) "\n" + "beq 403f\n" + "smax v11.4s, v9.4s, v8.4s\n" + "smax v12.4s, v10.4s, v8.4s\n" + "sshl v16.4s, v16.4s, v11.4s\n" + "sshl v17.4s, v17.4s, v12.4s\n" + "sshl v18.4s, v18.4s, v11.4s\n" + "sshl v19.4s, v19.4s, v12.4s\n" + "sshl v20.4s, v20.4s, v11.4s\n" + "sshl v21.4s, v21.4s, v12.4s\n" + "sshl v22.4s, v22.4s, v11.4s\n" + "sshl v23.4s, v23.4s, v12.4s\n" + "sshl v24.4s, v24.4s, v11.4s\n" + "sshl v25.4s, v25.4s, v12.4s\n" + "sshl v26.4s, v26.4s, v11.4s\n" + "sshl v27.4s, v27.4s, v12.4s\n" + "sshl v28.4s, v28.4s, v11.4s\n" + "sshl v29.4s, v29.4s, v12.4s\n" + "sshl v30.4s, v30.4s, v11.4s\n" + "sshl v31.4s, v31.4s, v12.4s\n" + "403:\n" + + "ldr q14, [x4]\n" // multiplier_fixedpoint + "ldr q15, [x4, #16]\n" // multiplier_fixedpoint + + "smin v11.4s, v9.4s, v8.4s\n" + "smin v12.4s, v10.4s, v8.4s\n" + + // Apply the fixed-point part of the multiplier. + "sqrdmulh v16.4s, v16.4s, v14.4s\n" + "sqrdmulh v17.4s, v17.4s, v15.4s\n" + "sqrdmulh v18.4s, v18.4s, v14.4s\n" + "sqrdmulh v19.4s, v19.4s, v15.4s\n" + "sqrdmulh v20.4s, v20.4s, v14.4s\n" + "sqrdmulh v21.4s, v21.4s, v15.4s\n" + "sqrdmulh v22.4s, v22.4s, v14.4s\n" + "sqrdmulh v23.4s, v23.4s, v15.4s\n" + "sqrdmulh v24.4s, v24.4s, v14.4s\n" + "sqrdmulh v25.4s, v25.4s, v15.4s\n" + "sqrdmulh v26.4s, v26.4s, v14.4s\n" + "sqrdmulh v27.4s, v27.4s, v15.4s\n" + "sqrdmulh v28.4s, v28.4s, v14.4s\n" + "sqrdmulh v29.4s, v29.4s, v15.4s\n" + "sqrdmulh v30.4s, v30.4s, v14.4s\n" + "sqrdmulh v31.4s, v31.4s, v15.4s\n" + + // We have some rounding division-by-power-of-two to do. This should + // always use "round to nearest". We allow for some + // freedom in how ties are broken, to strike a good compromise of + // performance on given hardware vs. perfect agreement of results + // across hardware. + // + // When RUY_OPT_NATIVE_ROUNDING is enabled, we allow for implementation + // defined tie-breaks to help performance. On NEON, this means that we + // can just use the NEON rounding instructions, such as srshl. They + // happen to be breaking ties upward. + // + // When RUY_OPT_NATIVE_ROUNDING is disabled, we implement strict + // break-ties-away-from zero, as described in Appendix B of + // https://arxiv.org/pdf/1712.05877.pdf + // When we wrote that, we thought that that would be better unbiased + // than the NEON upwards tie-breaks, and we had observed some + // improvement on some model. However, that is only more unbiased for + // data centered at zero, which was likely the case in that model, + // but is not always the case. If we wanted something more consistently + // unbiased then we should try breaking ties toward-nearest-even. +#if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING) + // Fix up values to be right-shifted, so that the (round to nearest, + // break ties upward) behavior of srshl applied to these fixed-up + // values, produces the same result as the desired (round to nearest, + // break ties away from zero) behavior on the original values. + "and v8.16b, v16.16b, v11.16b\n" + "and v9.16b, v17.16b, v12.16b\n" + "and v14.16b, v18.16b, v11.16b\n" + "and v15.16b, v19.16b, v12.16b\n" + "sshr v8.4s, v8.4s, #31\n" + "sshr v9.4s, v9.4s, #31\n" + "sshr v14.4s, v14.4s, #31\n" + "sshr v15.4s, v15.4s, #31\n" + "sqadd v16.4s, v16.4s, v8.4s\n" + "sqadd v17.4s, v17.4s, v9.4s\n" + "sqadd v18.4s, v18.4s, v14.4s\n" + "sqadd v19.4s, v19.4s, v15.4s\n" + "and v8.16b, v20.16b, v11.16b\n" + "and v9.16b, v21.16b, v12.16b\n" + "and v14.16b, v22.16b, v11.16b\n" + "and v15.16b, v23.16b, v12.16b\n" + "sshr v8.4s, v8.4s, #31\n" + "sshr v9.4s, v9.4s, #31\n" + "sshr v14.4s, v14.4s, #31\n" + "sshr v15.4s, v15.4s, #31\n" + "sqadd v20.4s, v20.4s, v8.4s\n" + "sqadd v21.4s, v21.4s, v9.4s\n" + "sqadd v22.4s, v22.4s, v14.4s\n" + "sqadd v23.4s, v23.4s, v15.4s\n" + "and v8.16b, v24.16b, v11.16b\n" + "and v9.16b, v25.16b, v12.16b\n" + "and v14.16b, v26.16b, v11.16b\n" + "and v15.16b, v27.16b, v12.16b\n" + "sshr v8.4s, v8.4s, #31\n" + "sshr v9.4s, v9.4s, #31\n" + "sshr v14.4s, v14.4s, #31\n" + "sshr v15.4s, v15.4s, #31\n" + "sqadd v24.4s, v24.4s, v8.4s\n" + "sqadd v25.4s, v25.4s, v9.4s\n" + "sqadd v26.4s, v26.4s, v14.4s\n" + "sqadd v27.4s, v27.4s, v15.4s\n" + "and v8.16b, v28.16b, v11.16b\n" + "and v9.16b, v29.16b, v12.16b\n" + "and v14.16b, v30.16b, v11.16b\n" + "and v15.16b, v31.16b, v12.16b\n" + "sshr v8.4s, v8.4s, #31\n" + "sshr v9.4s, v9.4s, #31\n" + "sshr v14.4s, v14.4s, #31\n" + "sshr v15.4s, v15.4s, #31\n" + "sqadd v28.4s, v28.4s, v8.4s\n" + "sqadd v29.4s, v29.4s, v9.4s\n" + "sqadd v30.4s, v30.4s, v14.4s\n" + "sqadd v31.4s, v31.4s, v15.4s\n" +#endif + // At this point we have reduced the problem of correctly implementing + // rounding divide-by-power-of-two, to what the SRSHL instruction can + // do. + "srshl v16.4s, v16.4s, v11.4s\n" + "srshl v17.4s, v17.4s, v12.4s\n" + "srshl v18.4s, v18.4s, v11.4s\n" + "srshl v19.4s, v19.4s, v12.4s\n" + "srshl v20.4s, v20.4s, v11.4s\n" + "srshl v21.4s, v21.4s, v12.4s\n" + "srshl v22.4s, v22.4s, v11.4s\n" + "srshl v23.4s, v23.4s, v12.4s\n" + "srshl v24.4s, v24.4s, v11.4s\n" + "srshl v25.4s, v25.4s, v12.4s\n" + "srshl v26.4s, v26.4s, v11.4s\n" + "srshl v27.4s, v27.4s, v12.4s\n" + "srshl v28.4s, v28.4s, v11.4s\n" + "srshl v29.4s, v29.4s, v12.4s\n" + "srshl v30.4s, v30.4s, v11.4s\n" + "srshl v31.4s, v31.4s, v12.4s\n" + + "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n" + "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n" + "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n" + "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n" + + // Cast-and-saturate from int32 to int16 + "sqxtn v16.4h, v16.4s\n" + "sqxtn2 v16.8h, v17.4s\n" + "sqxtn v17.4h, v18.4s\n" + "sqxtn2 v17.8h, v19.4s\n" + "sqxtn v18.4h, v20.4s\n" + "sqxtn2 v18.8h, v21.4s\n" + "sqxtn v19.4h, v22.4s\n" + "sqxtn2 v19.8h, v23.4s\n" + "sqxtn v20.4h, v24.4s\n" + "sqxtn2 v20.8h, v25.4s\n" + "sqxtn v21.4h, v26.4s\n" + "sqxtn2 v21.8h, v27.4s\n" + "sqxtn v22.4h, v28.4s\n" + "sqxtn2 v22.8h, v29.4s\n" + "sqxtn v23.4h, v30.4s\n" + "sqxtn2 v23.8h, v31.4s\n" + + // At this point, v24 -- v31 aren't used anymore for the current block, + // so we can start clearing these accumulators for the next block + // (next iteration of the main loop). + RUY_MAKE_ZERO(v24) + RUY_MAKE_ZERO(v25) + RUY_MAKE_ZERO(v26) + RUY_MAKE_ZERO(v27) + RUY_MAKE_ZERO(v28) + RUY_MAKE_ZERO(v29) + RUY_MAKE_ZERO(v30) + RUY_MAKE_ZERO(v31) + + // Add the destination zero point + "dup v14.8h, v13.h[4]\n" + "add v16.8h, v16.8h, v14.8h\n" + "add v17.8h, v17.8h, v14.8h\n" + "add v18.8h, v18.8h, v14.8h\n" + "add v19.8h, v19.8h, v14.8h\n" + "add v20.8h, v20.8h, v14.8h\n" + "add v21.8h, v21.8h, v14.8h\n" + "add v22.8h, v22.8h, v14.8h\n" + "add v23.8h, v23.8h, v14.8h\n" + + // Cast-and-saturate from int16 to uint8 + "sqxtun v16.8b, v16.8h\n" + "sqxtun2 v16.16b, v17.8h\n" + "sqxtun v17.8b, v18.8h\n" + "sqxtun2 v17.16b, v19.8h\n" + "sqxtun v18.8b, v20.8h\n" + "sqxtun2 v18.16b, v21.8h\n" + "sqxtun v19.8b, v22.8h\n" + "sqxtun2 v19.16b, v23.8h\n" + + // Load the clamp_min, clamp_max bounds + "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + "dup v14.16b, w2\n" // clamp_min + "dup v15.16b, w3\n" // clamp_max + + // Apply the clamp_min bound + "umax v16.16b, v16.16b, v14.16b\n" + "umax v17.16b, v17.16b, v14.16b\n" + "umax v18.16b, v18.16b, v14.16b\n" + "umax v19.16b, v19.16b, v14.16b\n" + + // Apply the clamp_max bound + "umin v16.16b, v16.16b, v15.16b\n" + "umin v17.16b, v17.16b, v15.16b\n" + "umin v18.16b, v18.16b, v15.16b\n" + "umin v19.16b, v19.16b, v15.16b\n" + + // Make it so that all of the final 8bit values are stored in the + // first 64bits of 128bit NEON registers, so they can be stored + // by 64bit st1 store instructions with byte alignment. + "dup d20, v16.d[1]\n" + "dup d21, v17.d[1]\n" + "dup d22, v18.d[1]\n" + "dup d23, v19.d[1]\n" + + // Compute how much of the 8x8 block of destination 8bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 8x8, there are some 8x8 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + "sub w2, %w[dst_cols], %w[col]\n" + "mov w3, #8\n" + "cmp w1, #8\n" + // Compute w1 = how many rows of the 8x8 block fit + "csel w1, w1, w3, le\n" + "cmp w2, #8\n" + // Compute w2 = how many cols of the 8x8 block fit + "csel w2, w2, w3, le\n" + + // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. + "cmp w1, w3\n" + "ccmp w2, w3, 0, eq\n" + // Yes, all of the 8x8 block fits, go to fast path. + "beq 30f\n" + // Not all of the 8x8 block fits. + // Set (x3 address, x4 stride) to write to dst_tmp_buf + "mov x3, %[dst_tmp_buf]\n" + "mov x4, #8\n" + "b 31f\n" + "30:\n" + // Yes, all of the 8x8 block fits. + // Set (x3 address, x4 stride) to write directly to destination matrix. + "mov x3, %[dst_ptr]\n" + "mov x4, x11\n" + "31:\n" + + // Write our 8bit values to the destination described by + // (x3 address, x4 stride). + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v16.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v16) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v20.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v20) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v17.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v17) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v21.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v21) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v18.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v18) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v22.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v22) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v19.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v19) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v23.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v23) + + // For the next block: perform the first few multiply-adds on the data + // that we have already loaded. + ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" + ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" + ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" + ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" + + // If all of the 8x8 block fits, we just finished writing it to the + // destination, so we skip the next part. + "beq 41f\n" + // Not all of the 8x8 block fits in the destination matrix. We just + // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over + // it to copy into the destination matrix the part that fits. + "mov x3, %[dst_tmp_buf]\n" + "mov x4, %[dst_ptr]\n" + "mov w6, #0\n" + "50:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov w5, #0\n" + "51:\n" + "ldrb w7, [x3, w5, uxtw]\n" + "strb w7, [x4, w5, uxtw]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 51b\n" + "add w6, w6, #1\n" + "add x3, x3, #8\n" + "add x4, x4, x11\n" + "cmp w6, w2\n" + "blt 50b\n" + "41:\n" + "add %[dst_ptr], %[dst_ptr], #8\n" + // At this point we have completely finished writing values to the + // destination matrix for the current block. + + "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n" + + // Cast-and-saturate from int32 to int16 + "sqxtn v16.4h, v16.4s\n" + "sqxtn2 v16.8h, v17.4s\n" + "sqxtn v17.4h, v18.4s\n" + "sqxtn2 v17.8h, v19.4s\n" + "sqxtn v18.4h, v20.4s\n" + "sqxtn2 v18.8h, v21.4s\n" + "sqxtn v19.4h, v22.4s\n" + "sqxtn2 v19.8h, v23.4s\n" + "sqxtn v20.4h, v24.4s\n" + "sqxtn2 v20.8h, v25.4s\n" + "sqxtn v21.4h, v26.4s\n" + "sqxtn2 v21.8h, v27.4s\n" + "sqxtn v22.4h, v28.4s\n" + "sqxtn2 v22.8h, v29.4s\n" + "sqxtn v23.4h, v30.4s\n" + "sqxtn2 v23.8h, v31.4s\n" + + // At this point, v24 -- v31 aren't used anymore for the current block, + // so we can start clearing these accumulators for the next block + // (next iteration of the main loop). + RUY_MAKE_ZERO(v24) + RUY_MAKE_ZERO(v25) + RUY_MAKE_ZERO(v26) + RUY_MAKE_ZERO(v27) + RUY_MAKE_ZERO(v28) + RUY_MAKE_ZERO(v29) + RUY_MAKE_ZERO(v30) + RUY_MAKE_ZERO(v31) + + // Add the destination zero point + "dup v14.8h, v13.h[4]\n" + "add v16.8h, v16.8h, v14.8h\n" + "add v17.8h, v17.8h, v14.8h\n" + "add v18.8h, v18.8h, v14.8h\n" + "add v19.8h, v19.8h, v14.8h\n" + "add v20.8h, v20.8h, v14.8h\n" + "add v21.8h, v21.8h, v14.8h\n" + "add v22.8h, v22.8h, v14.8h\n" + "add v23.8h, v23.8h, v14.8h\n" + + // Cast-and-saturate from int16 to uint8 + "sqxtn v16.8b, v16.8h\n" + "sqxtn2 v16.16b, v17.8h\n" + "sqxtn v17.8b, v18.8h\n" + "sqxtn2 v17.16b, v19.8h\n" + "sqxtn v18.8b, v20.8h\n" + "sqxtn2 v18.16b, v21.8h\n" + "sqxtn v19.8b, v22.8h\n" + "sqxtn2 v19.16b, v23.8h\n" + + // Load the clamp_min, clamp_max bounds + "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + "dup v14.16b, w2\n" // clamp_min + "dup v15.16b, w3\n" // clamp_max + + // Apply the clamp_min bound + "smax v16.16b, v16.16b, v14.16b\n" + "smax v17.16b, v17.16b, v14.16b\n" + "smax v18.16b, v18.16b, v14.16b\n" + "smax v19.16b, v19.16b, v14.16b\n" + + // Apply the clamp_max bound + "smin v16.16b, v16.16b, v15.16b\n" + "smin v17.16b, v17.16b, v15.16b\n" + "smin v18.16b, v18.16b, v15.16b\n" + "smin v19.16b, v19.16b, v15.16b\n" + + // Make it so that all of the final 8bit values are stored in the + // first 64bits of 128bit NEON registers, so they can be stored + // by 64bit st1 store instructions with byte alignment. + "dup d20, v16.d[1]\n" + "dup d21, v17.d[1]\n" + "dup d22, v18.d[1]\n" + "dup d23, v19.d[1]\n" + + // Compute how much of the 8x8 block of destination 8bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 8x8, there are some 8x8 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + "sub w2, %w[dst_cols], %w[col]\n" + "mov w3, #8\n" + "cmp w1, #8\n" + // Compute w1 = how many rows of the 8x8 block fit + "csel w1, w1, w3, le\n" + "cmp w2, #8\n" + // Compute w2 = how many cols of the 8x8 block fit + "csel w2, w2, w3, le\n" + + // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. + "cmp w1, w3\n" + "ccmp w2, w3, 0, eq\n" + // Yes, all of the 8x8 block fits, go to fast path. + "beq 130f\n" + // Not all of the 8x8 block fits. + // Set (x3 address, x4 stride) to write to dst_tmp_buf + "mov x3, %[dst_tmp_buf]\n" + "mov x4, #8\n" + "b 131f\n" + "130:\n" + // Yes, all of the 8x8 block fits. + // Set (x3 address, x4 stride) to write directly to destination matrix. + "mov x3, %[dst_ptr]\n" + "mov x4, x11\n" + "131:\n" + + // Write our 8bit values to the destination described by + // (x3 address, x4 stride). + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v16.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v16) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v20.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v20) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v17.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v17) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v21.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v21) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v18.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v18) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v22.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v22) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v19.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v19) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v23.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v23) + + // For the next block: perform the first few multiply-adds on the data + // that we have already loaded. + ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" + ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" + ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" + ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" + + // If all of the 8x8 block fits, we just finished writing it to the + // destination, so we skip the next part. + "beq 141f\n" + // Not all of the 8x8 block fits in the destination matrix. We just + // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over + // it to copy into the destination matrix the part that fits. + "mov x3, %[dst_tmp_buf]\n" + "mov x4, %[dst_ptr]\n" + "mov w6, #0\n" + "150:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov w5, #0\n" + "151:\n" + "ldrb w7, [x3, w5, uxtw]\n" + "strb w7, [x4, w5, uxtw]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 151b\n" + "add w6, w6, #1\n" + "add x3, x3, #8\n" + "add x4, x4, x11\n" + "cmp w6, w2\n" + "blt 150b\n" + "141:\n" + "add %[dst_ptr], %[dst_ptr], #8\n" + // At this point we have completely finished writing values to the + // destination matrix for the current block. + + "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n" + + // Add the destination zero point + "dup v14.8h, v13.h[4]\n" + "saddw v16.4s, v16.4s, v14.4h\n" + "saddw v17.4s, v17.4s, v14.4h\n" + "saddw v18.4s, v18.4s, v14.4h\n" + "saddw v19.4s, v19.4s, v14.4h\n" + "saddw v20.4s, v20.4s, v14.4h\n" + "saddw v21.4s, v21.4s, v14.4h\n" + "saddw v22.4s, v22.4s, v14.4h\n" + "saddw v23.4s, v23.4s, v14.4h\n" + "saddw v24.4s, v24.4s, v14.4h\n" + "saddw v25.4s, v25.4s, v14.4h\n" + "saddw v26.4s, v26.4s, v14.4h\n" + "saddw v27.4s, v27.4s, v14.4h\n" + "saddw v28.4s, v28.4s, v14.4h\n" + "saddw v29.4s, v29.4s, v14.4h\n" + "saddw v30.4s, v30.4s, v14.4h\n" + "saddw v31.4s, v31.4s, v14.4h\n" + + // Cast-and-saturate from int32 to int16 + "sqxtn v16.4h, v16.4s\n" + "sqxtn2 v16.8h, v17.4s\n" + "sqxtn v17.4h, v18.4s\n" + "sqxtn2 v17.8h, v19.4s\n" + "sqxtn v18.4h, v20.4s\n" + "sqxtn2 v18.8h, v21.4s\n" + "sqxtn v19.4h, v22.4s\n" + "sqxtn2 v19.8h, v23.4s\n" + "sqxtn v20.4h, v24.4s\n" + "sqxtn2 v20.8h, v25.4s\n" + "sqxtn v21.4h, v26.4s\n" + "sqxtn2 v21.8h, v27.4s\n" + "sqxtn v22.4h, v28.4s\n" + "sqxtn2 v22.8h, v29.4s\n" + "sqxtn v23.4h, v30.4s\n" + "sqxtn2 v23.8h, v31.4s\n" + + // At this point, v24 -- v31 aren't used anymore for the current block, + // so we can start clearing these accumulators for the next block + // (next iteration of the main loop). + RUY_MAKE_ZERO(v24) + RUY_MAKE_ZERO(v25) + RUY_MAKE_ZERO(v26) + RUY_MAKE_ZERO(v27) + RUY_MAKE_ZERO(v28) + RUY_MAKE_ZERO(v29) + RUY_MAKE_ZERO(v30) + RUY_MAKE_ZERO(v31) + + // Load the clamp_min, clamp_max bounds + "ldrsh w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + "ldrsh w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + "dup v14.8h, w2\n" // clamp_min + "dup v15.8h, w3\n" // clamp_max + + // Apply the clamp_min bound + "smax v16.8h, v16.8h, v14.8h\n" + "smax v17.8h, v17.8h, v14.8h\n" + "smax v18.8h, v18.8h, v14.8h\n" + "smax v19.8h, v19.8h, v14.8h\n" + "smax v20.8h, v20.8h, v14.8h\n" + "smax v21.8h, v21.8h, v14.8h\n" + "smax v22.8h, v22.8h, v14.8h\n" + "smax v23.8h, v23.8h, v14.8h\n" + // Apply the clamp_max bound + "smin v16.8h, v16.8h, v15.8h\n" + "smin v17.8h, v17.8h, v15.8h\n" + "smin v18.8h, v18.8h, v15.8h\n" + "smin v19.8h, v19.8h, v15.8h\n" + "smin v20.8h, v20.8h, v15.8h\n" + "smin v21.8h, v21.8h, v15.8h\n" + "smin v22.8h, v22.8h, v15.8h\n" + "smin v23.8h, v23.8h, v15.8h\n" + + // Compute how much of the 8x8 block of destination 16bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 8x8, there are some 8x8 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + "sub w2, %w[dst_cols], %w[col]\n" + "mov w3, #8\n" + "cmp w1, #8\n" + // Compute w1 = how many rows of the 8x8 block fit + "csel w1, w1, w3, le\n" + "cmp w2, #8\n" + // Compute w1 = how many rows of the 8x8 block fit + "csel w2, w2, w3, le\n" + + // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. + "cmp w1, w3\n" + "ccmp w2, w3, 0, eq\n" + // Yes, all of the 8x8 block fits, go to fast path. + "beq 230f\n" + // Not all of the 8x8 block fits. + // Set (x3 address, x4 stride) to write to dst_tmp_buf + "mov x3, %[dst_tmp_buf]\n" + "mov x4, #16\n" + "b 231f\n" + "230:\n" + // Yes, all of the 8x8 block fits. + // Set (x3 address, x4 stride) to write directly to destination matrix. + "mov x3, %[dst_ptr]\n" + "mov x4, x11\n" + "231:\n" + + // Write our 16bit values to the destination described by + // (x3 address, x4 stride). + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v16.8h}, [x3], x4\n" + RUY_MAKE_ZERO(v16) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v17.8h}, [x3], x4\n" + RUY_MAKE_ZERO(v17) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v18.8h}, [x3], x4\n" + RUY_MAKE_ZERO(v18) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v19.8h}, [x3], x4\n" + RUY_MAKE_ZERO(v19) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v20.8h}, [x3], x4\n" + RUY_MAKE_ZERO(v20) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v21.8h}, [x3], x4\n" + RUY_MAKE_ZERO(v21) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v22.8h}, [x3], x4\n" + RUY_MAKE_ZERO(v22) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v23.8h}, [x3], x4\n" + RUY_MAKE_ZERO(v23) + + // For the next block: perform the first few multiply-adds on the data + // that we have already loaded. + ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" + ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" + ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" + ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" + + // If all of the 8x8 block fits, we just finished writing it to the + // destination, so we skip the next part. + "beq 241f\n" + // Not all of the 8x8 block fits in the destination matrix. We just + // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over + // it to copy into the destination matrix the part that fits. + "mov x3, %[dst_tmp_buf]\n" + "mov x4, %[dst_ptr]\n" + "mov w6, #0\n" + "250:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov w5, #0\n" + "251:\n" + "ldrsh w7, [x3, x5, lsl #1]\n" + "strh w7, [x4, x5, lsl #1]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 251b\n" + "add w6, w6, #1\n" + "add x3, x3, #16\n" + "add x4, x4, x11\n" + "cmp w6, w2\n" + "blt 250b\n" + "241:\n" + "add %[dst_ptr], %[dst_ptr], #16\n" + // At this point we have completely finished writing values to the + // destination matrix for the current block. + + "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n" + + // Since the store type is the same as the accum type, no need for + // downcast. There's also no need for clamp by min/max. + + // Compute how much of the 8x8 block of destination 32it values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 8x8, there are some 8x8 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + "sub w2, %w[dst_cols], %w[col]\n" + "mov w3, #8\n" + "cmp w1, #8\n" + // Compute w1 = how many rows of the 8x8 block fit + "csel w1, w1, w3, le\n" + "cmp w2, #8\n" + // Compute w1 = how many rows of the 8x8 block fit + "csel w2, w2, w3, le\n" + + // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. + "cmp w1, w3\n" + "ccmp w2, w3, 0, eq\n" + // Yes, all of the 8x8 block fits, go to fast path. + "beq 330f\n" + // Not all of the 8x8 block fits. + // Write to dst_tmp_buf + "mov x3, %[dst_tmp_buf]\n" + "st1 {v16.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v16) + "st1 {v17.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v17) + "st1 {v18.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v18) + "st1 {v19.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v19) + "st1 {v20.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v20) + "st1 {v21.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v21) + "st1 {v22.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v22) + "st1 {v23.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v23) + "st1 {v24.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v24) + "st1 {v25.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v25) + "st1 {v26.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v26) + "st1 {v27.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v27) + "st1 {v28.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v28) + "st1 {v29.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v29) + "st1 {v30.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v30) + "st1 {v31.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v31) + + "b 331f\n" + + "330:\n" + // Yes, all of the 8x8 block fits. + "mov x4, %[dst_ptr]\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v16.4s, v17.4s}, [x3], #32\n" + RUY_MAKE_ZERO(v16) + RUY_MAKE_ZERO(v17) + "add x4, x4, x11\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v18.4s, v19.4s}, [x3], #32\n" + RUY_MAKE_ZERO(v18) + RUY_MAKE_ZERO(v19) + "add x4, x4, x11\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v20.4s, v21.4s}, [x3], #32\n" + RUY_MAKE_ZERO(v20) + RUY_MAKE_ZERO(v21) + "add x4, x4, x11\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v22.4s, v23.4s}, [x3], #32\n" + RUY_MAKE_ZERO(v22) + RUY_MAKE_ZERO(v23) + "add x4, x4, x11\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v24.4s, v25.4s}, [x3], #32\n" + RUY_MAKE_ZERO(v24) + RUY_MAKE_ZERO(v25) + "add x4, x4, x11\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v26.4s, v27.4s}, [x3], #32\n" + RUY_MAKE_ZERO(v26) + RUY_MAKE_ZERO(v27) + "add x4, x4, x11\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v28.4s, v29.4s}, [x3], #32\n" + RUY_MAKE_ZERO(v28) + RUY_MAKE_ZERO(v29) + "add x4, x4, x11\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v30.4s, v31.4s}, [x3], #32\n" + RUY_MAKE_ZERO(v30) + RUY_MAKE_ZERO(v31) + + "331:\n" + + // For the next block: perform the first few multiply-adds on the data + // that we have already loaded. + ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" + ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" + ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" + ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" + + // If all of the 8x8 block fits, we just finished writing it to the + // destination, so we skip the next part. + "beq 341f\n" + + // Not all of the 8x8 block fits in the destination matrix. We just + // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over + // it to copy into the destination matrix the part that fits. + "mov x3, %[dst_tmp_buf]\n" + "mov x4, %[dst_ptr]\n" + "mov w6, #0\n" + "350:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov w5, #0\n" + "351:\n" + "ldr w7, [x3, x5, lsl #2]\n" + "str w7, [x4, x5, lsl #2]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 351b\n" + "add w6, w6, #1\n" + "add x3, x3, #32\n" + "add x4, x4, x11\n" + "cmp w6, w2\n" + "blt 350b\n" + "341:\n" + "add %[dst_ptr], %[dst_ptr], #32\n" + // At this point we have completely finished writing values to the + // destination matrix for the current block. + + RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n" + + // Reload some params --- we had used x5 -- x7 for a few other things + // since the last time we had loaded them. + "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" + "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" + + // Move to the next block of the destination matrix, for the next iter + // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already + // been updated earlier. + // Have we reached the end row? + "cmp %w[row], w7\n" + "beq 20f\n" // yes, end row. + // Not end row. Move to the next row. + "add %w[row], %w[row], #8\n" + "b 21f\n" + "20:\n" + // Was already at end row. + "mov %w[row], w6\n" // Move back to first row. + "add %w[col], %w[col], #8\n" // Move to the next column. + "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #3\n" + "mov %[dst_ptr], %[dst_col_ptr]\n" + "21:\n" + + // Main loop exit condition: have we hit the end column? + "cmp %w[col], w8\n" + + // w1 is the number of levels of depth that we have already loaded + // LHS and RHS data for. Corresponding to the initial ld1 instructions + // above, this is currently 4. + "mov w1, #4\n" + + "ble 1b\n" + + // clang-format on + + : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr), + [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), + [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col) + : [ params ] "r"(¶ms), [dst_rows] "r"(params.dst_rows), + [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf), + [dst_type_id] "r"(params.dst_type_id) + : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc", + "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", + "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31"); +} + +// Similar to the above 8-bit dotprod kernel, but specialized for the case of +// RHS cols == 1. +// Relevant target CPUs for this kernel include ARM Cortex-A76, +// since these are 64-bit, out-of-order and with dotprod support. +void Kernel8bitNeonDotprodOutOfOrder1Col(const KernelParams8bit<8, 8>& params) { + profiler::ScopeLabel label( + "Kernel (kNeonDotprod, optimized for out-of-order cores)"); + + CheckOffsetsInKernelParams8bit(params); + + const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; + const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; + const std::int8_t* lhs_ptr = lhs_col_ptr; + const std::int8_t* rhs_ptr = rhs_col_ptr; + void* dst_col_ptr = params.dst_base_ptr; + void* dst_ptr = dst_col_ptr; + int row = params.start_row; + int col = params.start_col; + + // The asm kernel below has the following NEON register allocation: + // + // v16 -- v31 are int32 accumulators. + // During accumulation, v0 -- v15 are used to load int8 data from LHS and + // RHS. At least v0 and v1 are used to load a 8x4 block of LHS, and v2 and + // v3 are used to load a 4x8 block of RHS, like this: + // + // int8 RHS 4x1 block + // /-------\ + // |v2.b[0]| + // | ... | + // |v2.b[3]| + // \-------/ + // int8 LHS 8x4 block + // /---------------------\ /--------\ + // |v0.b[0] ... v0.b[3] | |v16.s[0]| + // | ... ... | | ... | + // |v0.b[12] ... v0.b[15]| |v16.s[3]| + // |v1.b[0] ... v1.b[3] | |v17.s[0]| + // | ... ... | | ... | + // |v1.b[12] ... v1.b[15]| |v17.s[3]| + // \---------------------/ \--------/ + // int32 accumulators 8x1 block + // + // In the RUY_OPT_MAX_STREAMING part of the kernel, this elementary step + // is repeated 4 times, using 4x more registers for LHS and RHS, so that + // is where instead of using v0 -- v3 for LHS and RHS, we use v0 -- v15. + // + // Outside of the RUY_OPT_MAX_STREAMING part of the kernel, v4 -- v7 are + // unused, and v8 -- v15 are used for loading parameters used for the + // post-accumulation part of the kernel. + asm volatile( +#define RUY_MAKE_ZERO(reg) "dup " #reg ".4s, wzr\n" + + // clang-format off + + // Load some parameters into registers. + "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" + "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" + "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" + "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" + "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" + "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" + "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" + + // Load the first 32 bytes of LHS and RHS data. + "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v2.8b}, [%[rhs_ptr]]\n" + "add %[rhs_ptr], %[rhs_ptr], #32\n" + + // Clear accumulators. + RUY_MAKE_ZERO(v16) + RUY_MAKE_ZERO(v17) + + // w1 is the number of levels of depth that we have already loaded + // LHS and RHS data for. Corresponding to the initial ld1 instructions + // above, this is currently 4. + "mov w1, #4\n" + + // Perform the first few multiply-adds on the data that we have already + // loaded. + ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" + + // Main loop of the whole GEMM, over rows and columns of the + // destination matrix. + "1:\n" + + // Ordinary kernel inner loop (over depth), the simpler loop that the + // above was an equivalent 4x-partially-unrolled version of. + + // Reminder - w1 is how many levels of depth we have already loaded + // data for, w12 is the total depth. + "cmp w1, w12\n" + "beq 79f\n" + + "2:\n" + + // Because of the data that we have already loaded, we can start the + // loop body right away with some multiply-adds. + // Each iteration of this loop advances by 4 levels of depth. + "add w1, w1, #4\n" + "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" + ".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n" + // Loop termination condition. + "cmp w1, w12\n" + "ld1 {v2.8b}, [%[rhs_ptr]]\n" + "add %[rhs_ptr], %[rhs_ptr], #32\n" + ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" + "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" + + "blt 2b\n" + + "79:\n" + // End of the inner loop on depth. Now perform the remaining + // multiply-adds of the last 4 levels of depth, for which the LHS + // and RHS data is already loaded. + + ".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n" + + // End of accumulation. The registers v16 -- v31 contain the final + // int32 accumulator values of the current 8x8 destination block. + // We now have to compute the final 8-bit values from these int32 + // accumulators, and advance to the next 8x8 block. We intertwine + // these two aspects whenever possible for optimal pipelining, both + // at the data flow level (prefetch data for next block as early as + // possible) and instruction pipelining level (some of the next-block + // work can dual-issue with some of the final work on the current + // block). + + // Logic to advance to the next block in preparation for the next + // iteration of the main loop. For now, we only want to compute + // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are + // not yet ready to update the values of row and col, as we still need + // the current values for the rest of the work on the current block. + + "cmp %w[row], w7\n" // Have we finished the last row? + "bge 4f\n" // If finished last row, go to 4 + // Not finished last row: then advance to next row. + "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #3\n" + "b 5f\n" + "4:\n" // Finished last row... + "mov %[lhs_col_ptr], x5\n" // Go back to first row + // Now we need to advance to the next column. If we already + // finished the last column, then in principle we are done, however + // we can't just return here, as we need to allow the end work of the + // current block to complete. The good news is that at this point it + // doesn't matter what data we load for the next column, since + // we will exit from the main loop below before actually storing + // anything computed from that data. + "cmp %w[col], w8\n" // Have we finished the last column? + "bge 5f\n" // If yes, just carry on without updating the column pointer. + // Not finished last column: then advance to next column. + "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #3\n" + "5:\n" + + // Set the LHS and RHS data pointers to the start of the columns just + // computed. + "mov %[lhs_ptr], %[lhs_col_ptr]\n" + "mov %[rhs_ptr], %[rhs_col_ptr]\n" + + // Load some parameters needed for the end work on current block. + RUY_MAKE_ZERO(v8) + "ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" + "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n" + "ins v13.h[4], w4\n" // dst_zero_point + "ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n" + "ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" + "dup v9.4s, w3\n" // create prod_zp_depth_vec + "add x5, x4, %x[row], lsl #2\n" + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" + "csel x4, x4, x5, eq\n" + + "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" + "add x5, x1, %x[row], lsl #2\n" + + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" + "csel x1, x1, x5, eq\n" + + // Load 8 bias values. + "ld1 {v14.4s}, [x1], #16\n" + "ld1 {v15.4s}, [x1]\n" + + // Now that we know what LHS and RHS data the next iteration of the + // main loop will need to load, we start loading the first 32 bytes of + // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore + // in the rest of the work on the current block. + "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v2.8b}, [%[rhs_ptr]]\n" + "add %[rhs_ptr], %[rhs_ptr], #32\n" + + // Add to the bias values the product (depth * lhs_zero_point * rhs_zero_point), + // See the term NZ1Z2 in equation (7) in https://arxiv.org/pdf/1712.05877.pdf + "add v14.4s, v14.4s, v9.4s\n" + "add v15.4s, v15.4s, v9.4s\n" + + // Perform the bias-addition (per the above, we have just folded into + // the bias the (depth * lhs_zero_point * rhs_zero_point) term.) + "add v16.4s, v16.4s, v14.4s\n" + "add v17.4s, v17.4s, v15.4s\n" + + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n" + "beq 401f\n" + "ldr x3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n" + "add x3, x3, %x[col], lsl #2\n" + "ld1 {v14.4s}, [x3], #16\n" + "ld1 {v15.4s}, [x3]\n" + "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n" + "dup v10.4s, w5\n" // create lhs_zero_point_vec + // Subtract rhs_sums * lhs_zero_point, per + // equation (7) in https://arxiv.org/pdf/1712.05877.pdf + "mls v16.4s, v10.4s, v14.s[0]\n" + "mls v17.4s, v10.4s, v14.s[0]\n" + "401:\n" + + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n" + "beq 402f\n" + "ldr x2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n" + "add x2, x2, %x[row], lsl #2\n" + "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n" + // Load 4 lhs_sums values. + "ld1 {v11.4s}, [x2], #16\n" + "ld1 {v12.4s}, [x2]\n" + "ins v13.s[1], w5\n" // rhs_zero_point + // Compute lhs_sums * rhs_zero_point. + "mul v11.4s, v11.4s, v13.s[1]\n" + "mul v12.4s, v12.4s, v13.s[1]\n" + // Subtract lhs_sums * rhs_zero_point, per + // equation (7) in https://arxiv.org/pdf/1712.05877.pdf + "sub v16.4s, v16.4s, v11.4s\n" + "sub v17.4s, v17.4s, v12.4s\n" + + "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n" + "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n" + + "402:\n" + + // At this point we have computed the final int32 values. Now we + // start down-quantizing them to obtain the final 8bit values from them. + + // As part of this down-quantization, our int32 values will be + // multiplied by a multiplier that has a fixed-point component and an + // exponent component. + + //Load the exponent part of the multiplier. + "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n" + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" + "add x5, x1, %x[row], lsl #2\n" + "csel x1, x1, x5, eq\n" + + "ldr q9, [x1]\n" + "ldr q10, [x1, #16]\n" + + "tst w6, #" RUY_STR(RUY_ASM_FLAG_NEEDS_LEFT_SHIFT) "\n" + "beq 403f\n" + "smax v11.4s, v9.4s, v8.4s\n" + "smax v12.4s, v10.4s, v8.4s\n" + "sshl v16.4s, v16.4s, v11.4s\n" + "sshl v17.4s, v17.4s, v12.4s\n" + "403:\n" + + "ldr q14, [x4]\n" // multiplier_fixedpoint + "ldr q15, [x4, #16]\n" // multiplier_fixedpoint + + "smin v11.4s, v9.4s, v8.4s\n" + "smin v12.4s, v10.4s, v8.4s\n" + + // Apply the fixed-point part of the multiplier. + "sqrdmulh v16.4s, v16.4s, v14.4s\n" + "sqrdmulh v17.4s, v17.4s, v15.4s\n" + + // We have some rounding division-by-power-of-two to do. This should + // always use "round to nearest". We allow for some + // freedom in how ties are broken, to strike a good compromise of + // performance on given hardware vs. perfect agreement of results + // across hardware. + // + // When RUY_OPT_NATIVE_ROUNDING is enabled, we allow for implementation + // defined tie-breaks to help performance. On NEON, this means that we + // can just use the NEON rounding instructions, such as srshl. They + // happen to be breaking ties upward. + // + // When RUY_OPT_NATIVE_ROUNDING is disabled, we implement strict + // break-ties-away-from zero, as described in Appendix B of + // https://arxiv.org/pdf/1712.05877.pdf + // When we wrote that, we thought that that would be better unbiased + // than the NEON upwards tie-breaks, and we had observed some + // improvement on some model. However, that is only more unbiased for + // data centered at zero, which was likely the case in that model, + // but is not always the case. If we wanted something more consistently + // unbiased then we should try breaking ties toward-nearest-even. +#if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING) + // Fix up values to be right-shifted, so that the (round to nearest, + // break ties upward) behavior of srshl applied to these fixed-up + // values, produces the same result as the desired (round to nearest, + // break ties away from zero) behavior on the original values. + "and v8.16b, v16.16b, v11.16b\n" + "and v9.16b, v17.16b, v12.16b\n" + "sshr v8.4s, v8.4s, #31\n" + "sshr v9.4s, v9.4s, #31\n" + "sqadd v16.4s, v16.4s, v8.4s\n" + "sqadd v17.4s, v17.4s, v9.4s\n" + +#endif + // At this point we have reduced the problem of correctly implementing + // rounding divide-by-power-of-two, to what the SRSHL instruction can + // do. + "srshl v16.4s, v16.4s, v11.4s\n" + "srshl v17.4s, v17.4s, v12.4s\n" + + "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n" + "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n" + "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n" + "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n" + + // Cast-and-saturate from int32 to int16 + "sqxtn v16.4h, v16.4s\n" + "sqxtn2 v16.8h, v17.4s\n" + // All data in v16 at this point. + + // Add the destination zero point + "dup v14.8h, v13.h[4]\n" + "add v16.8h, v16.8h, v14.8h\n" + + // Cast-and-saturate from int16 to uint8, leaving all data in the + // lower half of v16. + "sqxtun v16.8b, v16.8h\n" + + // Load the clamp_min, clamp_max bounds + "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + "dup v14.16b, w2\n" // clamp_min + "dup v15.16b, w3\n" // clamp_max + + // Apply the clamp_min bound + "umax v16.16b, v16.16b, v14.16b\n" + + // Apply the clamp_max bound + "umin v16.16b, v16.16b, v15.16b\n" + + // Make it so that all of the final 8bit values are stored in the + // first 64bits of 128bit NEON registers, so they can be stored + // by 64bit st1 store instructions with byte alignment. + "dup d20, v16.d[1]\n" + + // Compute how much of the 8x1 block of destination 8bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 8x1, there are some 8x1 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + "sub w2, %w[dst_cols], %w[col]\n" + "mov w3, #8\n" + "cmp w1, #8\n" + // Compute w1 = how many rows of the 8x1 block fit + "csel w1, w1, w3, le\n" + "cmp w2, #8\n" + + // Test if w1==8, i.e. if all of the 8x1 block fits. + "cmp w1, w3\n" + // Yes, all of the 8x1 block fits, go to fast path. + "beq 30f\n" + // Not all of the 8x1 block fits. + // Set (x3 address, x4 stride) to write to dst_tmp_buf + "mov x3, %[dst_tmp_buf]\n" + "mov x4, #8\n" + "b 31f\n" + "30:\n" + // Yes, all of the 8x1 block fits. + // Set (x3 address, x4 stride) to write directly to destination matrix. + "mov x3, %[dst_ptr]\n" + "mov x4, x11\n" + "31:\n" + + // Write our 8bit values to the destination + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "st1 {v16.8b}, [x3]\n" + RUY_MAKE_ZERO(v16) + RUY_MAKE_ZERO(v17) + + // For the next block: perform the first few multiply-adds on the data + // that we have already loaded. + ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" + + // If all of the 8x8 block fits, we just finished writing it to the + // destination, so we skip the next part. + "beq 41f\n" + // Not all of the 8x8 block fits in the destination matrix. We just + // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over + // it to copy into the destination matrix the part that fits. + "mov x3, %[dst_tmp_buf]\n" + "mov x4, %[dst_ptr]\n" + "mov w6, #0\n" + "50:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov w5, #0\n" + "51:\n" + "ldrb w7, [x3, w5, uxtw]\n" + "strb w7, [x4, w5, uxtw]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 51b\n" + "41:\n" + "add %[dst_ptr], %[dst_ptr], #8\n" + // At this point we have completely finished writing values to the + // destination matrix for the current block. + + "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n" + + // Cast-and-saturate from int32 to int16 + "sqxtn v16.4h, v16.4s\n" + "sqxtn2 v16.8h, v17.4s\n" + + + // Add the destination zero point + "dup v14.8h, v13.h[4]\n" + "add v16.8h, v16.8h, v14.8h\n" + + // Cast-and-saturate from int16 to uint8 + "sqxtn v16.8b, v16.8h\n" + + // Load the clamp_min, clamp_max bounds + "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + "dup v14.16b, w2\n" // clamp_min + "dup v15.16b, w3\n" // clamp_max + + // Apply the clamp_min bound + "smax v16.16b, v16.16b, v14.16b\n" + + // Apply the clamp_max bound + "smin v16.16b, v16.16b, v15.16b\n" + + // Make it so that all of the final 8bit values are stored in the + // first 64bits of 128bit NEON registers, so they can be stored + // by 64bit st1 store instructions with byte alignment. + "dup d20, v16.d[1]\n" + + // Compute how much of the 8x1 block of destination 8bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 8x8, there are some 8x8 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + "sub w2, %w[dst_cols], %w[col]\n" + "mov w3, #8\n" + "cmp w1, #8\n" + // Compute w1 = how many rows of the 8x1 block fit + "csel w1, w1, w3, le\n" + "cmp w2, #8\n" + + // Test if w1==8, i.e. if all of the 8x1 block fits. + "cmp w1, w3\n" + // Yes, all of the 8x1 block fits, go to fast path. + "beq 130f\n" + // Not all of the 8x1 block fits. + // Set (x3 address, x4 stride) to write to dst_tmp_buf + "mov x3, %[dst_tmp_buf]\n" + "mov x4, #8\n" + "b 131f\n" + "130:\n" + // Yes, all of the 8x8 block fits. + // Set (x3 address, x4 stride) to write directly to destination matrix. + "mov x3, %[dst_ptr]\n" + "mov x4, x11\n" + "131:\n" + + // Write our 8bit values to the destination + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "st1 {v16.8b}, [x3]\n" + RUY_MAKE_ZERO(v16) + RUY_MAKE_ZERO(v17) + + // For the next block: perform the first few multiply-adds on the data + // that we have already loaded. + ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" + + // If all of the 8x8 block fits, we just finished writing it to the + // destination, so we skip the next part. + "beq 141f\n" + // Not all of the 8x8 block fits in the destination matrix. We just + // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over + // it to copy into the destination matrix the part that fits. + "mov x3, %[dst_tmp_buf]\n" + "mov x4, %[dst_ptr]\n" + "mov w6, #0\n" + "150:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov w5, #0\n" + "151:\n" + "ldrb w7, [x3, w5, uxtw]\n" + "strb w7, [x4, w5, uxtw]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 151b\n" + "141:\n" + "add %[dst_ptr], %[dst_ptr], #8\n" + // At this point we have completely finished writing values to the + // destination matrix for the current block. + + "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n" + + // Add the destination zero point + "dup v14.8h, v13.h[4]\n" + "saddw v16.4s, v16.4s, v14.4h\n" + "saddw v17.4s, v17.4s, v14.4h\n" + + // Cast-and-saturate from int32 to int16 + "sqxtn v16.4h, v16.4s\n" + "sqxtn2 v16.8h, v17.4s\n" + + // Load the clamp_min, clamp_max bounds + "ldrsh w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + "ldrsh w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + "dup v14.8h, w2\n" // clamp_min + "dup v15.8h, w3\n" // clamp_max + + // Apply the clamp_min bound + "smax v16.8h, v16.8h, v14.8h\n" + // Apply the clamp_max bound + "smin v16.8h, v16.8h, v15.8h\n" + + // Compute how much of the 8x1 block of destination 16bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 8x8, there are some 8x1 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + "sub w2, %w[dst_cols], %w[col]\n" + "mov w3, #8\n" + "cmp w1, #8\n" + // Compute w1 = how many rows of the 8x1 block fit + "csel w1, w1, w3, le\n" + "cmp w2, #8\n" + + // Test if w1==8, i.e. if all of the 8x8 block fits. + "cmp w1, w3\n" + // Yes, all of the 8x1 block fits, go to fast path. + "beq 230f\n" + // Not all of the 8x1 block fits. + // Set (x3 address, x4 stride) to write to dst_tmp_buf + "mov x3, %[dst_tmp_buf]\n" + "mov x4, #16\n" + "b 231f\n" + "230:\n" + // Yes, all of the 8x1 block fits. + // Set (x3 address, x4 stride) to write directly to destination matrix. + "mov x3, %[dst_ptr]\n" + "mov x4, x11\n" + "231:\n" + + // Write our 16bit values to the destination + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "st1 {v16.8h}, [x3]\n" + RUY_MAKE_ZERO(v16) + RUY_MAKE_ZERO(v17) + + // For the next block: perform the first few multiply-adds on the data + // that we have already loaded. + ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" + + // If all of the 8x1 block fits, we just finished writing it to the + // destination, so we skip the next part. + "beq 241f\n" + // Not all of the 8x1 block fits in the destination matrix. We just + // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over + // it to copy into the destination matrix the part that fits. + "mov x3, %[dst_tmp_buf]\n" + "mov x4, %[dst_ptr]\n" + "mov w6, #0\n" + "250:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov w5, #0\n" + "251:\n" + "ldrsh w7, [x3, x5, lsl #1]\n" + "strh w7, [x4, x5, lsl #1]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 251b\n" + "241:\n" + "add %[dst_ptr], %[dst_ptr], #16\n" + // At this point we have completely finished writing values to the + // destination matrix for the current block. + + "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n" + + // Since the store type is the same as the accum type, no need for + // downcast. There's also no need for clamp by min/max. + + // Compute how much of the 8x1 block of destination 32 bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 8x1, there are some 8x1 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + "sub w2, %w[dst_cols], %w[col]\n" + "mov w3, #8\n" + "cmp w1, #8\n" + // Compute w1 = how many rows of the 8x1 block fit + "csel w1, w1, w3, le\n" + "cmp w2, #8\n" + // Compute w1 = how many rows of the 8x8 block fit + "csel w2, w2, w3, le\n" + + // Test if w1==8, i.e. if all of the 8x8 block fits. + "cmp w1, w3\n" + // Yes, all of the 8x1 block fits, go to fast path. + "beq 330f\n" + // Not all of the 8x1 block fits. + // Set (x3 address, x4 stride) to write to dst_tmp_buf + "mov x3, %[dst_tmp_buf]\n" + "mov x4, #16\n" + + // Write our 32bit values to the destination described by + // (x3 address, x4 stride). + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v16.4s}, [x3], x4\n" + RUY_MAKE_ZERO(v16) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v17.4s}, [x3], x4\n" + RUY_MAKE_ZERO(v17) + + "b 331f\n" + + "330:\n" + // Yes, all of the 8x1 block fits. + // Set (x3 address, x4 stride) to write directly to destination matrix. + "mov x4, %[dst_ptr]\n" + "mov x3, x4\n" + + // Write our 32bit values to the destination described by + // (x3 address, x4 stride). + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v16.4s, v17.4s}, [x3], #32\n" + RUY_MAKE_ZERO(v16) + RUY_MAKE_ZERO(v17) + + "331:\n" + + // For the next block: perform the first few multiply-adds on the data + // that we have already loaded. + ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" + + // If all of the 8x8 block fits, we just finished writing it to the + // destination, so we skip the next part. + "beq 341f\n" + + // Not all of the 8x8 block fits in the destination matrix. We just + // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over + // it to copy into the destination matrix the part that fits. + "mov x3, %[dst_tmp_buf]\n" + "mov x4, %[dst_ptr]\n" + "mov w6, #0\n" + "350:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "mov w5, #0\n" + "351:\n" + "ldr w7, [x3, x5, lsl #2]\n" + "str w7, [x4, x5, lsl #2]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 351b\n" + "341:\n" + "add %[dst_ptr], %[dst_ptr], #32\n" + // At this point we have completely finished writing values to the + // destination matrix for the current block. + + RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n" + + // Reload some params --- we had used x5 -- x7 for a few other things + // since the last time we had loaded them. + "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" + "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" + + // Move to the next block of the destination matrix, for the next iter + // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already + // been updated earlier. + // Have we reached the end row? + "cmp %w[row], w7\n" + "beq 20f\n" // yes, end row. + // Not end row. Move to the next row. + "add %w[row], %w[row], #8\n" + "b 21f\n" + "20:\n" + // Was already at end row. + "mov %w[row], w6\n" // Move back to first row. + "add %w[col], %w[col], #8\n" // Move to the next column. + "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #3\n" + "mov %[dst_ptr], %[dst_col_ptr]\n" + "21:\n" + + // Main loop exit condition: have we hit the end column? + "cmp %w[col], w8\n" + + // w1 is the number of levels of depth that we have already loaded + // LHS and RHS data for. Corresponding to the initial ld1 instructions + // above, this is currently 4. + "mov w1, #4\n" + + "ble 1b\n" + + // clang-format on + + : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr), + [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), + [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col) + : [ params ] "r"(¶ms), [dst_rows] "r"(params.dst_rows), + [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf), + [dst_type_id] "r"(params.dst_type_id) + : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc", + "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", + "v13", "v14", "v15", "v16", "v17"); +} + +// Variant of the above Kernel8bitNeonDotprodOutOfOrder, tuned for in-order +// CPUs. Specifically here, the relevant in-order CPUs are ARM Cortex-A55r1, +// since these are 64-bit and support dotprod. +// +// While this kernel does not have a direct equivalent in gemmlowp, it was +// developed based on insights that David Mansell at ARM shared with their +// contribution of gemmlowp kernels tuned for Cortex-A55r1, with very helpful +// comments. Specifically, see this comment about tuning for Cortex-A55r1: +// https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L4412 +void Kernel8bitNeonDotprodInOrder(const KernelParams8bit<8, 8>& params) { + profiler::ScopeLabel label( + "Kernel (kNeonDotprod, optimized for in-order cores)"); + + CheckOffsetsInKernelParams8bit(params); + + const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; + const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; + const std::int8_t* lhs_ptr = lhs_col_ptr; + const std::int8_t* rhs_ptr = rhs_col_ptr; + void* dst_col_ptr = params.dst_base_ptr; + void* dst_ptr = dst_col_ptr; + int row = params.start_row; + int col = params.start_col; + + // The asm kernel below has the following NEON register allocation: + // + // v16 -- v31 are int32 accumulators. + // During accumulation, v0 -- v3 are used to load int8 data from LHS and + // RHS. + // + // int8 RHS 4x8 block + // /-----------------------------------------\ + // |v2.b[0] ... v2.b[12] v3.b[0] ... v3.b[12]| + // | ... ... | + // |v2.b[3] ... v2.b[15] v3.b[3] ... v3.b[15]| + // \-----------------------------------------/ + // int8 LHS 8x4 block + // /---------------------\ /-----------------------------------------\ + // |v0.b[0] ... v0.b[3] | |v16.s[0] ... v30.s[0]| + // | ... ... | | ... ... | + // |v0.b[12] ... v0.b[15]| |v16.s[3] ... v30.s[3]| + // |v1.b[0] ... v1.b[3] | |v17.s[0] ... v31.s[0]| + // | ... ... | | ... ... | + // |v1.b[12] ... v1.b[15]| |v17.s[3] ... v31.s[3]| + // \---------------------/ \-----------------------------------------/ + // int32 accumulators 8x8 block + // + // There is no RUY_OPT_MAX_STREAMING 4x-unrolled part in this kernel because + // we did not observe a benefit of such partial unrolling on in-order CPUs. + // + // v4 -- v7 are unused, and v8 -- v15 are used for loading parameters used for + // the post-accumulation part of the kernel. + asm volatile( +#define RUY_MAKE_ZERO(reg) "dup " #reg ".4s, wzr\n" + + // clang-format off + + // Load some parameters into registers. + "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + RUY_MAKE_ZERO(v16) + "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" + RUY_MAKE_ZERO(v17) + "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" + RUY_MAKE_ZERO(v18) + "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" + RUY_MAKE_ZERO(v19) + "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" + RUY_MAKE_ZERO(v20) + "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" + RUY_MAKE_ZERO(v21) + "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" + RUY_MAKE_ZERO(v22) + "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" + + // Load the first 32 bytes of LHS and RHS data. + "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v2.16b}, [%[rhs_ptr]], #16\n" + "ld1 {v3.16b}, [%[rhs_ptr]], #16\n" + + // Clear accumulators. + RUY_MAKE_ZERO(v23) + RUY_MAKE_ZERO(v24) + RUY_MAKE_ZERO(v25) + RUY_MAKE_ZERO(v26) + RUY_MAKE_ZERO(v27) + // Perform the first few multiply-adds on the data that we have already + // loaded. + ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" + RUY_MAKE_ZERO(v28) + ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" + RUY_MAKE_ZERO(v29) + ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" + RUY_MAKE_ZERO(v30) + ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" + RUY_MAKE_ZERO(v31) + + + "1:\n" + + "add x5, %[lhs_ptr], x12, lsl #3\n" + "sub x5, x5, #32\n" + "cmp %[lhs_ptr], x5\n" + + "beq 79f\n" + + // Main accumulation loop + "2:\n" + ".word 0x4f83e018 // sdot v24.4s, v0.16b, v3.4b[0]\n" + "ldr x1, [%[lhs_ptr], #8]\n" + ".word 0x4fa3e01a // sdot v26.4s, v0.16b, v3.4b[1]\n" + "ldr x3, [%[rhs_ptr], #8]\n" + ".word 0x4f83e81c // sdot v28.4s, v0.16b, v3.4b[2]\n" + "ldr x4, [%[rhs_ptr], #24]\n" + ".word 0x4fa3e81e // sdot v30.4s, v0.16b, v3.4b[3]\n" + "ldr d0, [%[lhs_ptr], #0]\n" + ".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n" + "ins v0.d[1], x1\n" + ".word 0x4fa2e033 // sdot v19.4s, v1.16b, v2.4b[1]\n" + "ldr x2, [%[lhs_ptr], #24]\n" + ".word 0x4f82e835 // sdot v21.4s, v1.16b, v2.4b[2]\n" + "add %[lhs_ptr], %[lhs_ptr], #32\n" + ".word 0x4fa2e837 // sdot v23.4s, v1.16b, v2.4b[3]\n" + "ldr d2, [%[rhs_ptr], #0]\n" + ".word 0x4f83e039 // sdot v25.4s, v1.16b, v3.4b[0]\n" + "ins v2.d[1], x3\n" + ".word 0x4fa3e03b // sdot v27.4s, v1.16b, v3.4b[1]\n" + "cmp %[lhs_ptr], x5\n" + ".word 0x4f83e83d // sdot v29.4s, v1.16b, v3.4b[2]\n" + "add %[rhs_ptr], %[rhs_ptr], #32\n" + ".word 0x4fa3e83f // sdot v31.4s, v1.16b, v3.4b[3]\n" + "ldr d3, [%[rhs_ptr], #-16]\n" + ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" + "ldr d1, [%[lhs_ptr], #-16]\n" + ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" + "ins v3.d[1], x4\n" + ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" + "ins v1.d[1], x2\n" + ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" + "blt 2b\n" + + // Last accumulation steps, nothing left to load. + "79:\n" + ".word 0x4f83e018 // sdot v24.4s, v0.16b, v3.4b[0]\n" + "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + ".word 0x4fa3e01a // sdot v26.4s, v0.16b, v3.4b[1]\n" + "cmp %w[row], w7\n" // Have we finished the last row? + ".word 0x4f83e81c // sdot v28.4s, v0.16b, v3.4b[2]\n" + ".word 0x4fa3e81e // sdot v30.4s, v0.16b, v3.4b[3]\n" + ".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n" + ".word 0x4fa2e033 // sdot v19.4s, v1.16b, v2.4b[1]\n" + ".word 0x4f82e835 // sdot v21.4s, v1.16b, v2.4b[2]\n" + ".word 0x4fa2e837 // sdot v23.4s, v1.16b, v2.4b[3]\n" + ".word 0x4f83e039 // sdot v25.4s, v1.16b, v3.4b[0]\n" + ".word 0x4fa3e03b // sdot v27.4s, v1.16b, v3.4b[1]\n" + ".word 0x4f83e83d // sdot v29.4s, v1.16b, v3.4b[2]\n" + ".word 0x4fa3e83f // sdot v31.4s, v1.16b, v3.4b[3]\n" + + // End of accumulation. The registers v16 -- v31 contain the final + // int32 accumulator values of the current 8x8 destination block. + // We now have to compute the final 8-bit values from these int32 + // accumulators, and advance to the next 8x8 block. We intertwine + // these two aspects whenever possible for optimal pipelining, both + // at the data flow level (prefetch data for next block as early as + // possible) and instruction pipelining level (some of the next-block + // work can dual-issue with some of the final work on the current + // block). + + // Logic to advance to the next block in preparation for the next + // iteration of the main loop. For now, we only want to compute + // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are + // not yet ready to update the values of row and col, as we still need + // the current values for the rest of the work on the current block. + + "bge 4f\n" // If finished last row, go to 4 + // Not finished last row: then advance to next row. + "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #3\n" + "b 5f\n" + "4:\n" // Finished last row... + "mov %[lhs_col_ptr], x5\n" // Go back to first row + // Now we need to advance to the next column. If we already + // finished the last column, then in principle we are done, however + // we can't just return here, as we need to allow the end work of the + // current block to complete. The good news is that at this point it + // doesn't matter what data we load for the next column, since + // we will exit from the main loop below before actually storing + // anything computed from that data. + "cmp %w[col], w8\n" // Have we finished the last column? + "bge 5f\n" // If yes, just carry on without updating the column pointer. + // Not finished last column: then advance to next column. + "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #3\n" + "5:\n" + + // Set the LHS and RHS data pointers to the start of the columns just + // computed. + "mov %[lhs_ptr], %[lhs_col_ptr]\n" + // Load some parameters needed for the end work on current block. + RUY_MAKE_ZERO(v8) + "mov %[rhs_ptr], %[rhs_col_ptr]\n" + "ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" + "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n" + "ins v13.h[4], w4\n" // dst_zero_point + "ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n" + "ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" + "dup v9.4s, w3\n" // create prod_zp_depth_vec + "add x5, x4, %x[row], lsl #2\n" + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" + "csel x4, x4, x5, eq\n" + + "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" + "add x5, x1, %x[row], lsl #2\n" + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" + "csel x1, x1, x5, eq\n" + + // Load 8 bias values. + "ld1 {v14.2s}, [x1], #8\n" + "ldr x5, [x1], #8\n" + "ins v14.d[1], x5\n" + "ld1 {v15.2s}, [x1], #8\n" + "ldr x5, [x1], #8\n" + "ins v15.d[1], x5\n" + + // Add to the bias values the product (depth * lhs_zero_point * rhs_zero_point), + // See the term NZ1Z2 in equation (7) in https://arxiv.org/pdf/1712.05877.pdf + "add v14.4s, v14.4s, v9.4s\n" + "add v15.4s, v15.4s, v9.4s\n" + // Perform the bias-addition (per the above, we have just folded into + // the bias the (depth * lhs_zero_point * rhs_zero_point) term.) + "add v16.4s, v16.4s, v14.4s\n" + "add v17.4s, v17.4s, v15.4s\n" + "add v18.4s, v18.4s, v14.4s\n" + "add v19.4s, v19.4s, v15.4s\n" + "add v20.4s, v20.4s, v14.4s\n" + "add v21.4s, v21.4s, v15.4s\n" + "add v22.4s, v22.4s, v14.4s\n" + "add v23.4s, v23.4s, v15.4s\n" + "add v24.4s, v24.4s, v14.4s\n" + "add v25.4s, v25.4s, v15.4s\n" + "add v26.4s, v26.4s, v14.4s\n" + "add v27.4s, v27.4s, v15.4s\n" + "add v28.4s, v28.4s, v14.4s\n" + "add v29.4s, v29.4s, v15.4s\n" + "add v30.4s, v30.4s, v14.4s\n" + "add v31.4s, v31.4s, v15.4s\n" + + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n" + "beq 401f\n" + "ldr x3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n" + "add x3, x3, %x[col], lsl #2\n" + "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n" + "dup v10.4s, w5\n" // create lhs_zero_point_vec + // Load 8 rhs_sums values. + "ld1 {v14.2s}, [x3], #8\n" + "ldr x7, [x3], #8\n" + "ld1 {v15.2s}, [x3], #8\n" + "ins v14.d[1], x7\n" + "ldr x7, [x3], #8\n" + "ins v15.d[1], x7\n" + // Subtract rhs_sums * lhs_zero_point, per + // equation (7) in https://arxiv.org/pdf/1712.05877.pdf + "mls v16.4s, v10.4s, v14.s[0]\n" + "mls v17.4s, v10.4s, v14.s[0]\n" + "mls v18.4s, v10.4s, v14.s[1]\n" + "mls v19.4s, v10.4s, v14.s[1]\n" + "mls v20.4s, v10.4s, v14.s[2]\n" + "mls v21.4s, v10.4s, v14.s[2]\n" + "mls v22.4s, v10.4s, v14.s[3]\n" + "mls v23.4s, v10.4s, v14.s[3]\n" + "mls v24.4s, v10.4s, v15.s[0]\n" + "mls v25.4s, v10.4s, v15.s[0]\n" + "mls v26.4s, v10.4s, v15.s[1]\n" + "mls v27.4s, v10.4s, v15.s[1]\n" + "mls v28.4s, v10.4s, v15.s[2]\n" + "mls v29.4s, v10.4s, v15.s[2]\n" + "mls v30.4s, v10.4s, v15.s[3]\n" + "mls v31.4s, v10.4s, v15.s[3]\n" + "401:\n" + + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n" + "beq 402f\n" + "ldr x2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n" + "add x2, x2, %x[row], lsl #2\n" + "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n" + "ins v13.s[1], w5\n" // rhs_zero_point + // Load 8 lhs_sums values. + "ld1 {v11.2s}, [x2], #8\n" + "ldr x6, [x2], #8\n" + "ins v11.d[1], x6\n" + "ld1 {v12.2s}, [x2], #8\n" + "ldr x6, [x2], #8\n" + "ins v12.d[1], x6\n" + // Compute lhs_sums * rhs_zero_point. + "mul v11.4s, v11.4s, v13.s[1]\n" + "mul v12.4s, v12.4s, v13.s[1]\n" + // Subtract lhs_sums * rhs_zero_point, per + // equation (7) in https://arxiv.org/pdf/1712.05877.pdf + "sub v16.4s, v16.4s, v11.4s\n" + "sub v17.4s, v17.4s, v12.4s\n" + "sub v18.4s, v18.4s, v11.4s\n" + "sub v19.4s, v19.4s, v12.4s\n" + "sub v20.4s, v20.4s, v11.4s\n" + "sub v21.4s, v21.4s, v12.4s\n" + "sub v22.4s, v22.4s, v11.4s\n" + "sub v23.4s, v23.4s, v12.4s\n" + "sub v24.4s, v24.4s, v11.4s\n" + "sub v25.4s, v25.4s, v12.4s\n" + "sub v26.4s, v26.4s, v11.4s\n" + "sub v27.4s, v27.4s, v12.4s\n" + "sub v28.4s, v28.4s, v11.4s\n" + "sub v29.4s, v29.4s, v12.4s\n" + "sub v30.4s, v30.4s, v11.4s\n" + "sub v31.4s, v31.4s, v12.4s\n" + + "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n" + "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n" + + "402:\n" + + // At this point we have computed the final int32 values. Now we + // start down-quantizing them to obtain the final 8bit values from them. + + // As part of this down-quantization, our int32 values will be + // multiplied by a multiplier that has a fixed-point component and an + // exponent component. + + //Load the exponent part of the multiplier. + "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n" + "ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" + "add x5, x1, %x[row], lsl #2\n" + "csel x1, x1, x5, eq\n" + + "ldr q9, [x1]\n" + "ldr q10, [x1, #16]\n" + + "tst w6, #" RUY_STR(RUY_ASM_FLAG_NEEDS_LEFT_SHIFT) "\n" + "beq 403f\n" + "smax v11.4s, v9.4s, v8.4s\n" + "smax v12.4s, v10.4s, v8.4s\n" + "sshl v16.4s, v16.4s, v11.4s\n" + "sshl v17.4s, v17.4s, v12.4s\n" + "sshl v18.4s, v18.4s, v11.4s\n" + "sshl v19.4s, v19.4s, v12.4s\n" + "sshl v20.4s, v20.4s, v11.4s\n" + "sshl v21.4s, v21.4s, v12.4s\n" + "sshl v22.4s, v22.4s, v11.4s\n" + "sshl v23.4s, v23.4s, v12.4s\n" + "sshl v24.4s, v24.4s, v11.4s\n" + "sshl v25.4s, v25.4s, v12.4s\n" + "sshl v26.4s, v26.4s, v11.4s\n" + "sshl v27.4s, v27.4s, v12.4s\n" + "sshl v28.4s, v28.4s, v11.4s\n" + "sshl v29.4s, v29.4s, v12.4s\n" + "sshl v30.4s, v30.4s, v11.4s\n" + "sshl v31.4s, v31.4s, v12.4s\n" + "403:\n" + + "ldr q14, [x4]\n" // multiplier_fixedpoint + "ldr q15, [x4, #16]\n" // multiplier_fixedpoint + + "smin v11.4s, v9.4s, v8.4s\n" + "smin v12.4s, v10.4s, v8.4s\n" + + // Apply the fixed-point part of the multiplier. + // + // ... and, interleaved into that: + // Now that we know what LHS and RHS data the next iteration of the + // main loop will need to load, we start loading the first 32 bytes of + // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore + // in the rest of the work on the current block. + "ld1 {v0.8b}, [%[lhs_ptr]], #8\n" + "sqrdmulh v16.4s, v16.4s, v14.4s\n" + "ldr x1, [%[lhs_ptr]], #8\n" + "sqrdmulh v17.4s, v17.4s, v15.4s\n" + "ld1 {v1.8b}, [%[lhs_ptr]], #8\n" + "sqrdmulh v18.4s, v18.4s, v14.4s\n" + "ldr x2, [%[lhs_ptr]], #8\n" + "sqrdmulh v19.4s, v19.4s, v15.4s\n" + "ld1 {v2.8b}, [%[rhs_ptr]], #8\n" + "sqrdmulh v20.4s, v20.4s, v14.4s\n" + "ldr x5, [%[rhs_ptr]], #8\n" + "sqrdmulh v21.4s, v21.4s, v15.4s\n" + "ld1 {v3.8b}, [%[rhs_ptr]], #8\n" + "sqrdmulh v22.4s, v22.4s, v14.4s\n" + "ldr x6, [%[rhs_ptr]], #8\n" + "sqrdmulh v23.4s, v23.4s, v15.4s\n" + "sqrdmulh v24.4s, v24.4s, v14.4s\n" + "sqrdmulh v25.4s, v25.4s, v15.4s\n" + "sqrdmulh v26.4s, v26.4s, v14.4s\n" + "sqrdmulh v27.4s, v27.4s, v15.4s\n" + "sqrdmulh v28.4s, v28.4s, v14.4s\n" + "sqrdmulh v29.4s, v29.4s, v15.4s\n" + "sqrdmulh v30.4s, v30.4s, v14.4s\n" + "sqrdmulh v31.4s, v31.4s, v15.4s\n" + + // We have some rounding division-by-power-of-two to do. This should + // always use "round to nearest". We allow for some + // freedom in how ties are broken, to strike a good compromise of + // performance on given hardware vs. perfect agreement of results + // across hardware. + // + // When RUY_OPT_NATIVE_ROUNDING is enabled, we allow for implementation + // defined tie-breaks to help performance. On NEON, this means that we + // can just use the NEON rounding instructions, such as srshl. They + // happen to be breaking ties upward. + // + // When RUY_OPT_NATIVE_ROUNDING is disabled, we implement strict + // break-ties-away-from zero, as described in Appendix B of + // https://arxiv.org/pdf/1712.05877.pdf + // When we wrote that, we thought that that would be better unbiased + // than the NEON upwards tie-breaks, and we had observed some + // improvement on some model. However, that is only more unbiased for + // data centered at zero, which was likely the case in that model, + // but is not always the case. If we wanted something more consistently + // unbiased then we should try breaking ties toward-nearest-even. +#if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING) + // Fix up values to be right-shifted, so that the (round to nearest, + // break ties upward) behavior of srshl applied to these fixed-up + // values, produces the same result as the desired (round to nearest, + // break ties away from zero) behavior on the original values. + "and v8.16b, v16.16b, v11.16b\n" + "and v9.16b, v17.16b, v12.16b\n" + "and v14.16b, v18.16b, v11.16b\n" + "and v15.16b, v19.16b, v12.16b\n" + "sshr v8.4s, v8.4s, #31\n" + "sshr v9.4s, v9.4s, #31\n" + "sshr v14.4s, v14.4s, #31\n" + "sshr v15.4s, v15.4s, #31\n" + "sqadd v16.4s, v16.4s, v8.4s\n" + "sqadd v17.4s, v17.4s, v9.4s\n" + "sqadd v18.4s, v18.4s, v14.4s\n" + "sqadd v19.4s, v19.4s, v15.4s\n" + "and v8.16b, v20.16b, v11.16b\n" + "and v9.16b, v21.16b, v12.16b\n" + "and v14.16b, v22.16b, v11.16b\n" + "and v15.16b, v23.16b, v12.16b\n" + "sshr v8.4s, v8.4s, #31\n" + "sshr v9.4s, v9.4s, #31\n" + "sshr v14.4s, v14.4s, #31\n" + "sshr v15.4s, v15.4s, #31\n" + "sqadd v20.4s, v20.4s, v8.4s\n" + "sqadd v21.4s, v21.4s, v9.4s\n" + "sqadd v22.4s, v22.4s, v14.4s\n" + "sqadd v23.4s, v23.4s, v15.4s\n" + "and v8.16b, v24.16b, v11.16b\n" + "and v9.16b, v25.16b, v12.16b\n" + "and v14.16b, v26.16b, v11.16b\n" + "and v15.16b, v27.16b, v12.16b\n" + "sshr v8.4s, v8.4s, #31\n" + "sshr v9.4s, v9.4s, #31\n" + "sshr v14.4s, v14.4s, #31\n" + "sshr v15.4s, v15.4s, #31\n" + "sqadd v24.4s, v24.4s, v8.4s\n" + "sqadd v25.4s, v25.4s, v9.4s\n" + "sqadd v26.4s, v26.4s, v14.4s\n" + "sqadd v27.4s, v27.4s, v15.4s\n" + "and v8.16b, v28.16b, v11.16b\n" + "and v9.16b, v29.16b, v12.16b\n" + "and v14.16b, v30.16b, v11.16b\n" + "and v15.16b, v31.16b, v12.16b\n" + "sshr v8.4s, v8.4s, #31\n" + "sshr v9.4s, v9.4s, #31\n" + "sshr v14.4s, v14.4s, #31\n" + "sshr v15.4s, v15.4s, #31\n" + "sqadd v28.4s, v28.4s, v8.4s\n" + "sqadd v29.4s, v29.4s, v9.4s\n" + "sqadd v30.4s, v30.4s, v14.4s\n" + "sqadd v31.4s, v31.4s, v15.4s\n" +#endif + // At this point we have reduced the problem of correctly implementing + // rounding divide-by-power-of-two, to what the SRSHL instruction can + // do. + "srshl v16.4s, v16.4s, v11.4s\n" + "srshl v17.4s, v17.4s, v12.4s\n" + "srshl v18.4s, v18.4s, v11.4s\n" + "srshl v19.4s, v19.4s, v12.4s\n" + "srshl v20.4s, v20.4s, v11.4s\n" + "srshl v21.4s, v21.4s, v12.4s\n" + "srshl v22.4s, v22.4s, v11.4s\n" + "srshl v23.4s, v23.4s, v12.4s\n" + "srshl v24.4s, v24.4s, v11.4s\n" + "srshl v25.4s, v25.4s, v12.4s\n" + "srshl v26.4s, v26.4s, v11.4s\n" + "srshl v27.4s, v27.4s, v12.4s\n" + "ins v0.d[1], x1\n" + "srshl v28.4s, v28.4s, v11.4s\n" + "ins v1.d[1], x2\n" + "srshl v29.4s, v29.4s, v12.4s\n" + "ins v2.d[1], x5\n" + "srshl v30.4s, v30.4s, v11.4s\n" + "ins v3.d[1], x6\n" + "srshl v31.4s, v31.4s, v12.4s\n" + + "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n" + "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n" + "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n" + "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n" + + // Cast-and-saturate from int32 to int16 + "sqxtn v16.4h, v16.4s\n" + "sqxtn2 v16.8h, v17.4s\n" + "sqxtn v17.4h, v18.4s\n" + "sqxtn2 v17.8h, v19.4s\n" + "sqxtn v18.4h, v20.4s\n" + "sqxtn2 v18.8h, v21.4s\n" + "sqxtn v19.4h, v22.4s\n" + "sqxtn2 v19.8h, v23.4s\n" + "sqxtn v20.4h, v24.4s\n" + "sqxtn2 v20.8h, v25.4s\n" + "sqxtn v21.4h, v26.4s\n" + "sqxtn2 v21.8h, v27.4s\n" + "sqxtn v22.4h, v28.4s\n" + "sqxtn2 v22.8h, v29.4s\n" + "sqxtn v23.4h, v30.4s\n" + "sqxtn2 v23.8h, v31.4s\n" + + // Destination zero_point + "dup v14.8h, v13.h[4]\n" + // At this point, v24 -- v31 aren't used anymore for the current block, + // so we can start clearing these accumulators for the next block + // (next iteration of the main loop). + RUY_MAKE_ZERO(v24) + RUY_MAKE_ZERO(v25) + RUY_MAKE_ZERO(v26) + RUY_MAKE_ZERO(v27) + RUY_MAKE_ZERO(v28) + RUY_MAKE_ZERO(v29) + RUY_MAKE_ZERO(v30) + RUY_MAKE_ZERO(v31) + + // Add the destination zero point + "add v16.8h, v16.8h, v14.8h\n" + "add v17.8h, v17.8h, v14.8h\n" + "add v18.8h, v18.8h, v14.8h\n" + "add v19.8h, v19.8h, v14.8h\n" + "add v20.8h, v20.8h, v14.8h\n" + "add v21.8h, v21.8h, v14.8h\n" + "add v22.8h, v22.8h, v14.8h\n" + "add v23.8h, v23.8h, v14.8h\n" + + // Load the clamp_min, clamp_max bounds + "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + // Cast-and-saturate from int16 to uint8 + "sqxtun v16.8b, v16.8h\n" + "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + "sqxtun2 v16.16b, v17.8h\n" + "sqxtun v17.8b, v18.8h\n" + "sqxtun2 v17.16b, v19.8h\n" + "sqxtun v18.8b, v20.8h\n" + "sqxtun2 v18.16b, v21.8h\n" + "sqxtun v19.8b, v22.8h\n" + "sqxtun2 v19.16b, v23.8h\n" + + "dup v14.16b, w2\n" // clamp_min + "dup v15.16b, w3\n" // clamp_max + + // Compute how much of the 8x8 block of destination 8bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 8x8, there are some 8x8 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + // Apply the clamp_min bound + "umax v16.16b, v16.16b, v14.16b\n" + "sub w2, %w[dst_cols], %w[col]\n" + "umax v17.16b, v17.16b, v14.16b\n" + "mov w3, #8\n" + "umax v18.16b, v18.16b, v14.16b\n" + "cmp w1, #8\n" + "umax v19.16b, v19.16b, v14.16b\n" + // Compute w1 = how many rows of the 8x8 block fit + "csel w1, w1, w3, le\n" + // Apply the clamp_max bound + "umin v16.16b, v16.16b, v15.16b\n" + "cmp w2, #8\n" + "umin v17.16b, v17.16b, v15.16b\n" + // Compute w2 = how many cols of the 8x8 block fit + "csel w2, w2, w3, le\n" + "umin v18.16b, v18.16b, v15.16b\n" + "umin v19.16b, v19.16b, v15.16b\n" + + // Make it so that all of the final 8bit values are stored in the + // first 64bits of 128bit NEON registers, so they can be stored + // by 64bit st1 store instructions with byte alignment. + "dup d20, v16.d[1]\n" + "dup d21, v17.d[1]\n" + "dup d22, v18.d[1]\n" + "dup d23, v19.d[1]\n" + + // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. + "cmp w1, w3\n" + "ccmp w2, w3, 0, eq\n" + // Yes, all of the 8x8 block fits, go to fast path. + "beq 30f\n" + // Not all of the 8x8 block fits. + // Set (x3 address, x4 stride) to write to dst_tmp_buf + "mov x3, %[dst_tmp_buf]\n" + "mov x4, #8\n" + "b 31f\n" + "30:\n" + // Yes, all of the 8x8 block fits. + // Set (x3 address, x4 stride) to write directly to destination matrix. + "mov x3, %[dst_ptr]\n" + "mov x4, x11\n" + "31:\n" + + // Write our 8bit values to the destination described by + // (x3 address, x4 stride). + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v16.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v16) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v20.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v20) + // For the next block: perform the first few multiply-adds on the data + // that we have already loaded. + ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v17.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v17) + ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v21.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v21) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v18.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v18) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v22.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v22) + ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v19.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v19) + ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v23.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v23) + + // If all of the 8x8 block fits, we just finished writing it to the + // destination, so we skip the next part. + "beq 41f\n" + // Not all of the 8x8 block fits in the destination matrix. We just + // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over + // it to copy into the destination matrix the part that fits. + "mov x3, %[dst_tmp_buf]\n" + "mov x4, %[dst_ptr]\n" + "mov w6, #0\n" + "50:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov w5, #0\n" + "51:\n" + "ldrb w7, [x3, w5, uxtw]\n" + "strb w7, [x4, w5, uxtw]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 51b\n" + "add w6, w6, #1\n" + "add x3, x3, #8\n" + "add x4, x4, x11\n" + "cmp w6, w2\n" + "blt 50b\n" + "41:\n" + "add %[dst_ptr], %[dst_ptr], #8\n" + + // At this point we have completely finished writing values to the + // destination matrix for the current block. + + "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n" + + // Cast-and-saturate from int32 to int16 + "sqxtn v16.4h, v16.4s\n" + "sqxtn2 v16.8h, v17.4s\n" + "sqxtn v17.4h, v18.4s\n" + "sqxtn2 v17.8h, v19.4s\n" + "sqxtn v18.4h, v20.4s\n" + "sqxtn2 v18.8h, v21.4s\n" + "sqxtn v19.4h, v22.4s\n" + "sqxtn2 v19.8h, v23.4s\n" + "sqxtn v20.4h, v24.4s\n" + "sqxtn2 v20.8h, v25.4s\n" + "sqxtn v21.4h, v26.4s\n" + "sqxtn2 v21.8h, v27.4s\n" + "sqxtn v22.4h, v28.4s\n" + "sqxtn2 v22.8h, v29.4s\n" + "sqxtn v23.4h, v30.4s\n" + "sqxtn2 v23.8h, v31.4s\n" + + // Destination zero_point + "dup v14.8h, v13.h[4]\n" + // At this point, v24 -- v31 aren't used anymore for the current block, + // so we can start clearing these accumulators for the next block + // (next iteration of the main loop). + RUY_MAKE_ZERO(v24) + RUY_MAKE_ZERO(v25) + RUY_MAKE_ZERO(v26) + RUY_MAKE_ZERO(v27) + RUY_MAKE_ZERO(v28) + RUY_MAKE_ZERO(v29) + RUY_MAKE_ZERO(v30) + RUY_MAKE_ZERO(v31) + + // Add the destination zero point + "add v16.8h, v16.8h, v14.8h\n" + "add v17.8h, v17.8h, v14.8h\n" + "add v18.8h, v18.8h, v14.8h\n" + "add v19.8h, v19.8h, v14.8h\n" + "add v20.8h, v20.8h, v14.8h\n" + "add v21.8h, v21.8h, v14.8h\n" + "add v22.8h, v22.8h, v14.8h\n" + "add v23.8h, v23.8h, v14.8h\n" + + // Load the clamp_min, clamp_max bounds + "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + // Cast-and-saturate from int16 to uint8 + "sqxtn v16.8b, v16.8h\n" + "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + "sqxtn2 v16.16b, v17.8h\n" + "sqxtn v17.8b, v18.8h\n" + "sqxtn2 v17.16b, v19.8h\n" + "sqxtn v18.8b, v20.8h\n" + "sqxtn2 v18.16b, v21.8h\n" + "sqxtn v19.8b, v22.8h\n" + "sqxtn2 v19.16b, v23.8h\n" + + "dup v14.16b, w2\n" // clamp_min + "dup v15.16b, w3\n" // clamp_max + + // Compute how much of the 8x8 block of destination 8bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 8x8, there are some 8x8 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + // Apply the clamp_min bound + "smax v16.16b, v16.16b, v14.16b\n" + "sub w2, %w[dst_cols], %w[col]\n" + "smax v17.16b, v17.16b, v14.16b\n" + "mov w3, #8\n" + "smax v18.16b, v18.16b, v14.16b\n" + "cmp w1, #8\n" + "smax v19.16b, v19.16b, v14.16b\n" + // Compute w1 = how many rows of the 8x8 block fit + "csel w1, w1, w3, le\n" + // Apply the clamp_max bound + "smin v16.16b, v16.16b, v15.16b\n" + "cmp w2, #8\n" + "smin v17.16b, v17.16b, v15.16b\n" + // Compute w2 = how many cols of the 8x8 block fit + "csel w2, w2, w3, le\n" + "smin v18.16b, v18.16b, v15.16b\n" + "smin v19.16b, v19.16b, v15.16b\n" + + // Make it so that all of the final 8bit values are stored in the + // first 64bits of 128bit NEON registers, so they can be stored + // by 64bit st1 store instructions with byte alignment. + "dup d20, v16.d[1]\n" + "dup d21, v17.d[1]\n" + "dup d22, v18.d[1]\n" + "dup d23, v19.d[1]\n" + + // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. + "cmp w1, w3\n" + "ccmp w2, w3, 0, eq\n" + // Yes, all of the 8x8 block fits, go to fast path. + "beq 130f\n" + // Not all of the 8x8 block fits. + // Set (x3 address, x4 stride) to write to dst_tmp_buf + "mov x3, %[dst_tmp_buf]\n" + "mov x4, #8\n" + "b 131f\n" + "130:\n" + // Yes, all of the 8x8 block fits. + // Set (x3 address, x4 stride) to write directly to destination matrix. + "mov x3, %[dst_ptr]\n" + "mov x4, x11\n" + "131:\n" + + // Write our 8bit values to the destination described by + // (x3 address, x4 stride). + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v16.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v16) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v20.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v20) + // For the next block: perform the first few multiply-adds on the data + // that we have already loaded. + ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v17.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v17) + ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v21.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v21) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v18.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v18) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v22.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v22) + ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v19.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v19) + ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v23.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v23) + + // If all of the 8x8 block fits, we just finished writing it to the + // destination, so we skip the next part. + "beq 141f\n" + // Not all of the 8x8 block fits in the destination matrix. We just + // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over + // it to copy into the destination matrix the part that fits. + "mov x3, %[dst_tmp_buf]\n" + "mov x4, %[dst_ptr]\n" + "mov w6, #0\n" + "150:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov w5, #0\n" + "151:\n" + "ldrb w7, [x3, w5, uxtw]\n" + "strb w7, [x4, w5, uxtw]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 151b\n" + "add w6, w6, #1\n" + "add x3, x3, #8\n" + "add x4, x4, x11\n" + "cmp w6, w2\n" + "blt 150b\n" + "141:\n" + "add %[dst_ptr], %[dst_ptr], #8\n" + + // At this point we have completely finished writing values to the + // destination matrix for the current block. + + "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n" + + // Add the destination zero point + "dup v14.8h, v13.h[4]\n" + "saddw v16.4s, v16.4s, v14.4h\n" + "saddw v17.4s, v17.4s, v14.4h\n" + "saddw v18.4s, v18.4s, v14.4h\n" + "saddw v19.4s, v19.4s, v14.4h\n" + "saddw v20.4s, v20.4s, v14.4h\n" + "saddw v21.4s, v21.4s, v14.4h\n" + "saddw v22.4s, v22.4s, v14.4h\n" + "saddw v23.4s, v23.4s, v14.4h\n" + "saddw v24.4s, v24.4s, v14.4h\n" + "saddw v25.4s, v25.4s, v14.4h\n" + "saddw v26.4s, v26.4s, v14.4h\n" + "saddw v27.4s, v27.4s, v14.4h\n" + "saddw v28.4s, v28.4s, v14.4h\n" + "saddw v29.4s, v29.4s, v14.4h\n" + "saddw v30.4s, v30.4s, v14.4h\n" + "saddw v31.4s, v31.4s, v14.4h\n" + + // Cast-and-saturate from int32 to int16 + "sqxtn v16.4h, v16.4s\n" + "sqxtn2 v16.8h, v17.4s\n" + "sqxtn v17.4h, v18.4s\n" + "sqxtn2 v17.8h, v19.4s\n" + "sqxtn v18.4h, v20.4s\n" + "sqxtn2 v18.8h, v21.4s\n" + "sqxtn v19.4h, v22.4s\n" + "sqxtn2 v19.8h, v23.4s\n" + "sqxtn v20.4h, v24.4s\n" + "sqxtn2 v20.8h, v25.4s\n" + "sqxtn v21.4h, v26.4s\n" + "sqxtn2 v21.8h, v27.4s\n" + "sqxtn v22.4h, v28.4s\n" + "sqxtn2 v22.8h, v29.4s\n" + "sqxtn v23.4h, v30.4s\n" + "sqxtn2 v23.8h, v31.4s\n" + + // At this point, v24 -- v31 aren't used anymore for the current block, + // so we can start clearing these accumulators for the next block + // (next iteration of the main loop). + RUY_MAKE_ZERO(v24) + RUY_MAKE_ZERO(v25) + RUY_MAKE_ZERO(v26) + RUY_MAKE_ZERO(v27) + RUY_MAKE_ZERO(v28) + RUY_MAKE_ZERO(v29) + RUY_MAKE_ZERO(v30) + RUY_MAKE_ZERO(v31) + + // Load the clamp_min, clamp_max bounds + "ldrsh w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + "ldrsh w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + "dup v14.8h, w2\n" // clamp_min + "dup v15.8h, w3\n" // clamp_max + + // Apply the clamp_min bound + "smax v16.8h, v16.8h, v14.8h\n" + "smax v17.8h, v17.8h, v14.8h\n" + "smax v18.8h, v18.8h, v14.8h\n" + "smax v19.8h, v19.8h, v14.8h\n" + "smax v20.8h, v20.8h, v14.8h\n" + "smax v21.8h, v21.8h, v14.8h\n" + "smax v22.8h, v22.8h, v14.8h\n" + "smax v23.8h, v23.8h, v14.8h\n" + // Apply the clamp_max bound + "smin v16.8h, v16.8h, v15.8h\n" + "smin v17.8h, v17.8h, v15.8h\n" + "smin v18.8h, v18.8h, v15.8h\n" + "smin v19.8h, v19.8h, v15.8h\n" + "smin v20.8h, v20.8h, v15.8h\n" + "smin v21.8h, v21.8h, v15.8h\n" + "smin v22.8h, v22.8h, v15.8h\n" + "smin v23.8h, v23.8h, v15.8h\n" + + // Compute how much of the 8x8 block of destination 16bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 8x8, there are some 8x8 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + "sub w2, %w[dst_cols], %w[col]\n" + "mov w3, #8\n" + "cmp w1, #8\n" + // Compute w1 = how many rows of the 8x8 block fit + "csel w1, w1, w3, le\n" + "cmp w2, #8\n" + // Compute w1 = how many rows of the 8x8 block fit + "csel w2, w2, w3, le\n" + + // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. + "cmp w1, w3\n" + "ccmp w2, w3, 0, eq\n" + // Yes, all of the 8x8 block fits, go to fast path. + "beq 230f\n" + // Not all of the 8x8 block fits. + // Set (x3 address, x4 stride) to write to dst_tmp_buf + "mov x3, %[dst_tmp_buf]\n" + "mov x4, #16\n" + "b 231f\n" + "230:\n" + // Yes, all of the 8x8 block fits. + // Set (x3 address, x4 stride) to write directly to destination matrix. + "mov x3, %[dst_ptr]\n" + "mov x4, x11\n" + "231:\n" + + // Write our 8bit values to the destination described by + // (x3 address, x4 stride). + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v16.8h}, [x3], x4\n" + RUY_MAKE_ZERO(v16) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v17.8h}, [x3], x4\n" + RUY_MAKE_ZERO(v17) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v18.8h}, [x3], x4\n" + RUY_MAKE_ZERO(v18) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v19.8h}, [x3], x4\n" + RUY_MAKE_ZERO(v19) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v20.8h}, [x3], x4\n" + RUY_MAKE_ZERO(v20) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v21.8h}, [x3], x4\n" + RUY_MAKE_ZERO(v21) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v22.8h}, [x3], x4\n" + RUY_MAKE_ZERO(v22) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v23.8h}, [x3], x4\n" + RUY_MAKE_ZERO(v23) + + // For the next block: perform the first few multiply-adds on the data + // that we have already loaded. + ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" + ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" + ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" + ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" + + // If all of the 8x8 block fits, we just finished writing it to the + // destination, so we skip the next part. + "beq 241f\n" + // Not all of the 8x8 block fits in the destination matrix. We just + // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over + // it to copy into the destination matrix the part that fits. + "mov x3, %[dst_tmp_buf]\n" + "mov x4, %[dst_ptr]\n" + "mov w6, #0\n" + "250:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov w5, #0\n" + "251:\n" + "ldrsh w7, [x3, x5, lsl #1]\n" + "strh w7, [x4, x5, lsl #1]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 251b\n" + "add w6, w6, #1\n" + "add x3, x3, #16\n" + "add x4, x4, x11\n" + "cmp w6, w2\n" + "blt 250b\n" + "241:\n" + "add %[dst_ptr], %[dst_ptr], #16\n" + // At this point we have completely finished writing values to the + // destination matrix for the current block. + + "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n" + + "ld1 {v0.8b}, [%[lhs_ptr]], #8\n" + "ldr x1, [%[lhs_ptr]], #8\n" + "ld1 {v1.8b}, [%[lhs_ptr]], #8\n" + "ldr x2, [%[lhs_ptr]], #8\n" + "ld1 {v2.8b}, [%[rhs_ptr]], #8\n" + "ldr x5, [%[rhs_ptr]], #8\n" + "ld1 {v3.8b}, [%[rhs_ptr]], #8\n" + "ldr x6, [%[rhs_ptr]], #8\n" + "ins v0.d[1], x1\n" + "ins v1.d[1], x2\n" + "ins v2.d[1], x5\n" + "ins v3.d[1], x6\n" + + // Since the store type is the same as the accum type, no need for + // downcast. There's also no need for clamp by min/max. + + // Compute how much of the 8x8 block of destination 32it values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 8x8, there are some 8x8 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + "sub w2, %w[dst_cols], %w[col]\n" + "mov w3, #8\n" + "cmp w1, #8\n" + // Compute w1 = how many rows of the 8x8 block fit + "csel w1, w1, w3, le\n" + "cmp w2, #8\n" + // Compute w1 = how many rows of the 8x8 block fit + "csel w2, w2, w3, le\n" + + // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. + "cmp w1, w3\n" + "ccmp w2, w3, 0, eq\n" + // Yes, all of the 8x8 block fits, go to fast path. + "beq 330f\n" + // Not all of the 8x8 block fits. + // Write to dst_tmp_buf + "mov x3, %[dst_tmp_buf]\n" + "st1 {v16.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v16) + "st1 {v17.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v17) + "st1 {v18.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v18) + "st1 {v19.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v19) + "st1 {v20.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v20) + "st1 {v21.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v21) + "st1 {v22.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v22) + "st1 {v23.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v23) + "st1 {v24.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v24) + "st1 {v25.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v25) + "st1 {v26.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v26) + "st1 {v27.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v27) + "st1 {v28.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v28) + "st1 {v29.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v29) + "st1 {v30.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v30) + "st1 {v31.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v31) + + "b 331f\n" + + "330:\n" + // Yes, all of the 8x8 block fits. + "mov x4, %[dst_ptr]\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "st1 {v16.4s, v17.4s}, [x4], x11\n" + RUY_MAKE_ZERO(v16) + RUY_MAKE_ZERO(v17) + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "st1 {v18.4s, v19.4s}, [x4], x11\n" + RUY_MAKE_ZERO(v18) + RUY_MAKE_ZERO(v19) + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "st1 {v20.4s, v21.4s}, [x4], x11\n" + RUY_MAKE_ZERO(v20) + RUY_MAKE_ZERO(v21) + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "st1 {v22.4s, v23.4s}, [x4], x11\n" + RUY_MAKE_ZERO(v22) + RUY_MAKE_ZERO(v23) + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "st1 {v24.4s, v25.4s}, [x4], x11\n" + RUY_MAKE_ZERO(v24) + RUY_MAKE_ZERO(v25) + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "st1 {v26.4s, v27.4s}, [x4], x11\n" + RUY_MAKE_ZERO(v26) + RUY_MAKE_ZERO(v27) + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "st1 {v28.4s, v29.4s}, [x4], x11\n" + RUY_MAKE_ZERO(v28) + RUY_MAKE_ZERO(v29) + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "st1 {v30.4s, v31.4s}, [x4], x11\n" + RUY_MAKE_ZERO(v30) + RUY_MAKE_ZERO(v31) + + "331:\n" + + // For the next block: perform the first few multiply-adds on the data + // that we have already loaded. + ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" + ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" + ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" + ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" + + // If all of the 8x8 block fits, we just finished writing it to the + // destination, so we skip the next part. + "beq 341f\n" + + // Not all of the 8x8 block fits in the destination matrix. We just + // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over + // it to copy into the destination matrix the part that fits. + "mov x3, %[dst_tmp_buf]\n" + "mov x4, %[dst_ptr]\n" + "mov w6, #0\n" + "350:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov w5, #0\n" + "351:\n" + "ldr w7, [x3, x5, lsl #2]\n" + "str w7, [x4, x5, lsl #2]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 351b\n" + "add w6, w6, #1\n" + "add x3, x3, #32\n" + "add x4, x4, x11\n" + "cmp w6, w2\n" + "blt 350b\n" + "341:\n" + "add %[dst_ptr], %[dst_ptr], #32\n" + // At this point we have completely finished writing values to the + // destination matrix for the current block. + + RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n" + + // Reload some params --- we had used x5 -- x7 for a few other things + // since the last time we had loaded them. + "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" + "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" + + // Move to the next block of the destination matrix, for the next iter + // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already + // been updated earlier. + // Have we reached the end row? + "cmp %w[row], w7\n" + "beq 20f\n" // yes, end row. + // Not end row. Move to the next row. + "add %w[row], %w[row], #8\n" + "b 21f\n" + "20:\n" + // Was already at end row. + "mov %w[row], w6\n" // Move back to first row. + "add %w[col], %w[col], #8\n" // Move to the next column. + "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #3\n" + "mov %[dst_ptr], %[dst_col_ptr]\n" + "21:\n" + + // Main loop exit condition: have we hit the end column? + "cmp %w[col], w8\n" + "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + "ble 1b\n" + + // clang-format on + + : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr), + [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), + [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col) + : [ params ] "r"(¶ms), [dst_rows] "r"(params.dst_rows), + [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf), + [dst_type_id] "r"(params.dst_type_id) + : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc", + "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", + "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31"); +} +#undef RUY_OFFSET_BIAS +#undef RUY_OFFSET_LHS_SUMS +#undef RUY_OFFSET_RHS_SUMS +#undef RUY_OFFSET_LHS_BASE_PTR +#undef RUY_OFFSET_MULTIPLIER_FIXEDPOINT +#undef RUY_OFFSET_MULTIPLIER_EXPONENT +#undef RUY_OFFSET_RHS_BASE_PTR +#undef RUY_OFFSET_DST_BASE_PTR +#undef RUY_OFFSET_LHS_ZERO_POINT +#undef RUY_OFFSET_RHS_ZERO_POINT +#undef RUY_OFFSET_DST_ZERO_POINT +#undef RUY_OFFSET_PROD_ZP_DEPTH +#undef RUY_OFFSET_START_ROW +#undef RUY_OFFSET_START_COL +#undef RUY_OFFSET_LAST_ROW +#undef RUY_OFFSET_LAST_COL +#undef RUY_OFFSET_DST_ROWS +#undef RUY_OFFSET_DST_COLS +#undef RUY_OFFSET_LHS_STRIDE +#undef RUY_OFFSET_RHS_STRIDE +#undef RUY_OFFSET_DST_STRIDE +#undef RUY_OFFSET_DEPTH +#undef RUY_OFFSET_CLAMP_MIN +#undef RUY_OFFSET_CLAMP_MAX +#undef RUY_OFFSET_FLAGS + +#define RUY_OFFSET_LHS_BASE_PTR 0 +#define RUY_OFFSET_RHS_BASE_PTR 8 +#define RUY_OFFSET_DST_BASE_PTR 16 +#define RUY_OFFSET_BIAS 24 +#define RUY_OFFSET_START_ROW 32 +#define RUY_OFFSET_START_COL 36 +#define RUY_OFFSET_LAST_ROW 40 +#define RUY_OFFSET_LAST_COL 44 +#define RUY_OFFSET_LHS_STRIDE 56 +#define RUY_OFFSET_RHS_STRIDE 60 +#define RUY_OFFSET_DST_STRIDE 64 +#define RUY_OFFSET_DEPTH 68 +#define RUY_OFFSET_CLAMP_MIN 72 +#define RUY_OFFSET_CLAMP_MAX 76 +#define RUY_OFFSET_FLAGS 80 + +template +void CheckOffsetsInKernelParamsFloat(const Params&) { + static_assert(offsetof(Params, lhs_base_ptr) == RUY_OFFSET_LHS_BASE_PTR, ""); + static_assert(offsetof(Params, rhs_base_ptr) == RUY_OFFSET_RHS_BASE_PTR, ""); + static_assert(offsetof(Params, dst_base_ptr) == RUY_OFFSET_DST_BASE_PTR, ""); + static_assert(offsetof(Params, bias) == RUY_OFFSET_BIAS, ""); + static_assert(offsetof(Params, start_row) == RUY_OFFSET_START_ROW, ""); + static_assert(offsetof(Params, start_col) == RUY_OFFSET_START_COL, ""); + static_assert(offsetof(Params, last_row) == RUY_OFFSET_LAST_ROW, ""); + static_assert(offsetof(Params, last_col) == RUY_OFFSET_LAST_COL, ""); + static_assert(offsetof(Params, lhs_stride) == RUY_OFFSET_LHS_STRIDE, ""); + static_assert(offsetof(Params, rhs_stride) == RUY_OFFSET_RHS_STRIDE, ""); + static_assert(offsetof(Params, dst_stride) == RUY_OFFSET_DST_STRIDE, ""); + static_assert(offsetof(Params, depth) == RUY_OFFSET_DEPTH, ""); + static_assert(offsetof(Params, clamp_min) == RUY_OFFSET_CLAMP_MIN, ""); + static_assert(offsetof(Params, clamp_max) == RUY_OFFSET_CLAMP_MAX, ""); + static_assert(offsetof(Params, flags) == RUY_OFFSET_FLAGS, ""); +} + +// Just a plain float kernel; good enough for out-of-order cores. +// The closest to it in the gemmlowp collection would be +// NEON_64bit_GEMM_Float32_WithScalar, +// https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L3925 +// +// Besides ruy-ification, the main nuance here is that we stick to a 8x8 +// width instead of the wider 12x8 that the register space permits and that +// the aforementioned gemmlowp kernel uses. Ruy likes powers of two for now +// and we don't have evidence that going beyond 8x8 is needed. +void KernelFloatNeonOutOfOrder(const KernelParamsFloat<8, 8>& params) { + CheckOffsetsInKernelParamsFloat(params); + profiler::ScopeLabel label( + "Kernel (kNeon, optimized for out-of-order cores)"); + + const float* lhs_col_ptr = params.lhs_base_ptr; + const float* rhs_col_ptr = params.rhs_base_ptr; + const float* lhs_ptr = lhs_col_ptr; + const float* rhs_ptr = rhs_col_ptr; + float* dst_col_ptr = params.dst_base_ptr; + float* dst_ptr = dst_col_ptr; + int row = params.start_row; + int col = params.start_col; + + // The asm kernel below has the following NEON register allocation: + // + // v16 -- v31 are accumulators. + // During accumulation, v0 -- v15 are used to load data from LHS and RHS. + // At least v0 and v1 are used to load a 8x1 block of LHS, and v2 and + // v3 are used to load a 1x8 block of RHS, like this: + // + // RHS 1x8 block + // /-----------------------------------------\ + // |v2.s[0] ... v2.s[3] v3.s[0] ... v3.s[3]| + // \-----------------------------------------/ + // LHS 8x1 block + // /---------------------\ /-----------------------------------------\ + // | v0.s[0] | |v16.s[0] ... v30.s[0]| + // | ... | | ... ... | + // | v0.s[3] | |v16.s[3] ... v30.s[3]| + // | v1.s[0] | |v17.s[0] ... v31.s[0]| + // | ... | | ... ... | + // | v1.s[3] | |v17.s[3] ... v31.s[3]| + // \---------------------/ \-----------------------------------------/ + // accumulators 8x8 block + // + // In the RUY_OPT_MAX_STREAMING part of the kernel, this elementary step + // is repeated 4 times, using 4x more registers for LHS and RHS, so that + // is where instead of using v0 -- v3 for LHS and RHS, we use v0 -- v15. + // + // Outside of the RUY_OPT_MAX_STREAMING part of the kernel, v4 -- v7 are + // unused, and v8 -- v15 are used for floading parameters used for the + // post-accumulation part of the kernel. + asm volatile( +#define RUY_MAKE_ZERO(reg) "dup " #reg ".4s, wzr\n" + + // clang-format off + + // Load some parameters into registers. + "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" + "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" + "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" + "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" + "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" + "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" + "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" + + // Load the first 32 bytes of LHS and RHS data. + "ld1 {v0.4s}, [%[lhs_ptr]], #16\n" + "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" + "ld1 {v2.4s}, [%[rhs_ptr]], #16\n" + "ld1 {v3.4s}, [%[rhs_ptr]], #16\n" + + // Clear accumulators. + RUY_MAKE_ZERO(v16) + RUY_MAKE_ZERO(v17) + RUY_MAKE_ZERO(v18) + RUY_MAKE_ZERO(v19) + RUY_MAKE_ZERO(v20) + RUY_MAKE_ZERO(v21) + RUY_MAKE_ZERO(v22) + RUY_MAKE_ZERO(v23) + RUY_MAKE_ZERO(v24) + RUY_MAKE_ZERO(v25) + RUY_MAKE_ZERO(v26) + RUY_MAKE_ZERO(v27) + RUY_MAKE_ZERO(v28) + RUY_MAKE_ZERO(v29) + RUY_MAKE_ZERO(v30) + RUY_MAKE_ZERO(v31) + + // w1 is the number of levels of depth that we have already loaded + // LHS and RHS data for. Corresponding to the initial ld1 instructions + // above, this is currently 1. + "mov w1, #1\n" + + // Main loop of the whole GEMM, over rows and columns of the + // destination matrix. + "1:\n" + + "fmla v16.4s, v0.4s, v2.s[0]\n" + "fmla v18.4s, v0.4s, v2.s[1]\n" + "fmla v20.4s, v0.4s, v2.s[2]\n" + "fmla v22.4s, v0.4s, v2.s[3]\n" + +#if RUY_OPT_ENABLED(RUY_OPT_MAX_STREAMING) + "cmp w12, #8\n" + "blt 78f\n" + "and w2, w12, #-4\n" + + "ld1 {v4.4s}, [%[lhs_ptr]], #16\n" + "ld1 {v5.4s}, [%[lhs_ptr]], #16\n" + "ld1 {v6.4s}, [%[rhs_ptr]], #16\n" + "ld1 {v7.4s}, [%[rhs_ptr]], #16\n" + + "ld1 {v8.4s}, [%[lhs_ptr]], #16\n" + "ld1 {v9.4s}, [%[lhs_ptr]], #16\n" + "ld1 {v10.4s}, [%[rhs_ptr]], #16\n" + "ld1 {v11.4s}, [%[rhs_ptr]], #16\n" + + "ld1 {v12.4s}, [%[lhs_ptr]], #16\n" + "ld1 {v13.4s}, [%[lhs_ptr]], #16\n" + "ld1 {v14.4s}, [%[rhs_ptr]], #16\n" + "ld1 {v15.4s}, [%[rhs_ptr]], #16\n" + "mov w1, #4\n" + + "80:\n" + + "add %[lhs_ptr], %[lhs_ptr], #128\n" + "add %[rhs_ptr], %[rhs_ptr], #128\n" + + "fmla v24.4s, v0.4s, v3.s[0]\n" + "fmla v26.4s, v0.4s, v3.s[1]\n" + "fmla v28.4s, v0.4s, v3.s[2]\n" + "fmla v30.4s, v0.4s, v3.s[3]\n" + "ldr q0, [%[lhs_ptr], #-128]\n" + "fmla v25.4s, v1.4s, v3.s[0]\n" + "fmla v27.4s, v1.4s, v3.s[1]\n" + "fmla v29.4s, v1.4s, v3.s[2]\n" + "fmla v31.4s, v1.4s, v3.s[3]\n" + "ldr q3, [%[rhs_ptr], #-112]\n" + "fmla v17.4s, v1.4s, v2.s[0]\n" + "fmla v19.4s, v1.4s, v2.s[1]\n" + "fmla v21.4s, v1.4s, v2.s[2]\n" + "fmla v23.4s, v1.4s, v2.s[3]\n" + "ldr q1, [%[lhs_ptr], #-112]\n" + "fmla v16.4s, v4.4s, v6.s[0]\n" + "fmla v18.4s, v4.4s, v6.s[1]\n" + "ldr q2, [%[rhs_ptr], #-128]\n" + "fmla v20.4s, v4.4s, v6.s[2]\n" + "fmla v22.4s, v4.4s, v6.s[3]\n" + + "fmla v24.4s, v4.4s, v7.s[0]\n" + "fmla v26.4s, v4.4s, v7.s[1]\n" + "fmla v28.4s, v4.4s, v7.s[2]\n" + "fmla v30.4s, v4.4s, v7.s[3]\n" + "ldr q4, [%[lhs_ptr], #-96]\n" + "fmla v25.4s, v5.4s, v7.s[0]\n" + "fmla v27.4s, v5.4s, v7.s[1]\n" + "fmla v29.4s, v5.4s, v7.s[2]\n" + "fmla v31.4s, v5.4s, v7.s[3]\n" + "ldr q7, [%[rhs_ptr], #-80]\n" + "fmla v17.4s, v5.4s, v6.s[0]\n" + "fmla v19.4s, v5.4s, v6.s[1]\n" + "fmla v21.4s, v5.4s, v6.s[2]\n" + "fmla v23.4s, v5.4s, v6.s[3]\n" + "ldr q5, [%[lhs_ptr], #-80]\n" + "fmla v16.4s, v8.4s, v10.s[0]\n" + "fmla v18.4s, v8.4s, v10.s[1]\n" + "ldr q6, [%[rhs_ptr], #-96]\n" + "fmla v20.4s, v8.4s, v10.s[2]\n" + "fmla v22.4s, v8.4s, v10.s[3]\n" + + "fmla v24.4s, v8.4s, v11.s[0]\n" + "fmla v26.4s, v8.4s, v11.s[1]\n" + "fmla v28.4s, v8.4s, v11.s[2]\n" + "fmla v30.4s, v8.4s, v11.s[3]\n" + "ldr q8, [%[lhs_ptr], #-64]\n" + "fmla v25.4s, v9.4s, v11.s[0]\n" + "fmla v27.4s, v9.4s, v11.s[1]\n" + "fmla v29.4s, v9.4s, v11.s[2]\n" + "fmla v31.4s, v9.4s, v11.s[3]\n" + "ldr q11, [%[rhs_ptr], #-48]\n" + "fmla v17.4s, v9.4s, v10.s[0]\n" + "fmla v19.4s, v9.4s, v10.s[1]\n" + "fmla v21.4s, v9.4s, v10.s[2]\n" + "fmla v23.4s, v9.4s, v10.s[3]\n" + "ldr q9, [%[lhs_ptr], #-48]\n" + "fmla v16.4s, v12.4s, v14.s[0]\n" + "fmla v18.4s, v12.4s, v14.s[1]\n" + "ldr q10, [%[rhs_ptr], #-64]\n" + "fmla v20.4s, v12.4s, v14.s[2]\n" + "fmla v22.4s, v12.4s, v14.s[3]\n" + + "fmla v24.4s, v12.4s, v15.s[0]\n" + "fmla v26.4s, v12.4s, v15.s[1]\n" + "fmla v28.4s, v12.4s, v15.s[2]\n" + "fmla v30.4s, v12.4s, v15.s[3]\n" + "ldr q12, [%[lhs_ptr], #-32]\n" + "fmla v25.4s, v13.4s, v15.s[0]\n" + "fmla v27.4s, v13.4s, v15.s[1]\n" + "fmla v29.4s, v13.4s, v15.s[2]\n" + "fmla v31.4s, v13.4s, v15.s[3]\n" + "ldr q15, [%[rhs_ptr], #-16]\n" + "fmla v17.4s, v13.4s, v14.s[0]\n" + "fmla v19.4s, v13.4s, v14.s[1]\n" + "fmla v21.4s, v13.4s, v14.s[2]\n" + "fmla v23.4s, v13.4s, v14.s[3]\n" + "ldr q13, [%[lhs_ptr], #-16]\n" + "fmla v16.4s, v0.4s, v2.s[0]\n" + "fmla v18.4s, v0.4s, v2.s[1]\n" + "ldr q14, [%[rhs_ptr], #-32]\n" + "fmla v20.4s, v0.4s, v2.s[2]\n" + "fmla v22.4s, v0.4s, v2.s[3]\n" + + "add w1, w1, #4\n" + "cmp w1, w2\n" + "blt 80b\n" + + "fmla v16.4s, v4.4s, v6.s[0]\n" + "fmla v18.4s, v4.4s, v6.s[1]\n" + "fmla v20.4s, v4.4s, v6.s[2]\n" + "fmla v22.4s, v4.4s, v6.s[3]\n" + "fmla v24.4s, v4.4s, v7.s[0]\n" + "fmla v26.4s, v4.4s, v7.s[1]\n" + "fmla v28.4s, v4.4s, v7.s[2]\n" + "fmla v30.4s, v4.4s, v7.s[3]\n" + "fmla v25.4s, v5.4s, v7.s[0]\n" + "fmla v27.4s, v5.4s, v7.s[1]\n" + "fmla v29.4s, v5.4s, v7.s[2]\n" + "fmla v31.4s, v5.4s, v7.s[3]\n" + "fmla v17.4s, v5.4s, v6.s[0]\n" + "fmla v19.4s, v5.4s, v6.s[1]\n" + "fmla v21.4s, v5.4s, v6.s[2]\n" + "fmla v23.4s, v5.4s, v6.s[3]\n" + + "fmla v16.4s, v8.4s, v10.s[0]\n" + "fmla v18.4s, v8.4s, v10.s[1]\n" + "fmla v20.4s, v8.4s, v10.s[2]\n" + "fmla v22.4s, v8.4s, v10.s[3]\n" + "fmla v24.4s, v8.4s, v11.s[0]\n" + "fmla v26.4s, v8.4s, v11.s[1]\n" + "fmla v28.4s, v8.4s, v11.s[2]\n" + "fmla v30.4s, v8.4s, v11.s[3]\n" + "fmla v25.4s, v9.4s, v11.s[0]\n" + "fmla v27.4s, v9.4s, v11.s[1]\n" + "fmla v29.4s, v9.4s, v11.s[2]\n" + "fmla v31.4s, v9.4s, v11.s[3]\n" + "fmla v17.4s, v9.4s, v10.s[0]\n" + "fmla v19.4s, v9.4s, v10.s[1]\n" + "fmla v21.4s, v9.4s, v10.s[2]\n" + "fmla v23.4s, v9.4s, v10.s[3]\n" + + "fmla v16.4s, v12.4s, v14.s[0]\n" + "fmla v18.4s, v12.4s, v14.s[1]\n" + "fmla v20.4s, v12.4s, v14.s[2]\n" + "fmla v22.4s, v12.4s, v14.s[3]\n" + "fmla v24.4s, v12.4s, v15.s[0]\n" + "fmla v26.4s, v12.4s, v15.s[1]\n" + "fmla v28.4s, v12.4s, v15.s[2]\n" + "fmla v30.4s, v12.4s, v15.s[3]\n" + "fmla v25.4s, v13.4s, v15.s[0]\n" + "fmla v27.4s, v13.4s, v15.s[1]\n" + "fmla v29.4s, v13.4s, v15.s[2]\n" + "fmla v31.4s, v13.4s, v15.s[3]\n" + "fmla v17.4s, v13.4s, v14.s[0]\n" + "fmla v19.4s, v13.4s, v14.s[1]\n" + "fmla v21.4s, v13.4s, v14.s[2]\n" + "fmla v23.4s, v13.4s, v14.s[3]\n" + + "78:\n" +#endif + + // Accumulation loop + "cmp w1, w12\n" + "beq 79f\n" + + "2:\n" + "fmla v24.4s, v0.4s, v3.s[0]\n" + "fmla v26.4s, v0.4s, v3.s[1]\n" + "ld1 {v4.4s}, [%[rhs_ptr]], #16\n" + "fmla v28.4s, v0.4s, v3.s[2]\n" + "fmla v30.4s, v0.4s, v3.s[3]\n" + "ld1 {v0.4s}, [%[lhs_ptr]], #16\n" + "fmla v25.4s, v1.4s, v3.s[0]\n" + "fmla v27.4s, v1.4s, v3.s[1]\n" + "add w1, w1, #1\n" + "fmla v29.4s, v1.4s, v3.s[2]\n" + "fmla v31.4s, v1.4s, v3.s[3]\n" + "ld1 {v3.4s}, [%[rhs_ptr]], #16\n" + "fmla v17.4s, v1.4s, v2.s[0]\n" + "fmla v19.4s, v1.4s, v2.s[1]\n" + "cmp w1, w12\n" + "fmla v21.4s, v1.4s, v2.s[2]\n" + "fmla v23.4s, v1.4s, v2.s[3]\n" + "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" + "fmla v16.4s, v0.4s, v4.s[0]\n" + "fmla v18.4s, v0.4s, v4.s[1]\n" + "mov v2.16b, v4.16b\n" + "fmla v20.4s, v0.4s, v4.s[2]\n" + "fmla v22.4s, v0.4s, v4.s[3]\n" + "blt 2b\n" + + "79:\n" + + // End of the inner loop on depth. Now perform the remaining + // multiply-adds of the last level of depth, for which the LHS + // and RHS data is already loaded. + + "fmla v24.4s, v0.4s, v3.s[0]\n" + "fmla v26.4s, v0.4s, v3.s[1]\n" + "fmla v28.4s, v0.4s, v3.s[2]\n" + "fmla v30.4s, v0.4s, v3.s[3]\n" + "fmla v25.4s, v1.4s, v3.s[0]\n" + "fmla v27.4s, v1.4s, v3.s[1]\n" + "fmla v29.4s, v1.4s, v3.s[2]\n" + "fmla v31.4s, v1.4s, v3.s[3]\n" + "fmla v17.4s, v1.4s, v2.s[0]\n" + "fmla v19.4s, v1.4s, v2.s[1]\n" + "fmla v21.4s, v1.4s, v2.s[2]\n" + "fmla v23.4s, v1.4s, v2.s[3]\n" + + // End of accumulation. The registers v16 -- v31 contain the final + // int32 accumulator values of the current 8x8 destination block. + // We now have to compute the final 8-bit values from these int32 + // accumulators, and advance to the next 8x8 block. We intertwine + // these two aspects whenever possible for optimal pipelining, both + // at the data flow level (prefetch data for next block as early as + // possible) and instruction pipelining level (some of the next-block + // work can dual-issue with some of the final work on the current + // block). + + // Logic to advance to the next block in preparation for the next + // iteration of the main loop. For now, we only want to compute + // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are + // not yet ready to update the values of row and col, as we still need + // the current values for the rest of the work on the current block. + + "cmp %w[row], w7\n" // Have we finished the last row? + "bge 4f\n" // If finished last row, go to 4 + // Not finished last row: then advance to next row. + "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #3\n" + "b 5f\n" + "4:\n" // Finished last row... + "mov %[lhs_col_ptr], x5\n" // Go back to first row + // Now we need to advance to the next column. If we already + // finished the last column, then in principle we are done, however + // we can't just return here, as we need to allow the end work of the + // current block to complete. The good news is that at this point it + // doesn't matter what data we load for the next column, since + // we will exit from the main loop below before actually storing + // anything computed from that data. + "cmp %w[col], w8\n" // Have we finished the last column? + "bge 5f\n" // If yes, just carry on without updating the column pointer. + // Not finished last column: then advance to next column. + "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #3\n" + "5:\n" + + // Set the LHS and RHS data pointers to the start of the columns just + // computed. + "mov %[lhs_ptr], %[lhs_col_ptr]\n" + "mov %[rhs_ptr], %[rhs_col_ptr]\n" + + // Load some parameters needed for the end work on current block. + "ldrb w4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" + "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" + + // Offset these base pointers as needed given the current row, col. + "add x5, x1, %x[row], lsl #2\n" + + "tst w4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" + "csel x1, x1, x5, eq\n" + + // Load 8 bias values. + "ld1 {v14.4s}, [x1], #16\n" + "ld1 {v15.4s}, [x1]\n" + + // Now that we know what LHS and RHS data the next iteration of the + // main loop will need to load, we start loading the first 32 bytes of + // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore + // in the rest of the work on the current block. + "ld1 {v0.4s}, [%[lhs_ptr]], #16\n" + "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" + "ld1 {v2.4s}, [%[rhs_ptr]], #16\n" + "ld1 {v3.4s}, [%[rhs_ptr]], #16\n" + + // Perform the bias-addition (per the above, we have just folded into + // the bias the (depth * lhs_zero_point * rhs_zero_point) term.) + "fadd v16.4s, v16.4s, v14.4s\n" + "fadd v17.4s, v17.4s, v15.4s\n" + "fadd v18.4s, v18.4s, v14.4s\n" + "fadd v19.4s, v19.4s, v15.4s\n" + "fadd v20.4s, v20.4s, v14.4s\n" + "fadd v21.4s, v21.4s, v15.4s\n" + "fadd v22.4s, v22.4s, v14.4s\n" + "fadd v23.4s, v23.4s, v15.4s\n" + "fadd v24.4s, v24.4s, v14.4s\n" + "fadd v25.4s, v25.4s, v15.4s\n" + "fadd v26.4s, v26.4s, v14.4s\n" + "fadd v27.4s, v27.4s, v15.4s\n" + "fadd v28.4s, v28.4s, v14.4s\n" + "fadd v29.4s, v29.4s, v15.4s\n" + "fadd v30.4s, v30.4s, v14.4s\n" + "fadd v31.4s, v31.4s, v15.4s\n" + + // Load the clamp_min, clamp_max bounds + "ldr w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + "dup v14.4s, w2\n" // clamp_min + "dup v15.4s, w3\n" // clamp_max + + // Apply the clamp_min bound + "fmax v16.4s, v16.4s, v14.4s\n" + "fmax v17.4s, v17.4s, v14.4s\n" + "fmax v18.4s, v18.4s, v14.4s\n" + "fmax v19.4s, v19.4s, v14.4s\n" + "fmax v20.4s, v20.4s, v14.4s\n" + "fmax v21.4s, v21.4s, v14.4s\n" + "fmax v22.4s, v22.4s, v14.4s\n" + "fmax v23.4s, v23.4s, v14.4s\n" + "fmax v24.4s, v24.4s, v14.4s\n" + "fmax v25.4s, v25.4s, v14.4s\n" + "fmax v26.4s, v26.4s, v14.4s\n" + "fmax v27.4s, v27.4s, v14.4s\n" + "fmax v28.4s, v28.4s, v14.4s\n" + "fmax v29.4s, v29.4s, v14.4s\n" + "fmax v30.4s, v30.4s, v14.4s\n" + "fmax v31.4s, v31.4s, v14.4s\n" + + // Apply the clamp_max bound + "fmin v16.4s, v16.4s, v15.4s\n" + "fmin v17.4s, v17.4s, v15.4s\n" + "fmin v18.4s, v18.4s, v15.4s\n" + "fmin v19.4s, v19.4s, v15.4s\n" + "fmin v20.4s, v20.4s, v15.4s\n" + "fmin v21.4s, v21.4s, v15.4s\n" + "fmin v22.4s, v22.4s, v15.4s\n" + "fmin v23.4s, v23.4s, v15.4s\n" + "fmin v24.4s, v24.4s, v15.4s\n" + "fmin v25.4s, v25.4s, v15.4s\n" + "fmin v26.4s, v26.4s, v15.4s\n" + "fmin v27.4s, v27.4s, v15.4s\n" + "fmin v28.4s, v28.4s, v15.4s\n" + "fmin v29.4s, v29.4s, v15.4s\n" + "fmin v30.4s, v30.4s, v15.4s\n" + "fmin v31.4s, v31.4s, v15.4s\n" + + // Compute how much of the 8x8 block of destination 8bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 8x8, there are some 8x8 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + "sub w2, %w[dst_cols], %w[col]\n" + "mov w3, #8\n" + "cmp w1, #8\n" + // Compute w1 = how many rows of the 8x8 block fit + "csel w1, w1, w3, le\n" + "cmp w2, #8\n" + // Compute w2 = how many cols of the 8x8 block fit + "csel w2, w2, w3, le\n" + + // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. + "cmp w1, w3\n" + "ccmp w2, w3, 0, eq\n" + // Yes, all of the 8x8 block fits, go to fast path. + "beq 30f\n" + // Not all of the 8x8 block fits. + // Set (x3 address, x4 stride) to write to dst_tmp_buf + "mov x3, %[dst_tmp_buf]\n" + "mov x4, #32\n" + "b 31f\n" + "30:\n" + // Yes, all of the 8x8 block fits. + // Set (x3 address, x4 stride) to write directly to destination matrix. + "mov x3, %[dst_ptr]\n" + "mov x4, x11\n" + "31:\n" + + // Write our 8bit values to the destination described by + // (x3 address, x4 stride). + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "str q16, [x3, #0]\n" + "str q17, [x3, #16]\n" + "add x3, x3, x4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + RUY_MAKE_ZERO(v16) + RUY_MAKE_ZERO(v17) + "str q18, [x3, #0]\n" + "str q19, [x3, #16]\n" + "add x3, x3, x4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + RUY_MAKE_ZERO(v18) + RUY_MAKE_ZERO(v19) + "str q20, [x3, #0]\n" + "str q21, [x3, #16]\n" + "add x3, x3, x4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + RUY_MAKE_ZERO(v20) + RUY_MAKE_ZERO(v21) + "str q22, [x3, #0]\n" + "str q23, [x3, #16]\n" + "add x3, x3, x4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + RUY_MAKE_ZERO(v22) + RUY_MAKE_ZERO(v23) + "str q24, [x3, #0]\n" + "str q25, [x3, #16]\n" + "add x3, x3, x4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + RUY_MAKE_ZERO(v24) + RUY_MAKE_ZERO(v25) + "str q26, [x3, #0]\n" + "str q27, [x3, #16]\n" + "add x3, x3, x4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + RUY_MAKE_ZERO(v26) + RUY_MAKE_ZERO(v27) + "str q28, [x3, #0]\n" + "str q29, [x3, #16]\n" + "add x3, x3, x4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + RUY_MAKE_ZERO(v28) + RUY_MAKE_ZERO(v29) + "str q30, [x3, #0]\n" + "str q31, [x3, #16]\n" + RUY_MAKE_ZERO(v30) + RUY_MAKE_ZERO(v31) + + // If all of the 8x8 block fits, we just finished writing it to the + // destination, so we skip the next part. + "beq 41f\n" + // Not all of the 8x8 block fits in the destination matrix. We just + // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over + // it to copy into the destination matrix the part that fits. + "mov x3, %[dst_tmp_buf]\n" + "mov x4, %[dst_ptr]\n" + "mov w6, #0\n" + "50:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov w5, #0\n" + "51:\n" + "ldr w7, [x3, x5, lsl #2]\n" + "str w7, [x4, x5, lsl #2]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 51b\n" + "add w6, w6, #1\n" + "add x3, x3, #32\n" + "add x4, x4, x11\n" + "cmp w6, w2\n" + "blt 50b\n" + "41:\n" + "add %[dst_ptr], %[dst_ptr], #32\n" + // At this point we have completely finished writing values to the + // destination matrix for the current block. + + // Reload some params --- we had used x5 -- x7 for a few other things + // since the last time we had loaded them. + "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" + "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" + + // Move to the next block of the destination matrix, for the next iter + // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already + // been updated earlier. + // Have we reached the end row? + "cmp %w[row], w7\n" + "beq 20f\n" // yes, end row. + // Not end row. Move to the next row. + "add %w[row], %w[row], #8\n" + "b 21f\n" + "20:\n" + // Was already at end row. + "mov %w[row], w6\n" // Move back to first row. + "add %w[col], %w[col], #8\n" // Move to the next column. + "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #3\n" + "mov %[dst_ptr], %[dst_col_ptr]\n" + "21:\n" + + // Main loop exit condition: have we hit the end column? + "cmp %w[col], w8\n" + + // w1 is the number of levels of depth that we have already loaded + // LHS and RHS data for. Corresponding to the initial ld1 instructions + // above, this is currently 1. + "mov w1, #1\n" + + "ble 1b\n" + + // clang-format on + + : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr), + [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), + [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col) + : [ params ] "r"(¶ms), [dst_rows] "r"(params.dst_rows), + [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf) + : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc", + "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", + "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31"); +} + +// Variant of KernelFloatNeonOutOfOrder tuned for in-order CPUs that do not +// support dotprod (while dotprod by itself is not relevant to floating-point, +// this additional bit of information that we have about the target happens to +// be useful here). +// +// So a typical target CPU here would be ARM Cortex-A53 or the original +// Cortex-A55. +// +// This kernel is similar to and inspired by gemmlowp's +// NEON_64bit_GEMM_Float32_WithScalar_A53. +// which was contributed by David Mansell with very helpful +// comments. Specifically, see this comment about tuning for Cortex-A53: +// https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L4215 +void KernelFloatNeonInOrder(const KernelParamsFloat<8, 8>& params) { + profiler::ScopeLabel label("Kernel (kNeon, optimized for in-order cores)"); + + CheckOffsetsInKernelParamsFloat(params); + + const float* lhs_col_ptr = params.lhs_base_ptr; + const float* rhs_col_ptr = params.rhs_base_ptr; + const float* lhs_ptr = lhs_col_ptr; + const float* rhs_ptr = rhs_col_ptr; + float* dst_col_ptr = params.dst_base_ptr; + float* dst_ptr = dst_col_ptr; + int row = params.start_row; + int col = params.start_col; + + // The asm kernel below has the following NEON register allocation: + // + // v16 -- v31 are accumulators. + // During accumulation, v0 -- v3 are used to load data from LHS and RHS. + // + // RHS 1x8 block + // /-----------------------------------------\ + // |v2.s[0] ... v2.s[3] v3.s[0] ... v3.s[3]| + // \-----------------------------------------/ + // LHS 8x1 block + // /---------------------\ /-----------------------------------------\ + // | v0.s[0] | |v16.s[0] ... v30.s[0]| + // | ... | | ... ... | + // | v0.s[3] | |v16.s[3] ... v30.s[3]| + // | v1.s[0] | |v17.s[0] ... v31.s[0]| + // | ... | | ... ... | + // | v1.s[3] | |v17.s[3] ... v31.s[3]| + // \---------------------/ \-----------------------------------------/ + // accumulators 8x8 block + // + // There is no RUY_OPT_MAX_STREAMING 4x-unrolled part in this kernel because + // we did not observe a benefit of such partial unrolling on in-order CPUs. + // + // v4 -- v7 are unused, and v8 -- v15 are used for floading parameters used + // for the post-accumulation part of the kernel. + asm volatile( +#define RUY_MAKE_ZERO(reg) "dup " #reg ".4s, wzr\n" + + // clang-format off + + // Load some parameters into registers. + "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" + "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" + "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" + "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" + "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" + "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" + "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" + + + // Clear accumulators. + RUY_MAKE_ZERO(v16) + // Load the first 32 bytes of LHS and RHS data. + "ld1 {v0.4s}, [%[lhs_ptr]], #16\n" + RUY_MAKE_ZERO(v17) + "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" + RUY_MAKE_ZERO(v18) + "ld1 {v2.4s}, [%[rhs_ptr]], #16\n" + RUY_MAKE_ZERO(v19) + "ld1 {v3.4s}, [%[rhs_ptr]], #16\n" + RUY_MAKE_ZERO(v20) + RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #64]\n") + RUY_MAKE_ZERO(v21) + RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #64]\n") + RUY_MAKE_ZERO(v22) + RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #128]\n") + RUY_MAKE_ZERO(v23) + RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #128]\n") + RUY_MAKE_ZERO(v24) + RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #192]\n") + RUY_MAKE_ZERO(v25) + RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #192]\n") + RUY_MAKE_ZERO(v26) + RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #256]\n") + RUY_MAKE_ZERO(v27) + RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #256]\n") + RUY_MAKE_ZERO(v28) + RUY_MAKE_ZERO(v29) + RUY_MAKE_ZERO(v30) + RUY_MAKE_ZERO(v31) + + // w1 is the number of levels of depth that remain to load + // LHS and RHS data for. Corresponding to the initial ld1 instructions + // above, this is currently depth - 1. + "sub w1, w12, #1\n" + + // Main loop of the whole GEMM, over rows and columns of the + // destination matrix. + "1:\n" + + "cmp w1, #0\n" + "fmla v16.4s, v0.4s, v2.s[0]\n" + "fmla v18.4s, v0.4s, v2.s[1]\n" + "fmla v20.4s, v0.4s, v2.s[2]\n" + "fmla v22.4s, v0.4s, v2.s[3]\n" + + // Accumulation loop + "beq 79f\n" + + "2:\n" + + "fmla v24.4s, v0.4s, v3.s[0]\n" + "ldr x2, [%[lhs_ptr], #8]\n" + "fmla v26.4s, v0.4s, v3.s[1]\n" + "ldr x3, [%[lhs_ptr], #24]\n" + "fmla v28.4s, v0.4s, v3.s[2]\n" + "ldr x5, [%[rhs_ptr], #24]\n" + "fmla v30.4s, v0.4s, v3.s[3]\n" + "ldr x4, [%[rhs_ptr], #8]\n" + "fmla v25.4s, v1.4s, v3.s[0]\n" + "subs w1, w1, #1\n" + "ldr d0, [%[lhs_ptr]], #32\n" + "fmla v27.4s, v1.4s, v3.s[1]\n" + "fmla v29.4s, v1.4s, v3.s[2]\n" + "fmla v31.4s, v1.4s, v3.s[3]\n" + "ins v0.d[1], x2\n" + "ldr d3, [%[rhs_ptr], #16]\n" + "fmla v17.4s, v1.4s, v2.s[0]\n" + "fmla v19.4s, v1.4s, v2.s[1]\n" + "ins v3.d[1], x5\n" + "ldr d4, [%[rhs_ptr]], #32\n" + "fmla v21.4s, v1.4s, v2.s[2]\n" + "fmla v23.4s, v1.4s, v2.s[3]\n" + "fmla v16.4s, v0.4s, v4.s[0]\n" + "ins v4.d[1], x4\n" + "ldr d1, [%[lhs_ptr], #-16]\n" + "fmla v18.4s, v0.4s, v4.s[1]\n" + "fmla v20.4s, v0.4s, v4.s[2]\n" + "ins v1.d[1], x3\n" + RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #256]\n") + "mov v2.16b, v4.16b\n" + RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #256]\n") + "fmla v22.4s, v0.4s, v4.s[3]\n" + "bne 2b\n" + + "79:\n" + + // End of the inner loop on depth. Now perform the remaining + // multiply-adds of the last level of depth, for which the LHS + // and RHS data is already loaded. + + "fmla v24.4s, v0.4s, v3.s[0]\n" + "fmla v26.4s, v0.4s, v3.s[1]\n" + "fmla v28.4s, v0.4s, v3.s[2]\n" + "fmla v30.4s, v0.4s, v3.s[3]\n" + "fmla v25.4s, v1.4s, v3.s[0]\n" + "fmla v27.4s, v1.4s, v3.s[1]\n" + "fmla v29.4s, v1.4s, v3.s[2]\n" + "fmla v31.4s, v1.4s, v3.s[3]\n" + "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + "fmla v17.4s, v1.4s, v2.s[0]\n" + "fmla v19.4s, v1.4s, v2.s[1]\n" + "fmla v21.4s, v1.4s, v2.s[2]\n" + "fmla v23.4s, v1.4s, v2.s[3]\n" + + // End of accumulation. The registers v16 -- v31 contain the final + // int32 accumulator values of the current 8x8 destination block. + // We now have to compute the final 8-bit values from these int32 + // accumulators, and advance to the next 8x8 block. We intertwine + // these two aspects whenever possible for optimal pipelining, both + // at the data flow level (prefetch data for next block as early as + // possible) and instruction pipelining level (some of the next-block + // work can dual-issue with some of the final work on the current + // block). + + // Logic to advance to the next block in preparation for the next + // iteration of the main loop. For now, we only want to compute + // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are + // not yet ready to update the values of row and col, as we still need + // the current values for the rest of the work on the current block. + + "cmp %w[row], w7\n" // Have we finished the last row? + "bge 4f\n" // If finished last row, go to 4 + // Not finished last row: then advance to next row. + "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #3\n" + "b 5f\n" + "4:\n" // Finished last row... + "mov %[lhs_col_ptr], x5\n" // Go back to first row + // Now we need to advance to the next column. If we already + // finished the last column, then in principle we are done, however + // we can't just return here, as we need to allow the end work of the + // current block to complete. The good news is that at this point it + // doesn't matter what data we load for the next column, since + // we will exit from the main loop below before actually storing + // anything computed from that data. + "cmp %w[col], w8\n" // Have we finished the last column? + "bge 5f\n" // If yes, just carry on without updating the column pointer. + // Not finished last column: then advance to next column. + "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #3\n" + "5:\n" + + // Set the LHS and RHS data pointers to the start of the columns just + // computed. + "mov %[lhs_ptr], %[lhs_col_ptr]\n" + "mov %[rhs_ptr], %[rhs_col_ptr]\n" + + // Load some parameters needed for the end work on current block. + "ldrb w4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" + "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" + + // Offset these base pointers as needed given the current row, col. + "add x5, x1, %x[row], lsl #2\n" + + "tst w4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" + "csel x1, x1, x5, eq\n" + + // Load 8 bias values. + "ld1 {v14.4s}, [x1], #16\n" + "ld1 {v15.4s}, [x1]\n" + + // Now that we know what LHS and RHS data the next iteration of the + // main loop will need to load, we start loading the first 32 bytes of + // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore + // in the rest of the work on the current block. + "ld1 {v0.4s}, [%[lhs_ptr]], #16\n" + "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" + "ld1 {v2.4s}, [%[rhs_ptr]], #16\n" + "ld1 {v3.4s}, [%[rhs_ptr]], #16\n" + + // Perform the bias-addition (per the above, we have just folded into + // the bias the (depth * lhs_zero_point * rhs_zero_point) term.) + "fadd v16.4s, v16.4s, v14.4s\n" + "fadd v17.4s, v17.4s, v15.4s\n" + "fadd v18.4s, v18.4s, v14.4s\n" + "fadd v19.4s, v19.4s, v15.4s\n" + "fadd v20.4s, v20.4s, v14.4s\n" + "fadd v21.4s, v21.4s, v15.4s\n" + "fadd v22.4s, v22.4s, v14.4s\n" + "fadd v23.4s, v23.4s, v15.4s\n" + "fadd v24.4s, v24.4s, v14.4s\n" + "fadd v25.4s, v25.4s, v15.4s\n" + "fadd v26.4s, v26.4s, v14.4s\n" + "fadd v27.4s, v27.4s, v15.4s\n" + "fadd v28.4s, v28.4s, v14.4s\n" + "fadd v29.4s, v29.4s, v15.4s\n" + "fadd v30.4s, v30.4s, v14.4s\n" + "fadd v31.4s, v31.4s, v15.4s\n" + + // Load the clamp_min, clamp_max bounds + "ldr w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + "dup v14.4s, w2\n" // clamp_min + "dup v15.4s, w3\n" // clamp_max + + // Apply the clamp_min bound + "fmax v16.4s, v16.4s, v14.4s\n" + "fmax v17.4s, v17.4s, v14.4s\n" + "fmax v18.4s, v18.4s, v14.4s\n" + "fmax v19.4s, v19.4s, v14.4s\n" + "fmax v20.4s, v20.4s, v14.4s\n" + "fmax v21.4s, v21.4s, v14.4s\n" + "fmax v22.4s, v22.4s, v14.4s\n" + "fmax v23.4s, v23.4s, v14.4s\n" + "fmax v24.4s, v24.4s, v14.4s\n" + "fmax v25.4s, v25.4s, v14.4s\n" + "fmax v26.4s, v26.4s, v14.4s\n" + "fmax v27.4s, v27.4s, v14.4s\n" + "fmax v28.4s, v28.4s, v14.4s\n" + "fmax v29.4s, v29.4s, v14.4s\n" + "fmax v30.4s, v30.4s, v14.4s\n" + "fmax v31.4s, v31.4s, v14.4s\n" + + // Apply the clamp_max bound + "fmin v16.4s, v16.4s, v15.4s\n" + "fmin v17.4s, v17.4s, v15.4s\n" + "fmin v18.4s, v18.4s, v15.4s\n" + "fmin v19.4s, v19.4s, v15.4s\n" + "fmin v20.4s, v20.4s, v15.4s\n" + "fmin v21.4s, v21.4s, v15.4s\n" + "fmin v22.4s, v22.4s, v15.4s\n" + "fmin v23.4s, v23.4s, v15.4s\n" + "fmin v24.4s, v24.4s, v15.4s\n" + "fmin v25.4s, v25.4s, v15.4s\n" + "fmin v26.4s, v26.4s, v15.4s\n" + "fmin v27.4s, v27.4s, v15.4s\n" + "fmin v28.4s, v28.4s, v15.4s\n" + "fmin v29.4s, v29.4s, v15.4s\n" + "fmin v30.4s, v30.4s, v15.4s\n" + "fmin v31.4s, v31.4s, v15.4s\n" + + // Compute how much of the 8x8 block of destination 8bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 8x8, there are some 8x8 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + "sub w2, %w[dst_cols], %w[col]\n" + "mov w3, #8\n" + "cmp w1, #8\n" + // Compute w1 = how many rows of the 8x8 block fit + "csel w1, w1, w3, le\n" + "cmp w2, #8\n" + // Compute w2 = how many cols of the 8x8 block fit + "csel w2, w2, w3, le\n" + + // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. + "cmp w1, w3\n" + "ccmp w2, w3, 0, eq\n" + // Yes, all of the 8x8 block fits, go to fast path. + "beq 30f\n" + // Not all of the 8x8 block fits. + // Set (x3 address, x4 stride) to write to dst_tmp_buf + "mov x3, %[dst_tmp_buf]\n" + "mov x4, #32\n" + "b 31f\n" + "30:\n" + // Yes, all of the 8x8 block fits. + // Set (x3 address, x4 stride) to write directly to destination matrix. + "mov x3, %[dst_ptr]\n" + "mov x4, x11\n" + "31:\n" + + // Write our 8bit values to the destination described by + // (x3 address, x4 stride). + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "str q16, [x3, #0]\n" + "str q17, [x3, #16]\n" + "add x3, x3, x4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + RUY_MAKE_ZERO(v16) + RUY_MAKE_ZERO(v17) + "str q18, [x3, #0]\n" + "str q19, [x3, #16]\n" + "add x3, x3, x4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + RUY_MAKE_ZERO(v18) + RUY_MAKE_ZERO(v19) + "str q20, [x3, #0]\n" + "str q21, [x3, #16]\n" + "add x3, x3, x4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + RUY_MAKE_ZERO(v20) + RUY_MAKE_ZERO(v21) + "str q22, [x3, #0]\n" + "str q23, [x3, #16]\n" + "add x3, x3, x4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + RUY_MAKE_ZERO(v22) + RUY_MAKE_ZERO(v23) + "str q24, [x3, #0]\n" + "str q25, [x3, #16]\n" + "add x3, x3, x4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + RUY_MAKE_ZERO(v24) + RUY_MAKE_ZERO(v25) + "str q26, [x3, #0]\n" + "str q27, [x3, #16]\n" + "add x3, x3, x4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + RUY_MAKE_ZERO(v26) + RUY_MAKE_ZERO(v27) + "str q28, [x3, #0]\n" + "str q29, [x3, #16]\n" + "add x3, x3, x4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + RUY_MAKE_ZERO(v28) + RUY_MAKE_ZERO(v29) + "str q30, [x3, #0]\n" + "str q31, [x3, #16]\n" + RUY_MAKE_ZERO(v30) + RUY_MAKE_ZERO(v31) + + // If all of the 8x8 block fits, we just finished writing it to the + // destination, so we skip the next part. + "beq 41f\n" + // Not all of the 8x8 block fits in the destination matrix. We just + // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over + // it to copy into the destination matrix the part that fits. + "mov x3, %[dst_tmp_buf]\n" + "mov x4, %[dst_ptr]\n" + "mov w6, #0\n" + "50:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov w5, #0\n" + "51:\n" + "ldr w7, [x3, x5, lsl #2]\n" + "str w7, [x4, x5, lsl #2]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 51b\n" + "add w6, w6, #1\n" + "add x3, x3, #32\n" + "add x4, x4, x11\n" + "cmp w6, w2\n" + "blt 50b\n" + "41:\n" + "add %[dst_ptr], %[dst_ptr], #32\n" + // At this point we have completely finished writing values to the + // destination matrix for the current block. + + // Reload some params --- we had used x5 -- x7 for a few other things + // since the last time we had loaded them. + "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" + "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" + + // Move to the next block of the destination matrix, for the next iter + // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already + // been updated earlier. + // Have we reached the end row? + "cmp %w[row], w7\n" + "beq 20f\n" // yes, end row. + // Not end row. Move to the next row. + "add %w[row], %w[row], #8\n" + "b 21f\n" + "20:\n" + // Was already at end row. + "mov %w[row], w6\n" // Move back to first row. + "add %w[col], %w[col], #8\n" // Move to the next column. + "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #3\n" + "mov %[dst_ptr], %[dst_col_ptr]\n" + "21:\n" + + // Main loop exit condition: have we hit the end column? + "cmp %w[col], w8\n" + + // w1 is the number of levels of depth that remain to load + // LHS and RHS data for. Corresponding to the initial ld1 instructions + // above, this is currently depth - 1. + "sub w1, w12, #1\n" + + "ble 1b\n" + + // clang-format on + + : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr), + [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), + [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col) + : [ params ] "r"(¶ms), [dst_rows] "r"(params.dst_rows), + [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf) + : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc", + "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", + "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31"); +} + +// Variant of KernelFloatNeonInOrder tuned for in-order CPUs that do +// support dotprod (while dotprod by itself is not relevant to floating-point, +// this additional bit of information that we have about the target happens to +// be useful here). +// +// So a typical target CPU here would be ARM Cortex-A55r1. +// +// This kernel is similar to and inspired by gemmlowp's +// NEON_64bit_GEMM_Float32_WithScalar_A55r1. +// which was contributed by David Mansell with very helpful +// comments. Specifically, see this comment about tuning for Cortex-A55r1: +// https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L4412 +void KernelFloatNeonDotprodInOrder(const KernelParamsFloat<8, 8>& params) { + profiler::ScopeLabel label( + "Kernel (kNeonDotprod, optimized for in-order cores)"); + + CheckOffsetsInKernelParamsFloat(params); + + const float* lhs_col_ptr = params.lhs_base_ptr; + const float* rhs_col_ptr = params.rhs_base_ptr; + const float* lhs_ptr = lhs_col_ptr; + const float* rhs_ptr = rhs_col_ptr; + float* dst_col_ptr = params.dst_base_ptr; + float* dst_ptr = dst_col_ptr; + int row = params.start_row; + int col = params.start_col; + + // The asm kernel below has the following NEON register allocation: + // + // v16 -- v31 are accumulators. + // During accumulation, v0 -- v3 are used to load data from LHS and RHS. + // + // RHS 1x8 block + // /-----------------------------------------\ + // |v2.s[0] ... v2.s[3] v3.s[0] ... v3.s[3]| + // \-----------------------------------------/ + // LHS 8x1 block + // /---------------------\ /-----------------------------------------\ + // | v0.s[0] | |v16.s[0] ... v30.s[0]| + // | ... | | ... ... | + // | v0.s[3] | |v16.s[3] ... v30.s[3]| + // | v1.s[0] | |v17.s[0] ... v31.s[0]| + // | ... | | ... ... | + // | v1.s[3] | |v17.s[3] ... v31.s[3]| + // \---------------------/ \-----------------------------------------/ + // accumulators 8x8 block + // + // There is no RUY_OPT_MAX_STREAMING 4x-unrolled part in this kernel because + // we did not observe a benefit of such partial unrolling on in-order CPUs. + // + // v4 -- v7 are unused, and v8 -- v15 are used for floading parameters used + // for the post-accumulation part of the kernel. + asm volatile( +#define RUY_MAKE_ZERO(reg) "dup " #reg ".4s, wzr\n" + + // clang-format off + + // Load some parameters into registers. + "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" + "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" + "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" + "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" + "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" + "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" + "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" + + + // Clear accumulators. + RUY_MAKE_ZERO(v16) + // Load the first 32 bytes of LHS and RHS data. + "ld1 {v0.4s}, [%[lhs_ptr]], #16\n" + RUY_MAKE_ZERO(v17) + "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" + RUY_MAKE_ZERO(v18) + "ld1 {v2.4s}, [%[rhs_ptr]], #16\n" + RUY_MAKE_ZERO(v19) + "ld1 {v3.4s}, [%[rhs_ptr]], #16\n" + RUY_MAKE_ZERO(v20) + RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #64]\n") + RUY_MAKE_ZERO(v21) + RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #64]\n") + RUY_MAKE_ZERO(v22) + RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #128]\n") + RUY_MAKE_ZERO(v23) + RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #128]\n") + RUY_MAKE_ZERO(v24) + RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #192]\n") + RUY_MAKE_ZERO(v25) + RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #192]\n") + RUY_MAKE_ZERO(v26) + RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #256]\n") + RUY_MAKE_ZERO(v27) + RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #256]\n") + RUY_MAKE_ZERO(v28) + RUY_MAKE_ZERO(v29) + RUY_MAKE_ZERO(v30) + RUY_MAKE_ZERO(v31) + + // w1 is the number of levels of depth that remain to load + // LHS and RHS data for. Corresponding to the initial ld1 instructions + // above, this is currently depth - 1. + "sub w1, w12, #1\n" + + // Main loop of the whole GEMM, over rows and columns of the + // destination matrix. + "1:\n" + + "cmp w1, #0\n" + "fmla v16.4s, v0.4s, v2.s[0]\n" + "fmla v18.4s, v0.4s, v2.s[1]\n" + "fmla v20.4s, v0.4s, v2.s[2]\n" + "fmla v22.4s, v0.4s, v2.s[3]\n" + + // Accumulation loop + "beq 79f\n" + + "2:\n" + + RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #256]\n") + "fmla v24.4s, v0.4s, v3.s[0]\n" + "ldr x2, [%[lhs_ptr], #8]\n" + "fmla v26.4s, v0.4s, v3.s[1]\n" + "ldr x3, [%[lhs_ptr], #24]\n" + "fmla v28.4s, v0.4s, v3.s[2]\n" + "ldr x5, [%[rhs_ptr], #24]\n" + "fmla v30.4s, v0.4s, v3.s[3]\n" + "ldr d0, [%[lhs_ptr]], #32\n" + "fmla v25.4s, v1.4s, v3.s[0]\n" + "ldr x4, [%[rhs_ptr], #8]\n" + "fmla v27.4s, v1.4s, v3.s[1]\n" + "subs w1, w1, #1\n" + "fmla v29.4s, v1.4s, v3.s[2]\n" + "ins v0.d[1], x2\n" + "fmla v31.4s, v1.4s, v3.s[3]\n" + "ldr d3, [%[rhs_ptr], #16]\n" + "fmla v17.4s, v1.4s, v2.s[0]\n" + "ins v3.d[1], x5\n" + "fmla v19.4s, v1.4s, v2.s[1]\n" + "ldr d4, [%[rhs_ptr]], #32\n" + "fmla v21.4s, v1.4s, v2.s[2]\n" + "ins v4.d[1], x4\n" + "fmla v23.4s, v1.4s, v2.s[3]\n" + RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #256]\n") + "fmla v16.4s, v0.4s, v4.s[0]\n" + "ldr d1, [%[lhs_ptr], #-16]\n" + "fmla v18.4s, v0.4s, v4.s[1]\n" + "ins v1.d[1], x3\n" + "fmla v20.4s, v0.4s, v4.s[2]\n" + "mov v2.16b, v4.16b\n" + "fmla v22.4s, v0.4s, v4.s[3]\n" + "bne 2b\n" + + "79:\n" + + // End of the inner loop on depth. Now perform the remaining + // multiply-adds of the last level of depth, for which the LHS + // and RHS data is already loaded. + + "fmla v24.4s, v0.4s, v3.s[0]\n" + "fmla v26.4s, v0.4s, v3.s[1]\n" + "fmla v28.4s, v0.4s, v3.s[2]\n" + "fmla v30.4s, v0.4s, v3.s[3]\n" + "fmla v25.4s, v1.4s, v3.s[0]\n" + "fmla v27.4s, v1.4s, v3.s[1]\n" + "fmla v29.4s, v1.4s, v3.s[2]\n" + "fmla v31.4s, v1.4s, v3.s[3]\n" + "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + "fmla v17.4s, v1.4s, v2.s[0]\n" + "fmla v19.4s, v1.4s, v2.s[1]\n" + "fmla v21.4s, v1.4s, v2.s[2]\n" + "fmla v23.4s, v1.4s, v2.s[3]\n" + + // End of accumulation. The registers v16 -- v31 contain the final + // int32 accumulator values of the current 8x8 destination block. + // We now have to compute the final 8-bit values from these int32 + // accumulators, and advance to the next 8x8 block. We intertwine + // these two aspects whenever possible for optimal pipelining, both + // at the data flow level (prefetch data for next block as early as + // possible) and instruction pipelining level (some of the next-block + // work can dual-issue with some of the final work on the current + // block). + + // Logic to advance to the next block in preparation for the next + // iteration of the main loop. For now, we only want to compute + // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are + // not yet ready to update the values of row and col, as we still need + // the current values for the rest of the work on the current block. + + "cmp %w[row], w7\n" // Have we finished the last row? + "bge 4f\n" // If finished last row, go to 4 + // Not finished last row: then advance to next row. + "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #3\n" + "b 5f\n" + "4:\n" // Finished last row... + "mov %[lhs_col_ptr], x5\n" // Go back to first row + // Now we need to advance to the next column. If we already + // finished the last column, then in principle we are done, however + // we can't just return here, as we need to allow the end work of the + // current block to complete. The good news is that at this point it + // doesn't matter what data we load for the next column, since + // we will exit from the main loop below before actually storing + // anything computed from that data. + "cmp %w[col], w8\n" // Have we finished the last column? + "bge 5f\n" // If yes, just carry on without updating the column pointer. + // Not finished last column: then advance to next column. + "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #3\n" + "5:\n" + + // Set the LHS and RHS data pointers to the start of the columns just + // computed. + "mov %[lhs_ptr], %[lhs_col_ptr]\n" + "mov %[rhs_ptr], %[rhs_col_ptr]\n" + + // Load some parameters needed for the end work on current block. + "ldrb w4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" + "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" + + // Offset these base pointers as needed given the current row, col. + "add x5, x1, %x[row], lsl #2\n" + + "tst w4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" + "csel x1, x1, x5, eq\n" + + // Load 8 bias values. + "ld1 {v14.4s}, [x1], #16\n" + "ld1 {v15.4s}, [x1]\n" + + // Now that we know what LHS and RHS data the next iteration of the + // main loop will need to load, we start loading the first 32 bytes of + // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore + // in the rest of the work on the current block. + "ld1 {v0.4s}, [%[lhs_ptr]], #16\n" + "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" + "ld1 {v2.4s}, [%[rhs_ptr]], #16\n" + "ld1 {v3.4s}, [%[rhs_ptr]], #16\n" + + // Perform the bias-addition (per the above, we have just folded into + // the bias the (depth * lhs_zero_point * rhs_zero_point) term.) + "fadd v16.4s, v16.4s, v14.4s\n" + "fadd v17.4s, v17.4s, v15.4s\n" + "fadd v18.4s, v18.4s, v14.4s\n" + "fadd v19.4s, v19.4s, v15.4s\n" + "fadd v20.4s, v20.4s, v14.4s\n" + "fadd v21.4s, v21.4s, v15.4s\n" + "fadd v22.4s, v22.4s, v14.4s\n" + "fadd v23.4s, v23.4s, v15.4s\n" + "fadd v24.4s, v24.4s, v14.4s\n" + "fadd v25.4s, v25.4s, v15.4s\n" + "fadd v26.4s, v26.4s, v14.4s\n" + "fadd v27.4s, v27.4s, v15.4s\n" + "fadd v28.4s, v28.4s, v14.4s\n" + "fadd v29.4s, v29.4s, v15.4s\n" + "fadd v30.4s, v30.4s, v14.4s\n" + "fadd v31.4s, v31.4s, v15.4s\n" + + // Load the clamp_min, clamp_max bounds + "ldr w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + "dup v14.4s, w2\n" // clamp_min + "dup v15.4s, w3\n" // clamp_max + + // Apply the clamp_min bound + "fmax v16.4s, v16.4s, v14.4s\n" + "fmax v17.4s, v17.4s, v14.4s\n" + "fmax v18.4s, v18.4s, v14.4s\n" + "fmax v19.4s, v19.4s, v14.4s\n" + "fmax v20.4s, v20.4s, v14.4s\n" + "fmax v21.4s, v21.4s, v14.4s\n" + "fmax v22.4s, v22.4s, v14.4s\n" + "fmax v23.4s, v23.4s, v14.4s\n" + "fmax v24.4s, v24.4s, v14.4s\n" + "fmax v25.4s, v25.4s, v14.4s\n" + "fmax v26.4s, v26.4s, v14.4s\n" + "fmax v27.4s, v27.4s, v14.4s\n" + "fmax v28.4s, v28.4s, v14.4s\n" + "fmax v29.4s, v29.4s, v14.4s\n" + "fmax v30.4s, v30.4s, v14.4s\n" + "fmax v31.4s, v31.4s, v14.4s\n" + + // Apply the clamp_max bound + "fmin v16.4s, v16.4s, v15.4s\n" + "fmin v17.4s, v17.4s, v15.4s\n" + "fmin v18.4s, v18.4s, v15.4s\n" + "fmin v19.4s, v19.4s, v15.4s\n" + "fmin v20.4s, v20.4s, v15.4s\n" + "fmin v21.4s, v21.4s, v15.4s\n" + "fmin v22.4s, v22.4s, v15.4s\n" + "fmin v23.4s, v23.4s, v15.4s\n" + "fmin v24.4s, v24.4s, v15.4s\n" + "fmin v25.4s, v25.4s, v15.4s\n" + "fmin v26.4s, v26.4s, v15.4s\n" + "fmin v27.4s, v27.4s, v15.4s\n" + "fmin v28.4s, v28.4s, v15.4s\n" + "fmin v29.4s, v29.4s, v15.4s\n" + "fmin v30.4s, v30.4s, v15.4s\n" + "fmin v31.4s, v31.4s, v15.4s\n" + + // Compute how much of the 8x8 block of destination 8bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 8x8, there are some 8x8 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + "sub w2, %w[dst_cols], %w[col]\n" + "mov w3, #8\n" + "cmp w1, #8\n" + // Compute w1 = how many rows of the 8x8 block fit + "csel w1, w1, w3, le\n" + "cmp w2, #8\n" + // Compute w2 = how many cols of the 8x8 block fit + "csel w2, w2, w3, le\n" + + // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. + "cmp w1, w3\n" + "ccmp w2, w3, 0, eq\n" + // Yes, all of the 8x8 block fits, go to fast path. + "beq 30f\n" + // Not all of the 8x8 block fits. + // Set (x3 address, x4 stride) to write to dst_tmp_buf + "mov x3, %[dst_tmp_buf]\n" + "mov x4, #32\n" + "b 31f\n" + "30:\n" + // Yes, all of the 8x8 block fits. + // Set (x3 address, x4 stride) to write directly to destination matrix. + "mov x3, %[dst_ptr]\n" + "mov x4, x11\n" + "31:\n" + + // Write our 8bit values to the destination described by + // (x3 address, x4 stride). + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "str q16, [x3, #0]\n" + "str q17, [x3, #16]\n" + "add x3, x3, x4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + RUY_MAKE_ZERO(v16) + RUY_MAKE_ZERO(v17) + "str q18, [x3, #0]\n" + "str q19, [x3, #16]\n" + "add x3, x3, x4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + RUY_MAKE_ZERO(v18) + RUY_MAKE_ZERO(v19) + "str q20, [x3, #0]\n" + "str q21, [x3, #16]\n" + "add x3, x3, x4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + RUY_MAKE_ZERO(v20) + RUY_MAKE_ZERO(v21) + "str q22, [x3, #0]\n" + "str q23, [x3, #16]\n" + "add x3, x3, x4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + RUY_MAKE_ZERO(v22) + RUY_MAKE_ZERO(v23) + "str q24, [x3, #0]\n" + "str q25, [x3, #16]\n" + "add x3, x3, x4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + RUY_MAKE_ZERO(v24) + RUY_MAKE_ZERO(v25) + "str q26, [x3, #0]\n" + "str q27, [x3, #16]\n" + "add x3, x3, x4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + RUY_MAKE_ZERO(v26) + RUY_MAKE_ZERO(v27) + "str q28, [x3, #0]\n" + "str q29, [x3, #16]\n" + "add x3, x3, x4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + RUY_MAKE_ZERO(v28) + RUY_MAKE_ZERO(v29) + "str q30, [x3, #0]\n" + "str q31, [x3, #16]\n" + RUY_MAKE_ZERO(v30) + RUY_MAKE_ZERO(v31) + + // If all of the 8x8 block fits, we just finished writing it to the + // destination, so we skip the next part. + "beq 41f\n" + // Not all of the 8x8 block fits in the destination matrix. We just + // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over + // it to copy into the destination matrix the part that fits. + "mov x3, %[dst_tmp_buf]\n" + "mov x4, %[dst_ptr]\n" + "mov w6, #0\n" + "50:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov w5, #0\n" + "51:\n" + "ldr w7, [x3, x5, lsl #2]\n" + "str w7, [x4, x5, lsl #2]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 51b\n" + "add w6, w6, #1\n" + "add x3, x3, #32\n" + "add x4, x4, x11\n" + "cmp w6, w2\n" + "blt 50b\n" + "41:\n" + "add %[dst_ptr], %[dst_ptr], #32\n" + // At this point we have completely finished writing values to the + // destination matrix for the current block. + + // Reload some params --- we had used x5 -- x7 for a few other things + // since the last time we had loaded them. + "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" + "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" + + // Move to the next block of the destination matrix, for the next iter + // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already + // been updated earlier. + // Have we reached the end row? + "cmp %w[row], w7\n" + "beq 20f\n" // yes, end row. + // Not end row. Move to the next row. + "add %w[row], %w[row], #8\n" + "b 21f\n" + "20:\n" + // Was already at end row. + "mov %w[row], w6\n" // Move back to first row. + "add %w[col], %w[col], #8\n" // Move to the next column. + "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #3\n" + "mov %[dst_ptr], %[dst_col_ptr]\n" + "21:\n" + + // Main loop exit condition: have we hit the end column? + "cmp %w[col], w8\n" + + // w1 is the number of levels of depth that remain to load + // LHS and RHS data for. Corresponding to the initial ld1 instructions + // above, this is currently depth - 1. + "sub w1, w12, #1\n" + + "ble 1b\n" + + // clang-format on + + : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr), + [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), + [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col) + : [ params ] "r"(¶ms), [dst_rows] "r"(params.dst_rows), + [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf) + : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc", + "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", + "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31"); +} +#undef RUY_OFFSET_BIAS +#undef RUY_OFFSET_FLAGS +#undef RUY_OFFSET_LHS_BASE_PTR +#undef RUY_OFFSET_CLAMP_MIN +#undef RUY_OFFSET_CLAMP_MAX +#undef RUY_OFFSET_START_ROW +#undef RUY_OFFSET_LAST_ROW +#undef RUY_OFFSET_LAST_COL +#undef RUY_OFFSET_LHS_STRIDE +#undef RUY_OFFSET_RHS_STRIDE +#undef RUY_OFFSET_DST_STRIDE +#undef RUY_OFFSET_DEPTH +#undef RUY_OFFSET_START_COL +#undef RUY_OFFSET_RHS_BASE_PTR +#undef RUY_OFFSET_DST_BASE_PTR + +#endif // RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM) + +} // namespace ruy diff --git a/ruy/kernel_avx2.cc b/ruy/kernel_avx2.cc new file mode 100644 index 0000000..13fe22b --- /dev/null +++ b/ruy/kernel_avx2.cc @@ -0,0 +1,1664 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "ruy/check_macros.h" +#include "ruy/kernel.h" +#include "ruy/opt_set.h" +#include "ruy/platform.h" +#include "ruy/profiler/instrumentation.h" + +#if RUY_PLATFORM(AVX2) && RUY_OPT_ENABLED(RUY_OPT_ASM) +#include // IWYU pragma: keep +#endif + +namespace ruy { + +#if !(RUY_PLATFORM(AVX2) && RUY_OPT_ENABLED(RUY_OPT_ASM)) + +void Kernel8bitAvx2(const KernelParams8bit<8, 8>& params) { + // CPU-ID-based checks should disable the path that would reach this point. + RUY_DCHECK(false); +} + +void Kernel8bitAvx2SingleCol(const KernelParams8bit<8, 8>& params) { + // CPU-ID-based checks should disable the path that would reach this point. + RUY_DCHECK(false); +} + +void KernelFloatAvx2(const KernelParamsFloat<8, 8>& params) { + // CPU-ID-based checks should disable the path that would reach this point. + RUY_DCHECK(false); +} + +void KernelFloatAvx2SingleCol(const KernelParamsFloat<8, 8>& params) { + // CPU-ID-based checks should disable the path that would reach this point. + RUY_DCHECK(false); +} + +#else // RUY_PLATFORM(AVX2) && RUY_OPT_ENABLED(RUY_OPT_ASM) + +static constexpr int kAvx8bitBlockSize = 8; +static constexpr int kAvx8bitInnerSize = 4; + +namespace { +namespace intrin_utils { + +inline __m256 mm256_n_loadu_epi32(int n, const std::int32_t* src) { + switch (n) { + case 0: + return _mm256_setzero_si256(); + case 1: + return _mm256_setr_m128(_mm_setr_epi32(src[0], 0, 0, 0), + _mm_setzero_si128()); + case 2: + return _mm256_setr_m128(_mm_setr_epi32(src[0], src[1], 0, 0), + _mm_setzero_si128()); + case 3: + return _mm256_setr_m128(_mm_setr_epi32(src[0], src[1], src[2], 0), + _mm_setzero_si128()); + case 4: + return _mm256_castsi128_si256( + _mm_loadu_si128(reinterpret_cast<__m128i const*>(src))); + case 5: + return _mm256_setr_epi32(src[0], src[1], src[2], src[3], src[4], 0, 0, 0); + case 6: + return _mm256_setr_epi32(src[0], src[1], src[2], src[3], src[4], src[5], + 0, 0); + case 7: + return _mm256_setr_epi32(src[0], src[1], src[2], src[3], src[4], src[5], + src[6], 0); + case 8: + return _mm256_loadu_si256(reinterpret_cast<__m256i const*>(src)); + default: + RUY_DCHECK_LT(n, 9); + return _mm256_setzero_si256(); + } +} + +inline void mm256_n_storeu_cvtepi32_epi8(std::uint8_t* dst, int residual_rows, + const __m256 v) { + // Select bytes 0, 4, 8, 12 within each lane, effectively truncating. + const __m256i repack_perm = _mm256_set1_epi32(0x0c080400); + __m256i shuffled_v; + if (residual_rows > 1) { + // This selects 0, 4, 8, 12, 0, 4, 8, 12, ..., but we only use the first 4 + // in each 128-bit lane. + shuffled_v = _mm256_shuffle_epi8(v, repack_perm); + } + switch (residual_rows) { + case 0: + break; + case 1: + dst[0] = _mm256_extract_epi8(v, 0); + break; + case 2: + _mm_storeu_si16(dst, _mm256_extracti128_si256(shuffled_v, 0)); + break; + case 3: { + __m128i trailing_packed = _mm256_extracti128_si256(shuffled_v, 0); + _mm_storeu_si16(dst, trailing_packed); + dst[2] = _mm_extract_epi8(trailing_packed, 2); + break; + } + case 4: + _mm_storeu_si32(dst, _mm256_extracti128_si256(shuffled_v, 0)); + break; + case 5: + _mm_storeu_si32(dst, _mm256_extracti128_si256(shuffled_v, 0)); + dst[4] = _mm256_extract_epi8(shuffled_v, 16); + break; + case 6: + _mm_storeu_si32(dst, _mm256_extracti128_si256(shuffled_v, 0)); + _mm_storeu_si16(dst + 4, _mm256_extracti128_si256(shuffled_v, 1)); + break; + case 7: { + _mm_storeu_si32(dst, _mm256_extracti128_si256(shuffled_v, 0)); + __m128i trailing_packed = _mm256_extracti128_si256(shuffled_v, 1); + _mm_storeu_si16(dst + 4, trailing_packed); + dst[6] = _mm_extract_epi8(trailing_packed, 2); + break; + } + case 8: + _mm_storeu_si32(dst, _mm256_extracti128_si256(shuffled_v, 0)); + _mm_storeu_si32(dst + 4, _mm256_extracti128_si256(shuffled_v, 1)); + break; + default: + RUY_DCHECK_LE(residual_rows, 8); + break; + } +} + +inline void mm256_storeu_cvtepi32_epi8(std::uint8_t* dst, const __m256 v) { + // Select bytes 0, 4, 8, 12 within each lane, effectively truncating. + const __m256i repack_perm = _mm256_set1_epi32(0x0c080400); + const __m256i shuffled_v = _mm256_shuffle_epi8(v, repack_perm); + _mm_storeu_si32(dst, _mm256_extracti128_si256(shuffled_v, 0)); + _mm_storeu_si32(dst + 4, _mm256_extracti128_si256(shuffled_v, 1)); +} + +inline void mm256_n_storeu_cvtepi32_epi8(std::int8_t* dst, int residual_rows, + const __m256 v) { + intrin_utils::mm256_n_storeu_cvtepi32_epi8( + reinterpret_cast(dst), residual_rows, v); +} + +inline void mm256_storeu_cvtepi32_epi8(std::int8_t* dst, const __m256 v) { + // Select bytes 0, 4, 8, 12 within each lane, effectively truncating. + const __m256i repack_perm = _mm256_set1_epi32(0x0c080400); + const __m256i shuffled_v = _mm256_shuffle_epi8(v, repack_perm); + _mm_storeu_si32(dst, _mm256_extracti128_si256(shuffled_v, 0)); + _mm_storeu_si32(dst + 4, _mm256_extracti128_si256(shuffled_v, 1)); +} + +inline void mm256_n_storeu_cvtepi32_epi16(std::int16_t* dst, int residual_rows, + const __m256 v) { + // Select bytes 0, 1, 4, 5, 8, 9, 12, 13 within each lane, effectively + // truncating each 16-bit integer. + const __m256i repack_perm = _mm256_set1_epi64x(0x0d0c090805040100); + __m256i shuffled_v; + __m128i shuffled_v_low; + if (residual_rows > 1) { + shuffled_v = _mm256_shuffle_epi8(v, repack_perm); + shuffled_v_low = _mm256_extracti128_si256(shuffled_v, 0); + } else { + shuffled_v_low = _mm256_extracti128_si256(v, 0); + } + switch (residual_rows) { + case 0: + break; + case 1: + _mm_storeu_si16(dst, shuffled_v_low); + break; + case 2: + _mm_storeu_si32(dst, shuffled_v_low); + break; + case 3: { + _mm_storeu_si32(dst, shuffled_v_low); + dst[2] = _mm_extract_epi16(shuffled_v_low, 2); + break; + } + case 4: + _mm_storeu_si64(dst, shuffled_v_low); + break; + case 5: + _mm_storeu_si64(dst, shuffled_v_low); + dst[4] = _mm256_extract_epi16(shuffled_v, 8); + break; + case 6: + _mm_storeu_si64(dst, shuffled_v_low); + _mm_storeu_si32(dst + 4, _mm256_extracti128_si256(shuffled_v, 1)); + break; + case 7: { + _mm_storeu_si64(dst, shuffled_v_low); + __m128i trailing_packed = _mm256_extracti128_si256(shuffled_v, 1); + _mm_storeu_si32(dst + 4, trailing_packed); + dst[6] = _mm_extract_epi16(trailing_packed, 2); + break; + } + case 8: + _mm_storeu_si64(dst, _mm256_extracti128_si256(shuffled_v, 0)); + _mm_storeu_si64(dst + 4, _mm256_extracti128_si256(shuffled_v, 1)); + break; + default: + RUY_DCHECK_LE(residual_rows, 8); + break; + } +} + +inline void mm256_storeu_cvtepi32_epi16(std::int16_t* dst, const __m256 v) { + // Select bytes 0, 1, 4, 5, 8, 9, 12, 13 within each lane, effectively + // truncating each 16-bit integer. + const __m256i repack_perm = _mm256_set1_epi64x(0x0d0c090805040100); + const __m256i shuffled_v = _mm256_shuffle_epi8(v, repack_perm); + _mm_storeu_si64(dst, _mm256_extracti128_si256(shuffled_v, 0)); + _mm_storeu_si64(dst + 4, _mm256_extracti128_si256(shuffled_v, 1)); +} + +inline void mm256_n_storeu_epi32(std::int32_t* dst, int residual_rows, + const __m256 v) { + const __m128i v_low = _mm256_extracti128_si256(v, 0); + switch (residual_rows) { + case 0: + break; + case 1: + _mm_storeu_si32(dst, v_low); + break; + case 2: + _mm_storeu_si64(dst, v_low); + break; + case 3: { + __m128i trailing_packed = v_low; + _mm_storeu_si64(dst, trailing_packed); + dst[2] = _mm_extract_epi32(trailing_packed, 2); + break; + } + case 4: + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), v_low); + break; + case 5: + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), v_low); + dst[4] = _mm256_extract_epi32(v, 4); + break; + case 6: + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), v_low); + _mm_storeu_si64(dst + 4, _mm256_extracti128_si256(v, 1)); + break; + case 7: { + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), v_low); + __m128i trailing_packed = _mm256_extracti128_si256(v, 1); + _mm_storeu_si64(dst + 4, trailing_packed); + dst[6] = _mm_extract_epi32(trailing_packed, 2); + break; + } + case 8: + _mm256_storeu_si256(reinterpret_cast<__m256i*>(dst), v); + break; + default: + RUY_DCHECK_LE(residual_rows, 8); + break; + } +} + +inline void mm256_storeu_epi32(std::int32_t* dst, const __m256 v) { + _mm256_storeu_si256(reinterpret_cast<__m256i*>(dst), v); +} + +inline float mm256_get1_ps(const __m256 a, int i) { + __m256i ai = _mm256_castps_si256(a); + int float_val_as_int; + switch (i) { + case 0: + float_val_as_int = _mm256_extract_epi32(ai, 0); + break; + case 1: + float_val_as_int = _mm256_extract_epi32(ai, 1); + break; + case 2: + float_val_as_int = _mm256_extract_epi32(ai, 2); + break; + case 3: + float_val_as_int = _mm256_extract_epi32(ai, 3); + break; + case 4: + float_val_as_int = _mm256_extract_epi32(ai, 4); + break; + case 5: + float_val_as_int = _mm256_extract_epi32(ai, 5); + break; + case 6: + float_val_as_int = _mm256_extract_epi32(ai, 6); + break; + case 7: + float_val_as_int = _mm256_extract_epi32(ai, 7); + break; + default: + RUY_DCHECK_LT(i, 8); + return .0f; + } + return reinterpret_cast(float_val_as_int); +} + +inline __m256 mm256_n_loadu_ps(int i, const float* src) { + switch (i) { + case 0: + return _mm256_setzero_ps(); + case 1: + return _mm256_setr_m128(_mm_setr_ps(src[0], .0f, .0f, .0f), + _mm_setzero_ps()); + case 2: + return _mm256_setr_m128(_mm_setr_ps(src[0], src[1], .0f, .0f), + _mm_setzero_ps()); + case 3: + return _mm256_setr_m128(_mm_setr_ps(src[0], src[1], src[2], .0f), + _mm_setzero_ps()); + case 4: + return _mm256_setr_m128(_mm_setr_ps(src[0], src[1], src[2], src[3]), + _mm_setzero_ps()); + case 5: + return _mm256_setr_ps(src[0], src[1], src[2], src[3], src[4], .0f, .0f, + .0f); + case 6: + return _mm256_setr_ps(src[0], src[1], src[2], src[3], src[4], src[5], .0f, + .0f); + case 7: + return _mm256_setr_ps(src[0], src[1], src[2], src[3], src[4], src[5], + src[6], .0f); + case 8: + return _mm256_loadu_ps(src); + default: + RUY_DCHECK_LT(i, 9); + return _mm256_setzero_ps(); + } +} + +inline void mm256_n_storeu_ps(float* dst, int residual_rows, const __m256 v) { + for (int i = 0; i < residual_rows; ++i) { + dst[i] = intrin_utils::mm256_get1_ps(v, i); + } +} +} // namespace intrin_utils +} // namespace + +void Kernel8bitAvx2(const KernelParams8bit<8, 8>& params) { + profiler::ScopeLabel label("Kernel kAvx2 8-bit"); + const std::int8_t splitter_idx_data[32] = { + 0, 1, 4, 5, 8, 9, 12, 13, // + 2, 3, 6, 7, 10, 11, 14, 15, // + 0, 1, 4, 5, 8, 9, 12, 13, // + 2, 3, 6, 7, 10, 11, 14, 15 // + }; + + std::int32_t dst_stride; + if ((params.dst_type_id == DstTypeId::kValue) || + (params.dst_type_id == DstTypeId::kValue)) { + dst_stride = params.dst_stride; + } else if (params.dst_type_id == DstTypeId::kValue) { + dst_stride = params.dst_stride / sizeof(std::int16_t); + } else if (params.dst_type_id == DstTypeId::kValue) { + dst_stride = params.dst_stride / sizeof(std::int32_t); + } else { + RUY_DCHECK(false); + } + + int bias_ptr_block_increment = + params.flags & RUY_ASM_FLAG_HAS_BIAS ? kAvx8bitBlockSize : 0; + + const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; + void* dst_col_ptr = params.dst_base_ptr; + const std::int32_t* bias_col_ptr = params.bias; + if (params.flags & RUY_ASM_FLAG_HAS_BIAS) { + bias_col_ptr += params.start_row; + } + + for (int col = params.start_col; col <= params.last_col; + col += kAvx8bitBlockSize) { + const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; + void* dst_ptr = dst_col_ptr; + const std::int32_t* bias_ptr = bias_col_ptr; + + const std::int32_t lhs_zero_point = params.lhs_zero_point; + const bool has_rhs_sums_offsets = + (params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && lhs_zero_point; + std::int32_t rhs_sums_offsets[8]; + if (has_rhs_sums_offsets) { + const __m256i rhs_sums_offset_v = _mm256_mullo_epi32( + _mm256_set1_epi32(lhs_zero_point), + _mm256_loadu_si256( + reinterpret_cast<__m256i const*>(¶ms.rhs_sums[col]))); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_sums_offsets), + rhs_sums_offset_v); + } + + for (int row = params.start_row; row <= params.last_row; + row += kAvx8bitBlockSize) { + const int residual_rows = + std::min(params.dst_rows - row, kAvx8bitBlockSize); + const int residual_cols = + std::min(params.dst_cols - col, kAvx8bitBlockSize); + + const __m256i splitter_idx = _mm256_loadu_si256( + reinterpret_cast<__m256i const*>(splitter_idx_data)); + + __m256i accum_data_v0; + __m256i accum_data_v1; + __m256i accum_data_v2; + __m256i accum_data_v3; + __m256i accum_data_v4; + __m256i accum_data_v5; + __m256i accum_data_v6; + __m256i accum_data_v7; + + // Initialize with bias. + __m256i initial_accum_data = + intrin_utils::mm256_n_loadu_epi32(residual_rows, bias_ptr); + bias_ptr += bias_ptr_block_increment; + + // Adjustments common across columns. + const std::int32_t rhs_zero_point = params.rhs_zero_point; + if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && rhs_zero_point) { + const __m256i lhs_sums_offset = _mm256_mullo_epi32( + _mm256_set1_epi32(rhs_zero_point), + _mm256_loadu_si256( + reinterpret_cast<__m256i const*>(¶ms.lhs_sums[row]))); + initial_accum_data = + _mm256_sub_epi32(initial_accum_data, lhs_sums_offset); + } + const std::int32_t prod_zp_depth = params.prod_zp_depth; + if (prod_zp_depth) { + initial_accum_data = _mm256_add_epi32(initial_accum_data, + _mm256_set1_epi32(prod_zp_depth)); + } + + // Adjustments differing across columns. + if (has_rhs_sums_offsets) { + accum_data_v0 = _mm256_sub_epi32( + initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[0])); + accum_data_v1 = _mm256_sub_epi32( + initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[1])); + accum_data_v2 = _mm256_sub_epi32( + initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[2])); + accum_data_v3 = _mm256_sub_epi32( + initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[3])); + accum_data_v4 = _mm256_sub_epi32( + initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[4])); + accum_data_v5 = _mm256_sub_epi32( + initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[5])); + accum_data_v6 = _mm256_sub_epi32( + initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[6])); + accum_data_v7 = _mm256_sub_epi32( + initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[7])); + } else { + accum_data_v0 = initial_accum_data; + accum_data_v1 = initial_accum_data; + accum_data_v2 = initial_accum_data; + accum_data_v3 = initial_accum_data; + accum_data_v4 = initial_accum_data; + accum_data_v5 = initial_accum_data; + accum_data_v6 = initial_accum_data; + accum_data_v7 = initial_accum_data; + } + + const std::int8_t* lhs_ptr = lhs_col_ptr; + const std::int8_t* rhs_ptr = rhs_col_ptr; + for (int d = 0; d < params.depth; d += kAvx8bitInnerSize) { + const __m256i lhs_data = + _mm256_load_si256(reinterpret_cast(lhs_ptr)); + const __m256i rhs_data_8bit = + _mm256_load_si256(reinterpret_cast(rhs_ptr)); + + // Each "int32" is two 16-bit RHS values, sign extended from 8-bit. + std::int32_t rhs_data[16]; + const __m128i rhs_data_bottom_lane = + _mm256_castsi256_si128(rhs_data_8bit); + const __m128i rhs_data_top_lane = + _mm256_extracti128_si256(rhs_data_8bit, 1); + const __m256i rhs_16_bit_dup_low = + _mm256_cvtepi8_epi16(rhs_data_bottom_lane); + const __m256i rhs_16_bit_dup_high = + _mm256_cvtepi8_epi16(rhs_data_top_lane); + // Now that we have cast the RHS data, we store it so that each value + // can be separately loaded in the accumulation loop. + _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_data), + rhs_16_bit_dup_low); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_data + 8), + rhs_16_bit_dup_high); + + // NOTE: There may be opportunities for permuting the data in the + // packing code instead of here. + const __m256i lhs_data_split = + _mm256_shuffle_epi8(lhs_data, splitter_idx); + const __m256i lhs_data_split_expand_bottom = + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(lhs_data_split, 0)); + const __m256i lhs_data_split_expand_top = + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(lhs_data_split, 1)); + + // Take bytes 0, 1, 4, 5, 8, 9, ... expanded to 16-bit. + const __m256i lhs_16_bit_low = _mm256_permute2x128_si256( + lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x20); + // Take bytes 2, 3, 6, 7, 10, 11, ... expanded to 16-bit. + const __m256i lhs_16_bit_high = _mm256_permute2x128_si256( + lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x31); + // Accumulate for column 0. + { + const std::int32_t low_rhs_value = rhs_data[0]; + const std::int32_t high_rhs_value = rhs_data[1]; + + const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value); + const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value); + + accum_data_v0 = _mm256_add_epi32( + accum_data_v0, + _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); + accum_data_v0 = _mm256_add_epi32( + accum_data_v0, + _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); + } + // Accumulate for column 1. + { + const std::int32_t low_rhs_value = rhs_data[2]; + const std::int32_t high_rhs_value = rhs_data[3]; + + const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value); + const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value); + + accum_data_v1 = _mm256_add_epi32( + accum_data_v1, + _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); + accum_data_v1 = _mm256_add_epi32( + accum_data_v1, + _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); + } + // Accumulate for column 2. + { + const std::int32_t low_rhs_value = rhs_data[4]; + const std::int32_t high_rhs_value = rhs_data[5]; + + const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value); + const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value); + + accum_data_v2 = _mm256_add_epi32( + accum_data_v2, + _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); + accum_data_v2 = _mm256_add_epi32( + accum_data_v2, + _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); + } + // Accumulate for column 3. + { + const std::int32_t low_rhs_value = rhs_data[6]; + const std::int32_t high_rhs_value = rhs_data[7]; + + const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value); + const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value); + + accum_data_v3 = _mm256_add_epi32( + accum_data_v3, + _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); + accum_data_v3 = _mm256_add_epi32( + accum_data_v3, + _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); + } + // Accumulate for column 4. + { + const std::int32_t low_rhs_value = rhs_data[8]; + const std::int32_t high_rhs_value = rhs_data[9]; + + const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value); + const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value); + + accum_data_v4 = _mm256_add_epi32( + accum_data_v4, + _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); + accum_data_v4 = _mm256_add_epi32( + accum_data_v4, + _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); + } + // Accumulate for column 5. + { + const std::int32_t low_rhs_value = rhs_data[10]; + const std::int32_t high_rhs_value = rhs_data[11]; + + const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value); + const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value); + + accum_data_v5 = _mm256_add_epi32( + accum_data_v5, + _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); + accum_data_v5 = _mm256_add_epi32( + accum_data_v5, + _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); + } + // Accumulate for column 6. + { + const std::int32_t low_rhs_value = rhs_data[12]; + const std::int32_t high_rhs_value = rhs_data[13]; + + const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value); + const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value); + + accum_data_v6 = _mm256_add_epi32( + accum_data_v6, + _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); + accum_data_v6 = _mm256_add_epi32( + accum_data_v6, + _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); + } + // Accumulate for column 7. + { + const std::int32_t low_rhs_value = rhs_data[14]; + const std::int32_t high_rhs_value = rhs_data[15]; + + const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value); + const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value); + + accum_data_v7 = _mm256_add_epi32( + accum_data_v7, + _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); + accum_data_v7 = _mm256_add_epi32( + accum_data_v7, + _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); + } + + lhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize; + rhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize; + } + + if (params.dst_type_id != DstTypeId::kValue) { + __m256i m_vector; + __m256i e_vector; + // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT. + if (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) { + m_vector = intrin_utils::mm256_n_loadu_epi32( + residual_rows, ¶ms.multiplier_fixedpoint[row]); + e_vector = intrin_utils::mm256_n_loadu_epi32( + residual_rows, ¶ms.multiplier_exponent[row]); + } else { + // These arrays have size LhsCols, and are pre-filled. + m_vector = _mm256_set1_epi32(params.multiplier_fixedpoint[0]); + e_vector = _mm256_set1_epi32(params.multiplier_exponent[0]); + } + + const __m256i m_64bit_low = + _mm256_cvtepi32_epi64(_mm256_extracti128_si256(m_vector, 0)); + const __m256i m_64bit_high = + _mm256_cvtepi32_epi64(_mm256_extracti128_si256(m_vector, 1)); + + const __m256i zero_vector = _mm256_setzero_si256(); + const __m256i left_shift = _mm256_max_epi32(e_vector, zero_vector); + const __m256i neg_e_vector = _mm256_sub_epi32(zero_vector, e_vector); + const __m256i right_shift = _mm256_max_epi32(neg_e_vector, zero_vector); + const __m256i final_right_shift = + _mm256_add_epi32(right_shift, _mm256_set1_epi32(31)); + const __m256i final_right_shift_low = _mm256_cvtepi32_epi64( + _mm256_extracti128_si256(final_right_shift, 0)); + const __m256i final_right_shift_high = _mm256_cvtepi32_epi64( + _mm256_extracti128_si256(final_right_shift, 1)); + // Really we want 0x100000000, but use half to avoid overflowing. + const __m256i convert_to_signed_halved = + _mm256_srlv_epi32(_mm256_set1_epi32(0x80000000), right_shift); + const __m256i convert_to_unsigned_64 = + _mm256_set1_epi64x(0x8000000000000000); + + __m256i post_scaling_offset = _mm256_add_epi32( + convert_to_signed_halved, convert_to_signed_halved); + + const __m256i offset_vector = + _mm256_slli_epi64(_mm256_set1_epi64x(1), 30); + // Really these should be shifted by neg_e_vector, but tests pass when + // using right_shift. + const __m256i offset_vector_low = _mm256_add_epi64( + _mm256_sllv_epi64(offset_vector, + _mm256_cvtepi32_epi64( + _mm256_extracti128_si256(right_shift, 0))), + convert_to_unsigned_64); + const __m256i offset_vector_high = _mm256_add_epi64( + _mm256_sllv_epi64(offset_vector, + _mm256_cvtepi32_epi64( + _mm256_extracti128_si256(right_shift, 1))), + convert_to_unsigned_64); + + if (params.dst_zero_point) { + const __m256i dst_zero_point = + _mm256_set1_epi32(params.dst_zero_point); + // The post-scaling offset is subtracted later, so this has the effect + // of adding the zero point. + post_scaling_offset = + _mm256_sub_epi32(post_scaling_offset, dst_zero_point); + } + +#if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING) + RUY_DCHECK(false); +#endif + const __m256i repack_perm = _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7); + + // We cannot do + // + // scaled_v_low = + // _mm256_srav_epi64(scaled_v_low, final_right_shift_low); + // scaled_v_high = + // _mm256_srav_epi64(scaled_v_high, final_right_shift_high); + // + // since this instruction is not in AVX2. Instead we use + // _mm256_srlv_epi64, but this is an unsigned shift, so we applied + // offsets before (convert_to_unsigned_64) and after + // (convert_to_signed_halved). + // + // The overall process is, for 64-bit scaled accumulator: + // unsigned_accum = signed_accum + 1 << 63; + // unsigned_accum = (unsigned_accum >> right_shift) >> 31; + // signed_accum = unsigned_accum - ((1 << 32) >> right_shift) / 2 * 2; + + // There are various ways to repack the results, in the absence of + // _mm256_cvtepi64_epi32() or anything like it. + // A. + // accum_data_v[j] = + // _mm256_set_epi32(_mm256_extract_epi32(scaled_v_high, 6), + // _mm256_extract_epi32(scaled_v_high, 4), + // _mm256_extract_epi32(scaled_v_high, 2), + // _mm256_extract_epi32(scaled_v_high, 0), + // _mm256_extract_epi32(scaled_v_low, 6), + // _mm256_extract_epi32(scaled_v_low, 4), + // _mm256_extract_epi32(scaled_v_low, 2), + // _mm256_extract_epi32(scaled_v_low, 0)); + // B. + // scaled_v_low = _mm256_shuffle_epi32(scaled_v_low, 0xd8); + // scaled_v_high = _mm256_shuffle_epi32(scaled_v_high, 0xd8); + // accum_data_v[j] = + // _mm256_set_epi64x(_mm256_extract_epi64(scaled_v_high, 2), + // _mm256_extract_epi64(scaled_v_high, 0), + // _mm256_extract_epi64(scaled_v_low, 2), + // _mm256_extract_epi64(scaled_v_low, 0)); + // C. + // scaled_v_low = + // _mm256_permutevar8x32_epi32(scaled_v_low, repack_perm); + // scaled_v_high = + // _mm256_permutevar8x32_epi32(scaled_v_high, repack_perm); + // accum_data_v[j] = + // _mm256_permute2x128_si256(scaled_v_low, scaled_v_high, 0x20); + // + // However, we choose the following because it uses two lighter + // instructions. The permutation does have a longer latency, but this + // loop can be unrolled. + // D. + // scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32); + // __m256i results = + // _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa); + // results = _mm256_permutevar8x32_epi32(results, repack_perm); + // accum_data_v[j] = _mm256_sub_epi32(results, post_scaling_offset); + { + __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v0, left_shift); + // Apply the fixed-point part of the multiplier. + __m256i scaled_v_low = _mm256_mul_epi32( + _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)), + m_64bit_low); + __m256i scaled_v_high = _mm256_mul_epi32( + _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)), + m_64bit_high); + + scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector_low); + scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector_high); + + scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low); + scaled_v_high = + _mm256_srlv_epi64(scaled_v_high, final_right_shift_high); + + scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32); + __m256i results = + _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa); + results = _mm256_permutevar8x32_epi32(results, repack_perm); + + accum_data_v0 = _mm256_sub_epi32(results, post_scaling_offset); + } + { + __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v1, left_shift); + // Apply the fixed-point part of the multiplier. + __m256i scaled_v_low = _mm256_mul_epi32( + _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)), + m_64bit_low); + __m256i scaled_v_high = _mm256_mul_epi32( + _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)), + m_64bit_high); + + scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector_low); + scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector_high); + + scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low); + scaled_v_high = + _mm256_srlv_epi64(scaled_v_high, final_right_shift_high); + + scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32); + __m256i results = + _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa); + results = _mm256_permutevar8x32_epi32(results, repack_perm); + + accum_data_v1 = _mm256_sub_epi32(results, post_scaling_offset); + } + { + __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v2, left_shift); + // Apply the fixed-point part of the multiplier. + __m256i scaled_v_low = _mm256_mul_epi32( + _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)), + m_64bit_low); + __m256i scaled_v_high = _mm256_mul_epi32( + _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)), + m_64bit_high); + + scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector_low); + scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector_high); + + scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low); + scaled_v_high = + _mm256_srlv_epi64(scaled_v_high, final_right_shift_high); + + scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32); + __m256i results = + _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa); + results = _mm256_permutevar8x32_epi32(results, repack_perm); + + accum_data_v2 = _mm256_sub_epi32(results, post_scaling_offset); + } + { + __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v3, left_shift); + // Apply the fixed-point part of the multiplier. + __m256i scaled_v_low = _mm256_mul_epi32( + _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)), + m_64bit_low); + __m256i scaled_v_high = _mm256_mul_epi32( + _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)), + m_64bit_high); + + scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector_low); + scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector_high); + + scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low); + scaled_v_high = + _mm256_srlv_epi64(scaled_v_high, final_right_shift_high); + + scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32); + __m256i results = + _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa); + results = _mm256_permutevar8x32_epi32(results, repack_perm); + + accum_data_v3 = _mm256_sub_epi32(results, post_scaling_offset); + } + { + __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v4, left_shift); + // Apply the fixed-point part of the multiplier. + __m256i scaled_v_low = _mm256_mul_epi32( + _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)), + m_64bit_low); + __m256i scaled_v_high = _mm256_mul_epi32( + _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)), + m_64bit_high); + + scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector_low); + scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector_high); + + scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low); + scaled_v_high = + _mm256_srlv_epi64(scaled_v_high, final_right_shift_high); + + scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32); + __m256i results = + _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa); + results = _mm256_permutevar8x32_epi32(results, repack_perm); + + accum_data_v4 = _mm256_sub_epi32(results, post_scaling_offset); + } + { + __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v5, left_shift); + // Apply the fixed-point part of the multiplier. + __m256i scaled_v_low = _mm256_mul_epi32( + _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)), + m_64bit_low); + __m256i scaled_v_high = _mm256_mul_epi32( + _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)), + m_64bit_high); + + scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector_low); + scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector_high); + + scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low); + scaled_v_high = + _mm256_srlv_epi64(scaled_v_high, final_right_shift_high); + + scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32); + __m256i results = + _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa); + results = _mm256_permutevar8x32_epi32(results, repack_perm); + + accum_data_v5 = _mm256_sub_epi32(results, post_scaling_offset); + } + { + __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v6, left_shift); + // Apply the fixed-point part of the multiplier. + __m256i scaled_v_low = _mm256_mul_epi32( + _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)), + m_64bit_low); + __m256i scaled_v_high = _mm256_mul_epi32( + _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)), + m_64bit_high); + + scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector_low); + scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector_high); + + scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low); + scaled_v_high = + _mm256_srlv_epi64(scaled_v_high, final_right_shift_high); + + scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32); + __m256i results = + _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa); + results = _mm256_permutevar8x32_epi32(results, repack_perm); + + accum_data_v6 = _mm256_sub_epi32(results, post_scaling_offset); + } + { + __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v7, left_shift); + // Apply the fixed-point part of the multiplier. + __m256i scaled_v_low = _mm256_mul_epi32( + _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)), + m_64bit_low); + __m256i scaled_v_high = _mm256_mul_epi32( + _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)), + m_64bit_high); + + scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector_low); + scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector_high); + + scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low); + scaled_v_high = + _mm256_srlv_epi64(scaled_v_high, final_right_shift_high); + + scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32); + __m256i results = + _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa); + results = _mm256_permutevar8x32_epi32(results, repack_perm); + + accum_data_v7 = _mm256_sub_epi32(results, post_scaling_offset); + } + } + const __m256i clamp_max_v = _mm256_set1_epi32(params.clamp_max); + const __m256i clamp_min_v = _mm256_set1_epi32(params.clamp_min); + const bool store_full_block = (residual_rows == kAvx8bitBlockSize) && + (residual_cols == kAvx8bitBlockSize); + + __m256i accum_data_v[kAvx8bitBlockSize]; + if (!store_full_block) { + accum_data_v[0] = accum_data_v0; + accum_data_v[1] = accum_data_v1; + accum_data_v[2] = accum_data_v2; + accum_data_v[3] = accum_data_v3; + accum_data_v[4] = accum_data_v4; + accum_data_v[5] = accum_data_v5; + accum_data_v[6] = accum_data_v6; + accum_data_v[7] = accum_data_v7; + } + + if (params.dst_type_id == DstTypeId::kValue) { + std::int8_t* tmp_ptr = static_cast(dst_ptr); + if (store_full_block) { + accum_data_v0 = _mm256_min_epi32(accum_data_v0, clamp_max_v); + accum_data_v0 = _mm256_max_epi32(accum_data_v0, clamp_min_v); + accum_data_v1 = _mm256_min_epi32(accum_data_v1, clamp_max_v); + accum_data_v1 = _mm256_max_epi32(accum_data_v1, clamp_min_v); + accum_data_v2 = _mm256_min_epi32(accum_data_v2, clamp_max_v); + accum_data_v2 = _mm256_max_epi32(accum_data_v2, clamp_min_v); + accum_data_v3 = _mm256_min_epi32(accum_data_v3, clamp_max_v); + accum_data_v3 = _mm256_max_epi32(accum_data_v3, clamp_min_v); + accum_data_v4 = _mm256_min_epi32(accum_data_v4, clamp_max_v); + accum_data_v4 = _mm256_max_epi32(accum_data_v4, clamp_min_v); + accum_data_v5 = _mm256_min_epi32(accum_data_v5, clamp_max_v); + accum_data_v5 = _mm256_max_epi32(accum_data_v5, clamp_min_v); + accum_data_v6 = _mm256_min_epi32(accum_data_v6, clamp_max_v); + accum_data_v6 = _mm256_max_epi32(accum_data_v6, clamp_min_v); + accum_data_v7 = _mm256_min_epi32(accum_data_v7, clamp_max_v); + accum_data_v7 = _mm256_max_epi32(accum_data_v7, clamp_min_v); + intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[0 * dst_stride], + accum_data_v0); + intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[1 * dst_stride], + accum_data_v1); + intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[2 * dst_stride], + accum_data_v2); + intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[3 * dst_stride], + accum_data_v3); + intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[4 * dst_stride], + accum_data_v4); + intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[5 * dst_stride], + accum_data_v5); + intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[6 * dst_stride], + accum_data_v6); + intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[7 * dst_stride], + accum_data_v7); + } else { + for (int j = 0; j < residual_cols; ++j) { + __m256 result = accum_data_v[j]; + result = _mm256_min_epi32(result, clamp_max_v); + result = _mm256_max_epi32(result, clamp_min_v); + intrin_utils::mm256_n_storeu_cvtepi32_epi8(tmp_ptr, residual_rows, + result); + tmp_ptr += dst_stride; + } + } + dst_ptr = static_cast(static_cast(dst_ptr) + + kAvx8bitBlockSize); + } else if (params.dst_type_id == DstTypeId::kValue) { + std::uint8_t* tmp_ptr = static_cast(dst_ptr); + if (store_full_block) { + accum_data_v0 = _mm256_min_epi32(accum_data_v0, clamp_max_v); + accum_data_v0 = _mm256_max_epi32(accum_data_v0, clamp_min_v); + accum_data_v1 = _mm256_min_epi32(accum_data_v1, clamp_max_v); + accum_data_v1 = _mm256_max_epi32(accum_data_v1, clamp_min_v); + accum_data_v2 = _mm256_min_epi32(accum_data_v2, clamp_max_v); + accum_data_v2 = _mm256_max_epi32(accum_data_v2, clamp_min_v); + accum_data_v3 = _mm256_min_epi32(accum_data_v3, clamp_max_v); + accum_data_v3 = _mm256_max_epi32(accum_data_v3, clamp_min_v); + accum_data_v4 = _mm256_min_epi32(accum_data_v4, clamp_max_v); + accum_data_v4 = _mm256_max_epi32(accum_data_v4, clamp_min_v); + accum_data_v5 = _mm256_min_epi32(accum_data_v5, clamp_max_v); + accum_data_v5 = _mm256_max_epi32(accum_data_v5, clamp_min_v); + accum_data_v6 = _mm256_min_epi32(accum_data_v6, clamp_max_v); + accum_data_v6 = _mm256_max_epi32(accum_data_v6, clamp_min_v); + accum_data_v7 = _mm256_min_epi32(accum_data_v7, clamp_max_v); + accum_data_v7 = _mm256_max_epi32(accum_data_v7, clamp_min_v); + intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[0], accum_data_v0); + intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[dst_stride], + accum_data_v1); + intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[2 * dst_stride], + accum_data_v2); + intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[3 * dst_stride], + accum_data_v3); + intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[4 * dst_stride], + accum_data_v4); + intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[5 * dst_stride], + accum_data_v5); + intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[6 * dst_stride], + accum_data_v6); + intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[7 * dst_stride], + accum_data_v7); + } else { + for (int j = 0; j < residual_cols; ++j) { + __m256 result = accum_data_v[j]; + result = _mm256_min_epi32(result, clamp_max_v); + result = _mm256_max_epi32(result, clamp_min_v); + intrin_utils::mm256_n_storeu_cvtepi32_epi8(tmp_ptr, residual_rows, + result); + tmp_ptr += dst_stride; + } + } + dst_ptr = static_cast(static_cast(dst_ptr) + + kAvx8bitBlockSize); + } else if (params.dst_type_id == DstTypeId::kValue) { + std::int16_t* tmp_ptr = static_cast(dst_ptr); + if (store_full_block) { + accum_data_v0 = _mm256_min_epi32(accum_data_v0, clamp_max_v); + accum_data_v0 = _mm256_max_epi32(accum_data_v0, clamp_min_v); + accum_data_v1 = _mm256_min_epi32(accum_data_v1, clamp_max_v); + accum_data_v1 = _mm256_max_epi32(accum_data_v1, clamp_min_v); + accum_data_v2 = _mm256_min_epi32(accum_data_v2, clamp_max_v); + accum_data_v2 = _mm256_max_epi32(accum_data_v2, clamp_min_v); + accum_data_v3 = _mm256_min_epi32(accum_data_v3, clamp_max_v); + accum_data_v3 = _mm256_max_epi32(accum_data_v3, clamp_min_v); + accum_data_v4 = _mm256_min_epi32(accum_data_v4, clamp_max_v); + accum_data_v4 = _mm256_max_epi32(accum_data_v4, clamp_min_v); + accum_data_v5 = _mm256_min_epi32(accum_data_v5, clamp_max_v); + accum_data_v5 = _mm256_max_epi32(accum_data_v5, clamp_min_v); + accum_data_v6 = _mm256_min_epi32(accum_data_v6, clamp_max_v); + accum_data_v6 = _mm256_max_epi32(accum_data_v6, clamp_min_v); + accum_data_v7 = _mm256_min_epi32(accum_data_v7, clamp_max_v); + accum_data_v7 = _mm256_max_epi32(accum_data_v7, clamp_min_v); + intrin_utils::mm256_storeu_cvtepi32_epi16(&tmp_ptr[0], accum_data_v0); + intrin_utils::mm256_storeu_cvtepi32_epi16(&tmp_ptr[dst_stride], + accum_data_v1); + intrin_utils::mm256_storeu_cvtepi32_epi16(&tmp_ptr[2 * dst_stride], + accum_data_v2); + intrin_utils::mm256_storeu_cvtepi32_epi16(&tmp_ptr[3 * dst_stride], + accum_data_v3); + intrin_utils::mm256_storeu_cvtepi32_epi16(&tmp_ptr[4 * dst_stride], + accum_data_v4); + intrin_utils::mm256_storeu_cvtepi32_epi16(&tmp_ptr[5 * dst_stride], + accum_data_v5); + intrin_utils::mm256_storeu_cvtepi32_epi16(&tmp_ptr[6 * dst_stride], + accum_data_v6); + intrin_utils::mm256_storeu_cvtepi32_epi16(&tmp_ptr[7 * dst_stride], + accum_data_v7); + } else { + for (int j = 0; j < residual_cols; ++j) { + __m256 result = accum_data_v[j]; + result = _mm256_min_epi32(result, clamp_max_v); + result = _mm256_max_epi32(result, clamp_min_v); + intrin_utils::mm256_n_storeu_cvtepi32_epi16(tmp_ptr, residual_rows, + result); + tmp_ptr += dst_stride; + } + } + dst_ptr = static_cast(static_cast(dst_ptr) + + kAvx8bitBlockSize); + } else if (params.dst_type_id == DstTypeId::kValue) { + if (store_full_block) { + std::int32_t* tmp_ptr = static_cast(dst_ptr); + intrin_utils::mm256_storeu_epi32(&tmp_ptr[0], accum_data_v0); + intrin_utils::mm256_storeu_epi32(&tmp_ptr[dst_stride], accum_data_v1); + intrin_utils::mm256_storeu_epi32(&tmp_ptr[2 * dst_stride], + accum_data_v2); + intrin_utils::mm256_storeu_epi32(&tmp_ptr[3 * dst_stride], + accum_data_v3); + intrin_utils::mm256_storeu_epi32(&tmp_ptr[4 * dst_stride], + accum_data_v4); + intrin_utils::mm256_storeu_epi32(&tmp_ptr[5 * dst_stride], + accum_data_v5); + intrin_utils::mm256_storeu_epi32(&tmp_ptr[6 * dst_stride], + accum_data_v6); + intrin_utils::mm256_storeu_epi32(&tmp_ptr[7 * dst_stride], + accum_data_v7); + } else { + std::int32_t* dst_block_ptr = static_cast(dst_ptr); + for (int j = 0; j < residual_cols; ++j) { + intrin_utils::mm256_n_storeu_epi32(dst_block_ptr, residual_rows, + accum_data_v[j]); + dst_block_ptr += dst_stride; + } + } + dst_ptr = static_cast(static_cast(dst_ptr) + + kAvx8bitBlockSize); + } else { + RUY_DCHECK(false); + } + + lhs_col_ptr += kAvx8bitBlockSize * params.lhs_stride; + } // End row-block loop. + + dst_col_ptr = static_cast(static_cast(dst_col_ptr) + + kAvx8bitBlockSize * params.dst_stride); + rhs_col_ptr += kAvx8bitBlockSize * params.rhs_stride; + } // End col-block loop. +} // NOLINT(readability/fn_size) + +void Kernel8bitAvx2SingleCol(const KernelParams8bit<8, 8>& params) { + profiler::ScopeLabel label("Kernel kAvx2 8-bit GEMV"); + + RUY_DCHECK_EQ(params.dst_cols, 1); + RUY_DCHECK_EQ(params.last_col, 0); + RUY_DCHECK_EQ(params.start_col, 0); + + const std::int8_t splitter_idx_data[32] = { + 0, 1, 4, 5, 8, 9, 12, 13, // + 2, 3, 6, 7, 10, 11, 14, 15, // + 0, 1, 4, 5, 8, 9, 12, 13, // + 2, 3, 6, 7, 10, 11, 14, 15 // + }; + + int bias_ptr_block_increment = + params.flags & RUY_ASM_FLAG_HAS_BIAS ? kAvx8bitBlockSize : 0; + + const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; + void* dst_col_ptr = params.dst_base_ptr; + const std::int32_t* bias_col_ptr = params.bias; + if (params.flags & RUY_ASM_FLAG_HAS_BIAS) { + bias_col_ptr += params.start_row; + } + + const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; + void* dst_ptr = dst_col_ptr; + const std::int32_t* bias_ptr = bias_col_ptr; + + const std::int32_t lhs_zero_point = params.lhs_zero_point; + const bool has_rhs_sums_offsets = + (params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && lhs_zero_point; + std::int32_t rhs_sums_offsets[8]; + if (has_rhs_sums_offsets) { + const __m256i rhs_sums_offset_v = _mm256_mullo_epi32( + _mm256_set1_epi32(lhs_zero_point), + _mm256_loadu_si256( + reinterpret_cast<__m256i const*>(¶ms.rhs_sums[0]))); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_sums_offsets), + rhs_sums_offset_v); + } + + for (int row = params.start_row; row <= params.last_row; + row += kAvx8bitBlockSize) { + const int residual_rows = + std::min(params.dst_rows - row, kAvx8bitBlockSize); + + const __m256i splitter_idx = + _mm256_loadu_si256(reinterpret_cast<__m256i const*>(splitter_idx_data)); + + __m256i accum_data_v0; + + // Initialize with bias. + __m256i initial_accum_data = + intrin_utils::mm256_n_loadu_epi32(residual_rows, bias_ptr); + bias_ptr += bias_ptr_block_increment; + + // Adjustments common across columns. + const std::int32_t rhs_zero_point = params.rhs_zero_point; + if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && rhs_zero_point) { + const __m256i lhs_sums_offset = _mm256_mullo_epi32( + _mm256_set1_epi32(rhs_zero_point), + _mm256_loadu_si256( + reinterpret_cast<__m256i const*>(¶ms.lhs_sums[row]))); + initial_accum_data = + _mm256_sub_epi32(initial_accum_data, lhs_sums_offset); + } + const std::int32_t prod_zp_depth = params.prod_zp_depth; + if (prod_zp_depth) { + initial_accum_data = _mm256_add_epi32(initial_accum_data, + _mm256_set1_epi32(prod_zp_depth)); + } + + // Adjustments differing across columns. + if (has_rhs_sums_offsets) { + accum_data_v0 = _mm256_sub_epi32(initial_accum_data, + _mm256_set1_epi32(rhs_sums_offsets[0])); + } else { + accum_data_v0 = initial_accum_data; + } + + const std::int8_t* lhs_ptr = lhs_col_ptr; + const std::int8_t* rhs_ptr = rhs_col_ptr; + for (int d = 0; d < params.depth; d += kAvx8bitInnerSize) { + const __m256i lhs_data = + _mm256_load_si256(reinterpret_cast(lhs_ptr)); + const __m128i rhs_data_8bit = _mm_loadu_si32(rhs_ptr); + + // Each "int32" is two 16-bit RHS values, sign extended from 8-bit. + // For simplicity we load 4x the data that we need and process twice the + // data that we need and store only the data we need. + std::int32_t rhs_data[2]; + const __m128i rhs_16_bit_dup = _mm_cvtepi8_epi16(rhs_data_8bit); + // Now that we have cast the RHS data, we store it so that each value + // can be separately loaded in the accumulation loop. + _mm_storeu_si64(reinterpret_cast<__m128i*>(rhs_data), rhs_16_bit_dup); + + // NOTE: There may be opportunities for permuting the data in the packing + // code instead of here. + const __m256i lhs_data_split = + _mm256_shuffle_epi8(lhs_data, splitter_idx); + const __m256i lhs_data_split_expand_bottom = + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(lhs_data_split, 0)); + const __m256i lhs_data_split_expand_top = + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(lhs_data_split, 1)); + + // Take bytes 0, 1, 4, 5, 8, 9, ... expanded to 16-bit. + const __m256i lhs_16_bit_low = _mm256_permute2x128_si256( + lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x20); + // Take bytes 2, 3, 6, 7, 10, 11, ... expanded to 16-bit. + const __m256i lhs_16_bit_high = _mm256_permute2x128_si256( + lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x31); + // Accumulate for column 0. + const std::int32_t low_rhs_value = rhs_data[0]; + const std::int32_t high_rhs_value = rhs_data[1]; + + const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value); + const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value); + + accum_data_v0 = _mm256_add_epi32( + accum_data_v0, _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); + accum_data_v0 = _mm256_add_epi32( + accum_data_v0, + _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); + + lhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize; + rhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize; + } + + if (params.dst_type_id != DstTypeId::kValue) { + __m256i m_vector; + __m256i e_vector; + // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT. + if (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) { + m_vector = intrin_utils::mm256_n_loadu_epi32( + residual_rows, ¶ms.multiplier_fixedpoint[row]); + e_vector = intrin_utils::mm256_n_loadu_epi32( + residual_rows, ¶ms.multiplier_exponent[row]); + } else { + // These arrays have size LhsCols, and are pre-filled. + m_vector = _mm256_set1_epi32(params.multiplier_fixedpoint[0]); + e_vector = _mm256_set1_epi32(params.multiplier_exponent[0]); + } + + const __m256i m_64bit_low = + _mm256_cvtepi32_epi64(_mm256_extracti128_si256(m_vector, 0)); + const __m256i m_64bit_high = + _mm256_cvtepi32_epi64(_mm256_extracti128_si256(m_vector, 1)); + + const __m256i zero_vector = _mm256_setzero_si256(); + const __m256i left_shift = _mm256_max_epi32(e_vector, zero_vector); + const __m256i neg_e_vector = _mm256_sub_epi32(zero_vector, e_vector); + const __m256i right_shift = _mm256_max_epi32(neg_e_vector, zero_vector); + const __m256i final_right_shift = + _mm256_add_epi32(right_shift, _mm256_set1_epi32(31)); + const __m256i final_right_shift_low = + _mm256_cvtepi32_epi64(_mm256_extracti128_si256(final_right_shift, 0)); + const __m256i final_right_shift_high = + _mm256_cvtepi32_epi64(_mm256_extracti128_si256(final_right_shift, 1)); + // Really we want 0x100000000, but use half to avoid overflowing. + const __m256i convert_to_signed_halved = + _mm256_srlv_epi32(_mm256_set1_epi32(0x80000000), right_shift); + const __m256i convert_to_unsigned_64 = + _mm256_set1_epi64x(0x8000000000000000); + + __m256i post_scaling_offset = + _mm256_add_epi32(convert_to_signed_halved, convert_to_signed_halved); + + const __m256i offset_vector = + _mm256_slli_epi64(_mm256_set1_epi64x(1), 30); + // Really these should be shifted by neg_e_vector, but tests pass when + // using right_shift. + const __m256i offset_vector_low = _mm256_add_epi64( + _mm256_sllv_epi64( + offset_vector, + _mm256_cvtepi32_epi64(_mm256_extracti128_si256(right_shift, 0))), + convert_to_unsigned_64); + const __m256i offset_vector_high = _mm256_add_epi64( + _mm256_sllv_epi64( + offset_vector, + _mm256_cvtepi32_epi64(_mm256_extracti128_si256(right_shift, 1))), + convert_to_unsigned_64); + + if (params.dst_zero_point) { + const __m256i dst_zero_point = _mm256_set1_epi32(params.dst_zero_point); + // The post-scaling offset is subtracted later, so this has the effect + // of adding the zero point. + post_scaling_offset = + _mm256_sub_epi32(post_scaling_offset, dst_zero_point); + } + +#if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING) + RUY_DCHECK(false); +#endif + const __m256i repack_perm = _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7); + + // See GEMM version for details of this process. + { + __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v0, left_shift); + // Apply the fixed-point part of the multiplier. + __m256i scaled_v_low = _mm256_mul_epi32( + _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)), + m_64bit_low); + __m256i scaled_v_high = _mm256_mul_epi32( + _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)), + m_64bit_high); + + scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector_low); + scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector_high); + + scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low); + scaled_v_high = + _mm256_srlv_epi64(scaled_v_high, final_right_shift_high); + + scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32); + __m256i results = _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa); + results = _mm256_permutevar8x32_epi32(results, repack_perm); + + accum_data_v0 = _mm256_sub_epi32(results, post_scaling_offset); + } + } + const __m256i clamp_max_v = _mm256_set1_epi32(params.clamp_max); + const __m256i clamp_min_v = _mm256_set1_epi32(params.clamp_min); + + if (params.dst_type_id == DstTypeId::kValue) { + std::int8_t* tmp_ptr = static_cast(dst_ptr); + __m256 result = accum_data_v0; + result = _mm256_min_epi32(result, clamp_max_v); + result = _mm256_max_epi32(result, clamp_min_v); + intrin_utils::mm256_n_storeu_cvtepi32_epi8(tmp_ptr, residual_rows, + result); + dst_ptr = static_cast(static_cast(dst_ptr) + + kAvx8bitBlockSize); + } else if (params.dst_type_id == DstTypeId::kValue) { + std::uint8_t* tmp_ptr = static_cast(dst_ptr); + __m256 result = accum_data_v0; + result = _mm256_min_epi32(result, clamp_max_v); + result = _mm256_max_epi32(result, clamp_min_v); + intrin_utils::mm256_n_storeu_cvtepi32_epi8(tmp_ptr, residual_rows, + result); + dst_ptr = static_cast(static_cast(dst_ptr) + + kAvx8bitBlockSize); + } else if (params.dst_type_id == DstTypeId::kValue) { + std::int16_t* tmp_ptr = static_cast(dst_ptr); + __m256 result = accum_data_v0; + result = _mm256_min_epi32(result, clamp_max_v); + result = _mm256_max_epi32(result, clamp_min_v); + intrin_utils::mm256_n_storeu_cvtepi32_epi16(tmp_ptr, residual_rows, + result); + dst_ptr = static_cast(static_cast(dst_ptr) + + kAvx8bitBlockSize); + } else if (params.dst_type_id == DstTypeId::kValue) { + std::int32_t* dst_block_ptr = static_cast(dst_ptr); + intrin_utils::mm256_n_storeu_epi32(dst_block_ptr, residual_rows, + accum_data_v0); + dst_ptr = static_cast(static_cast(dst_ptr) + + kAvx8bitBlockSize); + } else { + RUY_DCHECK(false); + } + + lhs_col_ptr += kAvx8bitBlockSize * params.lhs_stride; + } // End row-block loop. + + dst_col_ptr = static_cast(static_cast(dst_col_ptr) + + kAvx8bitBlockSize * params.dst_stride); + rhs_col_ptr += kAvx8bitBlockSize * params.rhs_stride; +} // NOLINT(readability/fn_size) + +void KernelFloatAvx2(const KernelParamsFloat<8, 8>& params) { + profiler::ScopeLabel label("Kernel kAvx2 float"); + + // As parameters are defined, we need to scale by sizeof(float). + const std::int64_t lhs_stride = params.lhs_stride >> 2; + const std::int64_t dst_stride = params.dst_stride >> 2; + const std::int64_t rhs_stride = params.rhs_stride >> 2; + // + int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 1 : 0; + // AVX2 float block size = 8. + const int end_row = std::min(params.dst_rows, params.last_row + 8); + const int end_col = std::min(params.dst_cols, params.last_col + 8); + // + const float* adj_rhs_col_ptr = + params.rhs_base_ptr - params.start_col * rhs_stride; + float* adj_dst_col_ptr = + params.dst_base_ptr - params.start_col * dst_stride - params.start_row; + const float* adj_lhs_col_ptr = + params.lhs_base_ptr - params.start_row * lhs_stride; + const float* bias_col_ptr = params.bias; + + const __m256 clamp_max_v = _mm256_set1_ps(params.clamp_max); + const __m256 clamp_min_v = _mm256_set1_ps(params.clamp_min); + + int col = params.start_col; + // Loop through cols by float block size, leaving incomplete remainder + for (; col <= end_col - 8; col += 8) { + __m256 accum_data_v[8]; + + const float* rhs_col_ptr = adj_rhs_col_ptr + col * rhs_stride; + float* dst_col_ptr = adj_dst_col_ptr + col * dst_stride; + + for (int row = params.start_row; row < end_row; row += 8) { + const int residual_rows = std::min(end_row - row, 8); + + const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; + float* dst_ptr = dst_col_ptr + row; + const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment; + + // Initialize with bias. + const __m256 initial_accum_data = + intrin_utils::mm256_n_loadu_ps(residual_rows, bias_ptr); + + for (int j = 0; j < 8; ++j) { + accum_data_v[j] = initial_accum_data; + } + + const float* lhs_ptr = lhs_col_ptr; + const float* rhs_ptr = rhs_col_ptr; + for (int d = 0; d < params.depth; ++d) { + const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr); + // In this version RHS values are loaded individually rather than first + // loading together and then extract with broadcasting. This is because + // AVX flavours and instrinsics and compilers in combination do not + // handle this pattern of extraction very well. + const float* rhs_data = rhs_ptr; + + for (int j = 0; j < 8; ++j) { + const __m256 dup_rhs_element_j = _mm256_set1_ps(rhs_data[j]); + accum_data_v[j] = + _mm256_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v[j]); + } + lhs_ptr += 8; + rhs_ptr += 8; + } + + if (residual_rows == 8) { + for (int j = 0; j < 8; ++j) { + float* block_ptr = dst_ptr + j * dst_stride; + accum_data_v[j] = _mm256_min_ps(accum_data_v[j], clamp_max_v); + accum_data_v[j] = _mm256_max_ps(accum_data_v[j], clamp_min_v); + _mm256_storeu_ps(block_ptr, accum_data_v[j]); + } + } else { + for (int j = 0; j < 8; ++j) { + float* block_ptr = dst_ptr + j * dst_stride; + accum_data_v[j] = _mm256_min_ps(accum_data_v[j], clamp_max_v); + accum_data_v[j] = _mm256_max_ps(accum_data_v[j], clamp_min_v); + intrin_utils::mm256_n_storeu_ps(block_ptr, residual_rows, + accum_data_v[j]); + } + } + } // End row-block loop. + } // End col-block loop. + + if (col < end_col) { + // Remaining cols in [0, float block size). + RUY_DCHECK_GE(end_col - col, 0); + RUY_DCHECK_LT(end_col - col, 8); + + __m256 accum_data_v[8]; + + const float* rhs_col_ptr = adj_rhs_col_ptr + col * rhs_stride; + float* dst_col_ptr = adj_dst_col_ptr + col * dst_stride; + const int residual_cols = std::min(end_col - col, 8); + + for (int row = params.start_row; row < end_row; row += 8) { + const int residual_rows = std::min(end_row - row, 8); + + const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; + float* dst_ptr = dst_col_ptr + row; + const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment; + + // Initialize with bias. + const __m256 initial_accum_data = + intrin_utils::mm256_n_loadu_ps(residual_rows, bias_ptr); + + for (int j = 0; j < 8; ++j) { + accum_data_v[j] = initial_accum_data; + } + + const float* lhs_ptr = lhs_col_ptr; + const float* rhs_ptr = rhs_col_ptr; + for (int d = 0; d < params.depth; ++d) { + const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr); + const float* rhs_data = rhs_ptr; + + for (int j = 0; j < 8; ++j) { + const __m256 dup_rhs_element_j = _mm256_set1_ps(rhs_data[j]); + accum_data_v[j] = + _mm256_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v[j]); + } + lhs_ptr += 8; + rhs_ptr += 8; + } + + for (int j = 0; j < residual_cols; ++j) { + float* block_ptr = dst_ptr + j * dst_stride; + accum_data_v[j] = _mm256_min_ps(accum_data_v[j], clamp_max_v); + accum_data_v[j] = _mm256_max_ps(accum_data_v[j], clamp_min_v); + intrin_utils::mm256_n_storeu_ps(block_ptr, residual_rows, + accum_data_v[j]); + } + } // End row-block loop. + } // End col-block terminal conditional. +} + +void KernelFloatAvx2SingleCol(const KernelParamsFloat<8, 8>& params) { + profiler::ScopeLabel label("Kernel kAvx2 float GEMV"); + + RUY_DCHECK_EQ(params.dst_cols, 1); + RUY_DCHECK_EQ(params.last_col, 0); + RUY_DCHECK_EQ(params.start_col, 0); + + // As parameters are defined, we need to scale by sizeof(float). + const std::int64_t lhs_stride = params.lhs_stride >> 2; + // + int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 1 : 0; + // AVX2 float block size = 8. + const int end_row = std::min(params.dst_rows, params.last_row + 8); + + float* adj_dst_col_ptr = params.dst_base_ptr - params.start_row; + const float* adj_lhs_col_ptr = + params.lhs_base_ptr - params.start_row * lhs_stride; + const float* bias_col_ptr = params.bias; + + const __m256 clamp_max_v = _mm256_set1_ps(params.clamp_max); + const __m256 clamp_min_v = _mm256_set1_ps(params.clamp_min); + + __m256 accum_data_v; + + const float* rhs_col_ptr = params.rhs_base_ptr; + float* dst_col_ptr = adj_dst_col_ptr; + + int row = params.start_row; + for (; row <= end_row - 8; row += 8) { + const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; + float* dst_ptr = dst_col_ptr + row; + const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment; + + // Initialize with bias. + accum_data_v = _mm256_loadu_ps(bias_ptr); + + const float* lhs_ptr = lhs_col_ptr; + const float* rhs_ptr = rhs_col_ptr; + int d = 0; + for (; d <= params.depth - 4; d += 4) { + const __m256 lhs_data_0 = _mm256_loadu_ps(lhs_ptr); + const __m256 dup_rhs_element_0 = _mm256_set1_ps(rhs_ptr[0]); + accum_data_v = + _mm256_fmadd_ps(lhs_data_0, dup_rhs_element_0, accum_data_v); + const __m256 dup_rhs_element_1 = _mm256_set1_ps(rhs_ptr[8]); + const __m256 lhs_data_1 = _mm256_loadu_ps(lhs_ptr + 8); + accum_data_v = + _mm256_fmadd_ps(lhs_data_1, dup_rhs_element_1, accum_data_v); + + const __m256 lhs_data_2 = _mm256_loadu_ps(lhs_ptr + 16); + const __m256 dup_rhs_element_2 = _mm256_set1_ps(rhs_ptr[16]); + accum_data_v = + _mm256_fmadd_ps(lhs_data_2, dup_rhs_element_2, accum_data_v); + const __m256 dup_rhs_element_3 = _mm256_set1_ps(rhs_ptr[24]); + const __m256 lhs_data_3 = _mm256_loadu_ps(lhs_ptr + 24); + accum_data_v = + _mm256_fmadd_ps(lhs_data_3, dup_rhs_element_3, accum_data_v); + lhs_ptr += 32; // Loaded 8 * 4 floats. + rhs_ptr += 32; + } + for (; d < params.depth; ++d) { + const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr); + const float* rhs_data = rhs_ptr; + + const __m256 dup_rhs_element_j = _mm256_set1_ps(rhs_data[0]); + accum_data_v = _mm256_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v); + lhs_ptr += 8; + rhs_ptr += 8; + } + + accum_data_v = _mm256_min_ps(accum_data_v, clamp_max_v); + accum_data_v = _mm256_max_ps(accum_data_v, clamp_min_v); + _mm256_storeu_ps(dst_ptr, accum_data_v); + } // End row-block loop. + + if (row < end_row) { + const int residual_rows = end_row - row; + RUY_CHECK_GE(residual_rows, 1); + RUY_CHECK_LT(residual_rows, 8); + + const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; + float* dst_ptr = dst_col_ptr + row; + const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment; + + // Initialize with bias. + accum_data_v = intrin_utils::mm256_n_loadu_ps(residual_rows, bias_ptr); + + const float* lhs_ptr = lhs_col_ptr; + const float* rhs_ptr = rhs_col_ptr; + for (int d = 0; d < params.depth; ++d) { + const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr); + const float* rhs_data = rhs_ptr; + + const __m256 dup_rhs_element_j = _mm256_set1_ps(rhs_data[0]); + accum_data_v = _mm256_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v); + lhs_ptr += 8; + rhs_ptr += 8; + } + + accum_data_v = _mm256_min_ps(accum_data_v, clamp_max_v); + accum_data_v = _mm256_max_ps(accum_data_v, clamp_min_v); + intrin_utils::mm256_n_storeu_ps(dst_ptr, residual_rows, accum_data_v); + } // End handling of residual rows. +} + +#endif // RUY_PLATFORM(AVX2) && RUY_OPT_ENABLED(RUY_OPT_ASM) + +} // namespace ruy diff --git a/ruy/kernel_avx512.cc b/ruy/kernel_avx512.cc new file mode 100644 index 0000000..5e771a5 --- /dev/null +++ b/ruy/kernel_avx512.cc @@ -0,0 +1,1820 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "ruy/check_macros.h" +#include "ruy/kernel.h" +#include "ruy/opt_set.h" +#include "ruy/platform.h" +#include "ruy/profiler/instrumentation.h" + +#if RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_ASM) +#include // IWYU pragma: keep +#endif + +namespace ruy { + +#if !(RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_ASM)) + +void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params) { + // CPU-ID-based checks should disable the path that would reach this point. + RUY_DCHECK(false); +} + +void Kernel8bitAvx512SingleCol(const KernelParams8bit<16, 16>& params) { + // CPU-ID-based checks should disable the path that would reach this point. + RUY_DCHECK(false); +} + +void KernelFloatAvx512(const KernelParamsFloat<16, 16>& params) { + // CPU-ID-based checks should disable the path that would reach this point. + RUY_DCHECK(false); +} + +void KernelFloatAvx512SingleCol(const KernelParamsFloat<16, 16>& params) { + // CPU-ID-based checks should disable the path that would reach this point. + RUY_DCHECK(false); +} + +#else // RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_ASM) + +void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params) { + profiler::ScopeLabel label("Kernel kAvx512 8-bit"); + + std::int32_t dst_stride; + if ((params.dst_type_id == DstTypeId::kValue) || + (params.dst_type_id == DstTypeId::kValue)) { + dst_stride = params.dst_stride; + } else if (params.dst_type_id == DstTypeId::kValue) { + dst_stride = params.dst_stride / sizeof(std::int16_t); + } else if (params.dst_type_id == DstTypeId::kValue) { + dst_stride = params.dst_stride / sizeof(std::int32_t); + } else { + RUY_DCHECK(false); + } + + int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 16 : 0; + + const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; + void* dst_col_ptr = params.dst_base_ptr; + const std::int32_t* bias_col_ptr = params.bias; + if (params.flags & RUY_ASM_FLAG_HAS_BIAS) { + bias_col_ptr += params.start_row; + } + + for (int col = params.start_col; col <= params.last_col; col += 16) { + const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; + void* dst_ptr = dst_col_ptr; + const std::int32_t* bias_ptr = bias_col_ptr; + + const std::int32_t lhs_zero_point = params.lhs_zero_point; + const bool has_rhs_sums_offsets = + (params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && lhs_zero_point; + std::int32_t rhs_sums_offsets[16]; + if (has_rhs_sums_offsets) { + const __m512i rhs_sums_offset_v = + _mm512_mullo_epi32(_mm512_set1_epi32(lhs_zero_point), + _mm512_loadu_epi32(¶ms.rhs_sums[col])); + _mm512_storeu_si512(reinterpret_cast<__m512i*>(rhs_sums_offsets), + rhs_sums_offset_v); + } + + for (int row = params.start_row; row <= params.last_row; row += 16) { + const int residual_rows = std::min(params.dst_rows - row, 16); + const int residual_cols = std::min(params.dst_cols - col, 16); + + __m512i accum_data_v0; + __m512i accum_data_v1; + __m512i accum_data_v2; + __m512i accum_data_v3; + __m512i accum_data_v4; + __m512i accum_data_v5; + __m512i accum_data_v6; + __m512i accum_data_v7; + __m512i accum_data_v8; + __m512i accum_data_v9; + __m512i accum_data_va; + __m512i accum_data_vb; + __m512i accum_data_vc; + __m512i accum_data_vd; + __m512i accum_data_ve; + __m512i accum_data_vf; + + // Initialize with bias. + const __mmask16 row_mask = + (static_cast(1) << residual_rows) - 1; + __m512i initial_accum_data = _mm512_maskz_loadu_epi32(row_mask, bias_ptr); + bias_ptr += bias_ptr_block_increment; + + const std::int32_t rhs_zero_point = params.rhs_zero_point; + if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && rhs_zero_point) { + const __m512i lhs_sums_offset = + _mm512_mullo_epi32(_mm512_set1_epi32(rhs_zero_point), + _mm512_loadu_epi32(¶ms.lhs_sums[row])); + initial_accum_data = + _mm512_sub_epi32(initial_accum_data, lhs_sums_offset); + } + + const std::int32_t prod_zp_depth = params.prod_zp_depth; + if (prod_zp_depth != 0) { + initial_accum_data = _mm512_add_epi32(initial_accum_data, + _mm512_set1_epi32(prod_zp_depth)); + } + + // Adjustments differing across columns. + if (has_rhs_sums_offsets) { + accum_data_v0 = _mm512_sub_epi32( + initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[0])); + accum_data_v1 = _mm512_sub_epi32( + initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[1])); + accum_data_v2 = _mm512_sub_epi32( + initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[2])); + accum_data_v3 = _mm512_sub_epi32( + initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[3])); + accum_data_v4 = _mm512_sub_epi32( + initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[4])); + accum_data_v5 = _mm512_sub_epi32( + initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[5])); + accum_data_v6 = _mm512_sub_epi32( + initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[6])); + accum_data_v7 = _mm512_sub_epi32( + initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[7])); + accum_data_v8 = _mm512_sub_epi32( + initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[8])); + accum_data_v9 = _mm512_sub_epi32( + initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[9])); + accum_data_va = _mm512_sub_epi32( + initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[10])); + accum_data_vb = _mm512_sub_epi32( + initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[11])); + accum_data_vc = _mm512_sub_epi32( + initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[12])); + accum_data_vd = _mm512_sub_epi32( + initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[13])); + accum_data_ve = _mm512_sub_epi32( + initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[14])); + accum_data_vf = _mm512_sub_epi32( + initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[15])); + } else { + accum_data_v0 = initial_accum_data; + accum_data_v1 = initial_accum_data; + accum_data_v2 = initial_accum_data; + accum_data_v3 = initial_accum_data; + accum_data_v4 = initial_accum_data; + accum_data_v5 = initial_accum_data; + accum_data_v6 = initial_accum_data; + accum_data_v7 = initial_accum_data; + accum_data_v8 = initial_accum_data; + accum_data_v9 = initial_accum_data; + accum_data_va = initial_accum_data; + accum_data_vb = initial_accum_data; + accum_data_vc = initial_accum_data; + accum_data_vd = initial_accum_data; + accum_data_ve = initial_accum_data; + accum_data_vf = initial_accum_data; + } + + const std::int8_t* lhs_ptr = lhs_col_ptr; + const std::int8_t* rhs_ptr = rhs_col_ptr; + for (int d = 0; d < params.depth; d += 4) { + const __m512i lhs_data = _mm512_loadu_epi8(lhs_ptr); + __m512i rhs_data_8bit = _mm512_loadu_epi8(rhs_ptr); + + // Each "int32" is two 16-bit RHS values, sign extended from 8-bit. + std::int32_t rhs_data[32]; + const __m256i rhs_data_bottom_lane = + _mm512_castsi512_si256(rhs_data_8bit); + const __m256i rhs_data_top_lane = + _mm512_extracti32x8_epi32(rhs_data_8bit, 1); + const __m512i rhs_16_bit_dup_low = + _mm512_cvtepi8_epi16(rhs_data_bottom_lane); + const __m512i rhs_16_bit_dup_high = + _mm512_cvtepi8_epi16(rhs_data_top_lane); + // Now that we have cast the RHS data, we store it so that each value + // can be separately loaded in the accumulation loop. + _mm512_storeu_si512(reinterpret_cast<__m256i*>(rhs_data), + rhs_16_bit_dup_low); + _mm512_storeu_si512(reinterpret_cast<__m256i*>(rhs_data + 16), + rhs_16_bit_dup_high); + + // Take bytes 0, 1, 4, 5, 8, 9, ... and expand to 16-bit. + const __m512i lhs_16_bit_low = + _mm512_cvtepi8_epi16(_mm512_cvtepi32_epi16(lhs_data)); + // Take bytes 2, 3, 6, 7, 10, 11, ... and expand to 16-bit. + const __m512i lhs_16_bit_high = _mm512_cvtepi8_epi16( + _mm512_cvtepi32_epi16(_mm512_srli_epi32(lhs_data, 16))); + + // Process column 0. + { + __m512i accum_v = accum_data_v0; + constexpr int index = 0; + + const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); + const __m512i rhs_16_bit_dup_high = + _mm512_set1_epi32(rhs_data[index + 1]); + + accum_v = _mm512_add_epi32( + accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); + accum_v = _mm512_add_epi32( + accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); + accum_data_v0 = accum_v; + } + // Process column 1. + { + __m512i accum_v = accum_data_v1; + constexpr int index = 2; + + const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); + const __m512i rhs_16_bit_dup_high = + _mm512_set1_epi32(rhs_data[index + 1]); + + accum_v = _mm512_add_epi32( + accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); + accum_v = _mm512_add_epi32( + accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); + accum_data_v1 = accum_v; + } + // Process column 2. + { + __m512i accum_v = accum_data_v2; + constexpr int index = 4; + + const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); + const __m512i rhs_16_bit_dup_high = + _mm512_set1_epi32(rhs_data[index + 1]); + + accum_v = _mm512_add_epi32( + accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); + accum_v = _mm512_add_epi32( + accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); + accum_data_v2 = accum_v; + } + // Process column 3. + { + __m512i accum_v = accum_data_v3; + constexpr int index = 6; + + const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); + const __m512i rhs_16_bit_dup_high = + _mm512_set1_epi32(rhs_data[index + 1]); + + accum_v = _mm512_add_epi32( + accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); + accum_v = _mm512_add_epi32( + accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); + accum_data_v3 = accum_v; + } + // Process column 4. + { + __m512i accum_v = accum_data_v4; + constexpr int index = 8; + + const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); + const __m512i rhs_16_bit_dup_high = + _mm512_set1_epi32(rhs_data[index + 1]); + + accum_v = _mm512_add_epi32( + accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); + accum_v = _mm512_add_epi32( + accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); + accum_data_v4 = accum_v; + } + // Process column 5. + { + __m512i accum_v = accum_data_v5; + constexpr int index = 10; + + const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); + const __m512i rhs_16_bit_dup_high = + _mm512_set1_epi32(rhs_data[index + 1]); + + accum_v = _mm512_add_epi32( + accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); + accum_v = _mm512_add_epi32( + accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); + accum_data_v5 = accum_v; + } + // Process column 6. + { + __m512i accum_v = accum_data_v6; + constexpr int index = 12; + + const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); + const __m512i rhs_16_bit_dup_high = + _mm512_set1_epi32(rhs_data[index + 1]); + + accum_v = _mm512_add_epi32( + accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); + accum_v = _mm512_add_epi32( + accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); + accum_data_v6 = accum_v; + } + // Process column 7. + { + __m512i accum_v = accum_data_v7; + constexpr int index = 14; + + const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); + const __m512i rhs_16_bit_dup_high = + _mm512_set1_epi32(rhs_data[index + 1]); + + accum_v = _mm512_add_epi32( + accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); + accum_v = _mm512_add_epi32( + accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); + accum_data_v7 = accum_v; + } + // Process column 8. + { + __m512i accum_v = accum_data_v8; + constexpr int index = 16; + + const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); + const __m512i rhs_16_bit_dup_high = + _mm512_set1_epi32(rhs_data[index + 1]); + + accum_v = _mm512_add_epi32( + accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); + accum_v = _mm512_add_epi32( + accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); + accum_data_v8 = accum_v; + } + // Process column 9. + { + __m512i accum_v = accum_data_v9; + constexpr int index = 18; + + const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); + const __m512i rhs_16_bit_dup_high = + _mm512_set1_epi32(rhs_data[index + 1]); + + accum_v = _mm512_add_epi32( + accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); + accum_v = _mm512_add_epi32( + accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); + accum_data_v9 = accum_v; + } + // Process column 10. + { + __m512i accum_v = accum_data_va; + constexpr int index = 20; + + const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); + const __m512i rhs_16_bit_dup_high = + _mm512_set1_epi32(rhs_data[index + 1]); + + accum_v = _mm512_add_epi32( + accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); + accum_v = _mm512_add_epi32( + accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); + accum_data_va = accum_v; + } + // Process column 11. + { + __m512i accum_v = accum_data_vb; + constexpr int index = 22; + + const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); + const __m512i rhs_16_bit_dup_high = + _mm512_set1_epi32(rhs_data[index + 1]); + + accum_v = _mm512_add_epi32( + accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); + accum_v = _mm512_add_epi32( + accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); + accum_data_vb = accum_v; + } + // Process column 12. + { + __m512i accum_v = accum_data_vc; + constexpr int index = 24; + + const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); + const __m512i rhs_16_bit_dup_high = + _mm512_set1_epi32(rhs_data[index + 1]); + + accum_v = _mm512_add_epi32( + accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); + accum_v = _mm512_add_epi32( + accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); + accum_data_vc = accum_v; + } + // Process column 13. + { + __m512i accum_v = accum_data_vd; + constexpr int index = 26; + + const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); + const __m512i rhs_16_bit_dup_high = + _mm512_set1_epi32(rhs_data[index + 1]); + + accum_v = _mm512_add_epi32( + accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); + accum_v = _mm512_add_epi32( + accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); + accum_data_vd = accum_v; + } + // Process column 14. + { + __m512i accum_v = accum_data_ve; + constexpr int index = 28; + + const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); + const __m512i rhs_16_bit_dup_high = + _mm512_set1_epi32(rhs_data[index + 1]); + + accum_v = _mm512_add_epi32( + accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); + accum_v = _mm512_add_epi32( + accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); + accum_data_ve = accum_v; + } + // Process column 15. + { + __m512i accum_v = accum_data_vf; + constexpr int index = 30; + + const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); + const __m512i rhs_16_bit_dup_high = + _mm512_set1_epi32(rhs_data[index + 1]); + + accum_v = _mm512_add_epi32( + accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); + accum_v = _mm512_add_epi32( + accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); + accum_data_vf = accum_v; + } + + lhs_ptr += 16 * 4; + rhs_ptr += 16 * 4; + } + + if (params.dst_type_id != DstTypeId::kValue) { + __m512i m_vector; + __m512i e_vector; + // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT. + if (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) { + m_vector = _mm512_maskz_loadu_epi32( + row_mask, ¶ms.multiplier_fixedpoint[row]); + e_vector = _mm512_maskz_loadu_epi32(row_mask, + ¶ms.multiplier_exponent[row]); + } else { + // These arrays have size LhsCols, and are pre-filled. + m_vector = _mm512_set1_epi32(params.multiplier_fixedpoint[0]); + e_vector = _mm512_set1_epi32(params.multiplier_exponent[0]); + } + + const __m512i m_64bit_low = + _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(m_vector, 0)); + const __m512i m_64bit_high = + _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(m_vector, 1)); + + const __m512i zero_vector = _mm512_setzero_epi32(); + const __m512i left_shift = _mm512_max_epi32(e_vector, zero_vector); + const __m512i neg_e_vector = _mm512_sub_epi32(zero_vector, e_vector); + const __m512i right_shift = _mm512_max_epi32(neg_e_vector, zero_vector); + const __m512i final_right_shift = + _mm512_add_epi32(right_shift, _mm512_set1_epi32(31)); + const __m512i final_right_shift_low = _mm512_cvtepi32_epi64( + _mm512_extracti32x8_epi32(final_right_shift, 0)); + const __m512i final_right_shift_high = _mm512_cvtepi32_epi64( + _mm512_extracti32x8_epi32(final_right_shift, 1)); + + const __m512i offset_vector = + _mm512_slli_epi64(_mm512_set1_epi64(1), 30); + // Really these should be shifted by neg_e_vector, but tests pass when + // using right_shift. + const __m512i offset_vector_low = _mm512_sllv_epi64( + offset_vector, + _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(right_shift, 0))); + const __m512i offset_vector_high = _mm512_sllv_epi64( + offset_vector, + _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(right_shift, 1))); + + // Shift and round column 0. + { + accum_data_v0 = _mm512_sllv_epi32(accum_data_v0, left_shift); + // Apply the fixed-point part of the multiplier. + __m512i scaled_v_low = + _mm512_mul_epi32(_mm512_cvtepi32_epi64( + _mm512_extracti32x8_epi32(accum_data_v0, 0)), + m_64bit_low); + __m512i scaled_v_high = + _mm512_mul_epi32(_mm512_cvtepi32_epi64( + _mm512_extracti32x8_epi32(accum_data_v0, 1)), + m_64bit_high); + + scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); + scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); + + scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); + scaled_v_high = + _mm512_srav_epi64(scaled_v_high, final_right_shift_high); + + accum_data_v0 = + _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); + accum_data_v0 = _mm512_inserti32x8( + accum_data_v0, _mm512_cvtepi64_epi32(scaled_v_high), 1); + } + // Shift and round column 1. + { + accum_data_v1 = _mm512_sllv_epi32(accum_data_v1, left_shift); + // Apply the fixed-point part of the multiplier. + __m512i scaled_v_low = + _mm512_mul_epi32(_mm512_cvtepi32_epi64( + _mm512_extracti32x8_epi32(accum_data_v1, 0)), + m_64bit_low); + __m512i scaled_v_high = + _mm512_mul_epi32(_mm512_cvtepi32_epi64( + _mm512_extracti32x8_epi32(accum_data_v1, 1)), + m_64bit_high); + + scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); + scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); + + scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); + scaled_v_high = + _mm512_srav_epi64(scaled_v_high, final_right_shift_high); + + accum_data_v1 = + _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); + accum_data_v1 = _mm512_inserti32x8( + accum_data_v1, _mm512_cvtepi64_epi32(scaled_v_high), 1); + } + // Shift and round column 2. + { + accum_data_v2 = _mm512_sllv_epi32(accum_data_v2, left_shift); + // Apply the fixed-point part of the multiplier. + __m512i scaled_v_low = + _mm512_mul_epi32(_mm512_cvtepi32_epi64( + _mm512_extracti32x8_epi32(accum_data_v2, 0)), + m_64bit_low); + __m512i scaled_v_high = + _mm512_mul_epi32(_mm512_cvtepi32_epi64( + _mm512_extracti32x8_epi32(accum_data_v2, 1)), + m_64bit_high); + + scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); + scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); + + scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); + scaled_v_high = + _mm512_srav_epi64(scaled_v_high, final_right_shift_high); + + accum_data_v2 = + _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); + accum_data_v2 = _mm512_inserti32x8( + accum_data_v2, _mm512_cvtepi64_epi32(scaled_v_high), 1); + } + // Shift and round column 3. + { + accum_data_v3 = _mm512_sllv_epi32(accum_data_v3, left_shift); + // Apply the fixed-point part of the multiplier. + __m512i scaled_v_low = + _mm512_mul_epi32(_mm512_cvtepi32_epi64( + _mm512_extracti32x8_epi32(accum_data_v3, 0)), + m_64bit_low); + __m512i scaled_v_high = + _mm512_mul_epi32(_mm512_cvtepi32_epi64( + _mm512_extracti32x8_epi32(accum_data_v3, 1)), + m_64bit_high); + + scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); + scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); + + scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); + scaled_v_high = + _mm512_srav_epi64(scaled_v_high, final_right_shift_high); + + accum_data_v3 = + _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); + accum_data_v3 = _mm512_inserti32x8( + accum_data_v3, _mm512_cvtepi64_epi32(scaled_v_high), 1); + } + // Shift and round column 4. + { + accum_data_v4 = _mm512_sllv_epi32(accum_data_v4, left_shift); + // Apply the fixed-point part of the multiplier. + __m512i scaled_v_low = + _mm512_mul_epi32(_mm512_cvtepi32_epi64( + _mm512_extracti32x8_epi32(accum_data_v4, 0)), + m_64bit_low); + __m512i scaled_v_high = + _mm512_mul_epi32(_mm512_cvtepi32_epi64( + _mm512_extracti32x8_epi32(accum_data_v4, 1)), + m_64bit_high); + + scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); + scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); + + scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); + scaled_v_high = + _mm512_srav_epi64(scaled_v_high, final_right_shift_high); + + accum_data_v4 = + _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); + accum_data_v4 = _mm512_inserti32x8( + accum_data_v4, _mm512_cvtepi64_epi32(scaled_v_high), 1); + } + // Shift and round column 5. + { + accum_data_v5 = _mm512_sllv_epi32(accum_data_v5, left_shift); + // Apply the fixed-point part of the multiplier. + __m512i scaled_v_low = + _mm512_mul_epi32(_mm512_cvtepi32_epi64( + _mm512_extracti32x8_epi32(accum_data_v5, 0)), + m_64bit_low); + __m512i scaled_v_high = + _mm512_mul_epi32(_mm512_cvtepi32_epi64( + _mm512_extracti32x8_epi32(accum_data_v5, 1)), + m_64bit_high); + + scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); + scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); + + scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); + scaled_v_high = + _mm512_srav_epi64(scaled_v_high, final_right_shift_high); + + accum_data_v5 = + _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); + accum_data_v5 = _mm512_inserti32x8( + accum_data_v5, _mm512_cvtepi64_epi32(scaled_v_high), 1); + } + // Shift and round column 6. + { + accum_data_v6 = _mm512_sllv_epi32(accum_data_v6, left_shift); + // Apply the fixed-point part of the multiplier. + __m512i scaled_v_low = + _mm512_mul_epi32(_mm512_cvtepi32_epi64( + _mm512_extracti32x8_epi32(accum_data_v6, 0)), + m_64bit_low); + __m512i scaled_v_high = + _mm512_mul_epi32(_mm512_cvtepi32_epi64( + _mm512_extracti32x8_epi32(accum_data_v6, 1)), + m_64bit_high); + + scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); + scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); + + scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); + scaled_v_high = + _mm512_srav_epi64(scaled_v_high, final_right_shift_high); + + accum_data_v6 = + _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); + accum_data_v6 = _mm512_inserti32x8( + accum_data_v6, _mm512_cvtepi64_epi32(scaled_v_high), 1); + } + // Shift and round column 7. + { + accum_data_v7 = _mm512_sllv_epi32(accum_data_v7, left_shift); + // Apply the fixed-point part of the multiplier. + __m512i scaled_v_low = + _mm512_mul_epi32(_mm512_cvtepi32_epi64( + _mm512_extracti32x8_epi32(accum_data_v7, 0)), + m_64bit_low); + __m512i scaled_v_high = + _mm512_mul_epi32(_mm512_cvtepi32_epi64( + _mm512_extracti32x8_epi32(accum_data_v7, 1)), + m_64bit_high); + + scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); + scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); + + scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); + scaled_v_high = + _mm512_srav_epi64(scaled_v_high, final_right_shift_high); + + accum_data_v7 = + _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); + accum_data_v7 = _mm512_inserti32x8( + accum_data_v7, _mm512_cvtepi64_epi32(scaled_v_high), 1); + } + // Shift and round column 8. + { + accum_data_v8 = _mm512_sllv_epi32(accum_data_v8, left_shift); + // Apply the fixed-point part of the multiplier. + __m512i scaled_v_low = + _mm512_mul_epi32(_mm512_cvtepi32_epi64( + _mm512_extracti32x8_epi32(accum_data_v8, 0)), + m_64bit_low); + __m512i scaled_v_high = + _mm512_mul_epi32(_mm512_cvtepi32_epi64( + _mm512_extracti32x8_epi32(accum_data_v8, 1)), + m_64bit_high); + + scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); + scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); + + scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); + scaled_v_high = + _mm512_srav_epi64(scaled_v_high, final_right_shift_high); + + accum_data_v8 = + _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); + accum_data_v8 = _mm512_inserti32x8( + accum_data_v8, _mm512_cvtepi64_epi32(scaled_v_high), 1); + } + // Shift and round column 9. + { + accum_data_v9 = _mm512_sllv_epi32(accum_data_v9, left_shift); + // Apply the fixed-point part of the multiplier. + __m512i scaled_v_low = + _mm512_mul_epi32(_mm512_cvtepi32_epi64( + _mm512_extracti32x8_epi32(accum_data_v9, 0)), + m_64bit_low); + __m512i scaled_v_high = + _mm512_mul_epi32(_mm512_cvtepi32_epi64( + _mm512_extracti32x8_epi32(accum_data_v9, 1)), + m_64bit_high); + + scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); + scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); + + scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); + scaled_v_high = + _mm512_srav_epi64(scaled_v_high, final_right_shift_high); + + accum_data_v9 = + _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); + accum_data_v9 = _mm512_inserti32x8( + accum_data_v9, _mm512_cvtepi64_epi32(scaled_v_high), 1); + } + // Shift and round column 10. + { + accum_data_va = _mm512_sllv_epi32(accum_data_va, left_shift); + // Apply the fixed-point part of the multiplier. + __m512i scaled_v_low = + _mm512_mul_epi32(_mm512_cvtepi32_epi64( + _mm512_extracti32x8_epi32(accum_data_va, 0)), + m_64bit_low); + __m512i scaled_v_high = + _mm512_mul_epi32(_mm512_cvtepi32_epi64( + _mm512_extracti32x8_epi32(accum_data_va, 1)), + m_64bit_high); + + scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); + scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); + + scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); + scaled_v_high = + _mm512_srav_epi64(scaled_v_high, final_right_shift_high); + + accum_data_va = + _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); + accum_data_va = _mm512_inserti32x8( + accum_data_va, _mm512_cvtepi64_epi32(scaled_v_high), 1); + } + // Shift and round column 11. + { + accum_data_vb = _mm512_sllv_epi32(accum_data_vb, left_shift); + // Apply the fixed-point part of the multiplier. + __m512i scaled_v_low = + _mm512_mul_epi32(_mm512_cvtepi32_epi64( + _mm512_extracti32x8_epi32(accum_data_vb, 0)), + m_64bit_low); + __m512i scaled_v_high = + _mm512_mul_epi32(_mm512_cvtepi32_epi64( + _mm512_extracti32x8_epi32(accum_data_vb, 1)), + m_64bit_high); + + scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); + scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); + + scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); + scaled_v_high = + _mm512_srav_epi64(scaled_v_high, final_right_shift_high); + + accum_data_vb = + _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); + accum_data_vb = _mm512_inserti32x8( + accum_data_vb, _mm512_cvtepi64_epi32(scaled_v_high), 1); + } + // Shift and round column 12. + { + accum_data_vc = _mm512_sllv_epi32(accum_data_vc, left_shift); + // Apply the fixed-point part of the multiplier. + __m512i scaled_v_low = + _mm512_mul_epi32(_mm512_cvtepi32_epi64( + _mm512_extracti32x8_epi32(accum_data_vc, 0)), + m_64bit_low); + __m512i scaled_v_high = + _mm512_mul_epi32(_mm512_cvtepi32_epi64( + _mm512_extracti32x8_epi32(accum_data_vc, 1)), + m_64bit_high); + + scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); + scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); + + scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); + scaled_v_high = + _mm512_srav_epi64(scaled_v_high, final_right_shift_high); + + accum_data_vc = + _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); + accum_data_vc = _mm512_inserti32x8( + accum_data_vc, _mm512_cvtepi64_epi32(scaled_v_high), 1); + } + // Shift and round column 13. + { + accum_data_vd = _mm512_sllv_epi32(accum_data_vd, left_shift); + // Apply the fixed-point part of the multiplier. + __m512i scaled_v_low = + _mm512_mul_epi32(_mm512_cvtepi32_epi64( + _mm512_extracti32x8_epi32(accum_data_vd, 0)), + m_64bit_low); + __m512i scaled_v_high = + _mm512_mul_epi32(_mm512_cvtepi32_epi64( + _mm512_extracti32x8_epi32(accum_data_vd, 1)), + m_64bit_high); + + scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); + scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); + + scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); + scaled_v_high = + _mm512_srav_epi64(scaled_v_high, final_right_shift_high); + + accum_data_vd = + _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); + accum_data_vd = _mm512_inserti32x8( + accum_data_vd, _mm512_cvtepi64_epi32(scaled_v_high), 1); + } + // Shift and round column 14. + { + accum_data_ve = _mm512_sllv_epi32(accum_data_ve, left_shift); + // Apply the fixed-point part of the multiplier. + __m512i scaled_v_low = + _mm512_mul_epi32(_mm512_cvtepi32_epi64( + _mm512_extracti32x8_epi32(accum_data_ve, 0)), + m_64bit_low); + __m512i scaled_v_high = + _mm512_mul_epi32(_mm512_cvtepi32_epi64( + _mm512_extracti32x8_epi32(accum_data_ve, 1)), + m_64bit_high); + + scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); + scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); + + scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); + scaled_v_high = + _mm512_srav_epi64(scaled_v_high, final_right_shift_high); + + accum_data_ve = + _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); + accum_data_ve = _mm512_inserti32x8( + accum_data_ve, _mm512_cvtepi64_epi32(scaled_v_high), 1); + } + // Shift and round column 15. + { + accum_data_vf = _mm512_sllv_epi32(accum_data_vf, left_shift); + // Apply the fixed-point part of the multiplier. + __m512i scaled_v_low = + _mm512_mul_epi32(_mm512_cvtepi32_epi64( + _mm512_extracti32x8_epi32(accum_data_vf, 0)), + m_64bit_low); + __m512i scaled_v_high = + _mm512_mul_epi32(_mm512_cvtepi32_epi64( + _mm512_extracti32x8_epi32(accum_data_vf, 1)), + m_64bit_high); + + scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); + scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); + + scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); + scaled_v_high = + _mm512_srav_epi64(scaled_v_high, final_right_shift_high); + + accum_data_vf = + _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); + accum_data_vf = _mm512_inserti32x8( + accum_data_vf, _mm512_cvtepi64_epi32(scaled_v_high), 1); + } +#if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING) + RUY_DCHECK(false); +#endif + + if (params.dst_zero_point != 0) { + __m512i dst_zero_point = _mm512_set1_epi32(params.dst_zero_point); + accum_data_v0 = _mm512_add_epi32(accum_data_v0, dst_zero_point); + accum_data_v1 = _mm512_add_epi32(accum_data_v1, dst_zero_point); + accum_data_v2 = _mm512_add_epi32(accum_data_v2, dst_zero_point); + accum_data_v3 = _mm512_add_epi32(accum_data_v3, dst_zero_point); + accum_data_v4 = _mm512_add_epi32(accum_data_v4, dst_zero_point); + accum_data_v5 = _mm512_add_epi32(accum_data_v5, dst_zero_point); + accum_data_v6 = _mm512_add_epi32(accum_data_v6, dst_zero_point); + accum_data_v7 = _mm512_add_epi32(accum_data_v7, dst_zero_point); + accum_data_v8 = _mm512_add_epi32(accum_data_v8, dst_zero_point); + accum_data_v9 = _mm512_add_epi32(accum_data_v9, dst_zero_point); + accum_data_va = _mm512_add_epi32(accum_data_va, dst_zero_point); + accum_data_vb = _mm512_add_epi32(accum_data_vb, dst_zero_point); + accum_data_vc = _mm512_add_epi32(accum_data_vc, dst_zero_point); + accum_data_vd = _mm512_add_epi32(accum_data_vd, dst_zero_point); + accum_data_ve = _mm512_add_epi32(accum_data_ve, dst_zero_point); + accum_data_vf = _mm512_add_epi32(accum_data_vf, dst_zero_point); + } + } + + const __m512i clamp_max_v = _mm512_set1_epi32(params.clamp_max); + const __m512i clamp_min_v = _mm512_set1_epi32(params.clamp_min); + + const bool store_full_block = + (residual_rows == 16) && (residual_cols == 16); + + __m512i accum_data_v[16]; + + // In most cases we would make this conditional on (!store_full_block) and + // unwind the clamp-and-store loop, but the benefit appears small. + { + accum_data_v[0] = accum_data_v0; + accum_data_v[1] = accum_data_v1; + accum_data_v[2] = accum_data_v2; + accum_data_v[3] = accum_data_v3; + accum_data_v[4] = accum_data_v4; + accum_data_v[5] = accum_data_v5; + accum_data_v[6] = accum_data_v6; + accum_data_v[7] = accum_data_v7; + accum_data_v[8] = accum_data_v8; + accum_data_v[9] = accum_data_v9; + accum_data_v[10] = accum_data_va; + accum_data_v[11] = accum_data_vb; + accum_data_v[12] = accum_data_vc; + accum_data_v[13] = accum_data_vd; + accum_data_v[14] = accum_data_ve; + accum_data_v[15] = accum_data_vf; + } + + if (params.dst_type_id == DstTypeId::kValue) { + std::int8_t* tmp_ptr = static_cast(dst_ptr); + const int block_col_offset = dst_stride; + if (store_full_block) { + for (int j = 0; j < 16; ++j) { + __m512i result = accum_data_v[j]; + result = _mm512_min_epi32(result, clamp_max_v); + result = _mm512_max_epi32(result, clamp_min_v); + _mm_storeu_epi8(tmp_ptr + j * block_col_offset, + _mm512_cvtepi32_epi8(result)); + } + } else { + for (int j = 0; j < residual_cols; ++j) { + __m512i result = accum_data_v[j]; + result = _mm512_min_epi32(result, clamp_max_v); + result = _mm512_max_epi32(result, clamp_min_v); + _mm_mask_storeu_epi8(tmp_ptr + j * block_col_offset, row_mask, + _mm512_cvtepi32_epi8(result)); + } + } + dst_ptr = static_cast(static_cast(dst_ptr) + 16); + } else if (params.dst_type_id == DstTypeId::kValue) { + std::uint8_t* tmp_ptr = static_cast(dst_ptr); + const int block_col_offset = dst_stride; + if (store_full_block) { + for (int j = 0; j < residual_cols; ++j) { + __m512i result = accum_data_v[j]; + result = _mm512_min_epi32(result, clamp_max_v); + result = _mm512_max_epi32(result, clamp_min_v); + _mm_storeu_epi8(tmp_ptr + j * block_col_offset, + _mm512_cvtepi32_epi8(result)); + } + } else { + for (int j = 0; j < residual_cols; ++j) { + __m512i result = accum_data_v[j]; + result = _mm512_min_epi32(result, clamp_max_v); + result = _mm512_max_epi32(result, clamp_min_v); + _mm_mask_storeu_epi8(tmp_ptr + j * block_col_offset, row_mask, + _mm512_cvtepi32_epi8(result)); + } + } + dst_ptr = static_cast(static_cast(dst_ptr) + 16); + } else if (params.dst_type_id == DstTypeId::kValue) { + std::int16_t* tmp_ptr = static_cast(dst_ptr); + const int block_col_offset = dst_stride; + if (store_full_block) { + for (int j = 0; j < 16; ++j) { + __m512i result = accum_data_v[j]; + result = _mm512_min_epi32(result, clamp_max_v); + result = _mm512_max_epi32(result, clamp_min_v); + _mm256_storeu_epi16(tmp_ptr + j * block_col_offset, + _mm512_cvtepi32_epi16(result)); + } + } else { + for (int j = 0; j < residual_cols; ++j) { + __m512i result = accum_data_v[j]; + result = _mm512_min_epi32(result, clamp_max_v); + result = _mm512_max_epi32(result, clamp_min_v); + _mm256_mask_storeu_epi16(tmp_ptr + j * block_col_offset, row_mask, + _mm512_cvtepi32_epi16(result)); + } + } + dst_ptr = static_cast(static_cast(dst_ptr) + 16); + } else if (params.dst_type_id == DstTypeId::kValue) { + if (store_full_block) { + std::int32_t* tmp_ptr = static_cast(dst_ptr); + for (int j = 0; j < 16; ++j) { + _mm512_storeu_epi32(tmp_ptr + j * dst_stride, accum_data_v[j]); + } + } else { + std::int32_t* tmp_ptr = static_cast(dst_ptr); + for (int j = 0; j < residual_cols; ++j) { + _mm512_mask_storeu_epi32(tmp_ptr + j * dst_stride, row_mask, + accum_data_v[j]); + } + } + dst_ptr = static_cast(static_cast(dst_ptr) + 16); + } else { + RUY_DCHECK(false); + } + + lhs_col_ptr += 16 * params.lhs_stride; + } // End row-block loop. + + dst_col_ptr = static_cast(static_cast(dst_col_ptr) + + 16 * params.dst_stride); + rhs_col_ptr += 16 * params.rhs_stride; + } // End col-block loop. +} // NOLINT(readability/fn_size) + +void Kernel8bitAvx512SingleCol(const KernelParams8bit<16, 16>& params) { + profiler::ScopeLabel label("Kernel kAvx512 8-bit GEMV"); + + RUY_DCHECK_EQ(params.dst_cols, 1); + RUY_DCHECK_EQ(params.last_col, 0); + RUY_DCHECK_EQ(params.start_col, 0); + + std::int32_t dst_stride; + if ((params.dst_type_id == DstTypeId::kValue) || + (params.dst_type_id == DstTypeId::kValue)) { + dst_stride = params.dst_stride; + } else if (params.dst_type_id == DstTypeId::kValue) { + dst_stride = params.dst_stride / sizeof(std::int16_t); + } else if (params.dst_type_id == DstTypeId::kValue) { + dst_stride = params.dst_stride / sizeof(std::int32_t); + } else { + RUY_DCHECK(false); + } + + int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 16 : 0; + + const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; + void* dst_col_ptr = params.dst_base_ptr; + const std::int32_t* bias_col_ptr = params.bias; + if (params.flags & RUY_ASM_FLAG_HAS_BIAS) { + bias_col_ptr += params.start_row; + } + + const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; + void* dst_ptr = dst_col_ptr; + const std::int32_t* bias_ptr = bias_col_ptr; + + const std::int32_t lhs_zero_point = params.lhs_zero_point; + const bool has_rhs_sums_offsets = + (params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && lhs_zero_point; + std::int32_t rhs_sums_offsets[16]; + if (has_rhs_sums_offsets) { + const __m512i rhs_sums_offset_v = + _mm512_mullo_epi32(_mm512_set1_epi32(lhs_zero_point), + _mm512_loadu_epi32(¶ms.rhs_sums[0])); + _mm512_storeu_si512(reinterpret_cast<__m512i*>(rhs_sums_offsets), + rhs_sums_offset_v); + } + + for (int row = params.start_row; row <= params.last_row; row += 16) { + const int residual_rows = std::min(params.dst_rows - row, 16); + + __m512i accum_data_v0; + + // Initialize with bias. + const __mmask16 row_mask = + (static_cast(1) << residual_rows) - 1; + __m512i initial_accum_data = _mm512_maskz_loadu_epi32(row_mask, bias_ptr); + bias_ptr += bias_ptr_block_increment; + + const std::int32_t rhs_zero_point = params.rhs_zero_point; + if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && rhs_zero_point) { + const __m512i lhs_sums_offset = + _mm512_mullo_epi32(_mm512_set1_epi32(rhs_zero_point), + _mm512_loadu_epi32(¶ms.lhs_sums[row])); + initial_accum_data = + _mm512_sub_epi32(initial_accum_data, lhs_sums_offset); + } + + const std::int32_t prod_zp_depth = params.prod_zp_depth; + if (prod_zp_depth != 0) { + initial_accum_data = _mm512_add_epi32(initial_accum_data, + _mm512_set1_epi32(prod_zp_depth)); + } + + // Adjustments differing across columns. + if (has_rhs_sums_offsets) { + accum_data_v0 = _mm512_sub_epi32(initial_accum_data, + _mm512_set1_epi32(rhs_sums_offsets[0])); + } else { + accum_data_v0 = initial_accum_data; + } + + const std::int8_t* lhs_ptr = lhs_col_ptr; + const std::int8_t* rhs_ptr = rhs_col_ptr; + for (int d = 0; d < params.depth; d += 4) { + const __m512i lhs_data = _mm512_loadu_epi8(lhs_ptr); + const __m128i rhs_data_8bit = _mm_loadu_epi8(rhs_ptr); + + // Each "int32" is two 16-bit RHS values, sign extended from 8-bit. + // For simplicity we load 4x the data that we need and process twice the + // data that we need and store only the data we need. + std::int32_t rhs_data[2]; + const __m128i rhs_16_bit_dup = _mm_cvtepi8_epi16(rhs_data_8bit); + // Now that we have cast the RHS data, we store it so that each value + // can be separately loaded in the accumulation loop. + _mm_storeu_si64(reinterpret_cast<__m128i*>(rhs_data), rhs_16_bit_dup); + + // Take bytes 0, 1, 4, 5, 8, 9, ... and expand to 16-bit. + const __m512i lhs_16_bit_low = + _mm512_cvtepi8_epi16(_mm512_cvtepi32_epi16(lhs_data)); + // Take bytes 2, 3, 6, 7, 10, 11, ... and expand to 16-bit. + const __m512i lhs_16_bit_high = _mm512_cvtepi8_epi16( + _mm512_cvtepi32_epi16(_mm512_srli_epi32(lhs_data, 16))); + + // Process column 0. + __m512i accum_v = accum_data_v0; + constexpr int index = 0; + + const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); + const __m512i rhs_16_bit_dup_high = + _mm512_set1_epi32(rhs_data[index + 1]); + + accum_v = _mm512_add_epi32( + accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); + accum_v = _mm512_add_epi32( + accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); + accum_data_v0 = accum_v; + + lhs_ptr += 16 * 4; + rhs_ptr += 16 * 4; + } + + if (params.dst_type_id != DstTypeId::kValue) { + __m512i m_vector; + __m512i e_vector; + // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT. + if (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) { + m_vector = _mm512_maskz_loadu_epi32(row_mask, + ¶ms.multiplier_fixedpoint[row]); + e_vector = _mm512_maskz_loadu_epi32(row_mask, + ¶ms.multiplier_exponent[row]); + } else { + // These arrays have size LhsCols, and are pre-filled. + m_vector = _mm512_set1_epi32(params.multiplier_fixedpoint[0]); + e_vector = _mm512_set1_epi32(params.multiplier_exponent[0]); + } + + const __m512i m_64bit_low = + _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(m_vector, 0)); + const __m512i m_64bit_high = + _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(m_vector, 1)); + + const __m512i zero_vector = _mm512_setzero_epi32(); + const __m512i left_shift = _mm512_max_epi32(e_vector, zero_vector); + const __m512i neg_e_vector = _mm512_sub_epi32(zero_vector, e_vector); + const __m512i right_shift = _mm512_max_epi32(neg_e_vector, zero_vector); + const __m512i final_right_shift = + _mm512_add_epi32(right_shift, _mm512_set1_epi32(31)); + const __m512i final_right_shift_low = _mm512_cvtepi32_epi64( + _mm512_extracti32x8_epi32(final_right_shift, 0)); + const __m512i final_right_shift_high = _mm512_cvtepi32_epi64( + _mm512_extracti32x8_epi32(final_right_shift, 1)); + + const __m512i offset_vector = _mm512_slli_epi64(_mm512_set1_epi64(1), 30); + // Really these should be shifted by neg_e_vector, but tests pass when + // using right_shift. + const __m512i offset_vector_low = _mm512_sllv_epi64( + offset_vector, + _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(right_shift, 0))); + const __m512i offset_vector_high = _mm512_sllv_epi64( + offset_vector, + _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(right_shift, 1))); + + // Shift and round column 0. + accum_data_v0 = _mm512_sllv_epi32(accum_data_v0, left_shift); + // Apply the fixed-point part of the multiplier. + __m512i scaled_v_low = _mm512_mul_epi32( + _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(accum_data_v0, 0)), + m_64bit_low); + __m512i scaled_v_high = _mm512_mul_epi32( + _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(accum_data_v0, 1)), + m_64bit_high); + + scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); + scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); + + scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); + scaled_v_high = _mm512_srav_epi64(scaled_v_high, final_right_shift_high); + + accum_data_v0 = + _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); + accum_data_v0 = _mm512_inserti32x8( + accum_data_v0, _mm512_cvtepi64_epi32(scaled_v_high), 1); +#if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING) + RUY_DCHECK(false); +#endif + + if (params.dst_zero_point != 0) { + __m512i dst_zero_point = _mm512_set1_epi32(params.dst_zero_point); + accum_data_v0 = _mm512_add_epi32(accum_data_v0, dst_zero_point); + } + } + + const __m512i clamp_max_v = _mm512_set1_epi32(params.clamp_max); + const __m512i clamp_min_v = _mm512_set1_epi32(params.clamp_min); + + if (params.dst_type_id == DstTypeId::kValue) { + std::int8_t* tmp_ptr = static_cast(dst_ptr); + __m512i result = accum_data_v0; + result = _mm512_min_epi32(result, clamp_max_v); + result = _mm512_max_epi32(result, clamp_min_v); + _mm_mask_storeu_epi8(tmp_ptr, row_mask, _mm512_cvtepi32_epi8(result)); + dst_ptr = static_cast(static_cast(dst_ptr) + 16); + } else if (params.dst_type_id == DstTypeId::kValue) { + std::uint8_t* tmp_ptr = static_cast(dst_ptr); + __m512i result = accum_data_v0; + result = _mm512_min_epi32(result, clamp_max_v); + result = _mm512_max_epi32(result, clamp_min_v); + _mm_mask_storeu_epi8(tmp_ptr, row_mask, _mm512_cvtepi32_epi8(result)); + dst_ptr = static_cast(static_cast(dst_ptr) + 16); + } else if (params.dst_type_id == DstTypeId::kValue) { + std::int16_t* tmp_ptr = static_cast(dst_ptr); + __m512i result = accum_data_v0; + result = _mm512_min_epi32(result, clamp_max_v); + result = _mm512_max_epi32(result, clamp_min_v); + _mm256_mask_storeu_epi16(tmp_ptr, row_mask, + _mm512_cvtepi32_epi16(result)); + dst_ptr = static_cast(static_cast(dst_ptr) + 16); + } else if (params.dst_type_id == DstTypeId::kValue) { + std::int32_t* tmp_ptr = static_cast(dst_ptr); + _mm512_mask_storeu_epi32(tmp_ptr, row_mask, accum_data_v0); + dst_ptr = static_cast(static_cast(dst_ptr) + 16); + } else { + RUY_DCHECK(false); + } + + lhs_col_ptr += 16 * params.lhs_stride; + } // End row-block loop. +} // NOLINT(readability/fn_size) + +void KernelFloatAvx512(const KernelParamsFloat<16, 16>& params) { + profiler::ScopeLabel label("Kernel kAvx512 float"); + + // As parameters are defined, we need to scale by sizeof(float). + const std::int64_t lhs_stride = params.lhs_stride >> 2; + const std::int64_t dst_stride = params.dst_stride >> 2; + const std::int64_t rhs_stride = params.rhs_stride >> 2; + + int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 1 : 0; + const int end_row = std::min(params.dst_rows, params.last_row + 16); + const int end_col = std::min(params.dst_cols, params.last_col + 16); + + const float* adj_rhs_col_ptr = + params.rhs_base_ptr - params.start_col * rhs_stride; + float* adj_dst_col_ptr = + params.dst_base_ptr - params.start_col * dst_stride - params.start_row; + const float* adj_lhs_col_ptr = + params.lhs_base_ptr - params.start_row * lhs_stride; + const float* bias_col_ptr = params.bias; + + const __m512 clamp_max_v = _mm512_set1_ps(params.clamp_max); + const __m512 clamp_min_v = _mm512_set1_ps(params.clamp_min); + + int col = params.start_col; + for (; col <= end_col - 16; col += 16) { + const float* rhs_col_ptr = adj_rhs_col_ptr + col * rhs_stride; + float* dst_col_ptr = adj_dst_col_ptr + col * dst_stride; + + int row = params.start_row; + for (; row <= end_row - 16; row += 16) { + const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; + float* dst_ptr = dst_col_ptr + row; + const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment; + + // Initialize with bias. + const __m512 initial_accum_data = _mm512_loadu_ps(bias_ptr); + + // Process block in two halves, split by columns. + { + constexpr int mmm = 0; + + __m512 accum_data_v0 = initial_accum_data; + __m512 accum_data_v1 = initial_accum_data; + __m512 accum_data_v2 = initial_accum_data; + __m512 accum_data_v3 = initial_accum_data; + __m512 accum_data_v4 = initial_accum_data; + __m512 accum_data_v5 = initial_accum_data; + __m512 accum_data_v6 = initial_accum_data; + __m512 accum_data_v7 = initial_accum_data; + + const float* lhs_ptr = lhs_col_ptr; + const float* rhs_ptr = rhs_col_ptr + 8 * mmm; + for (int d = 0; d < (params.depth - 1); ++d) { + const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr); + // In this version RHS values are loaded individually rather than + // first loading together and then extract with broadcasting. This is + // because AVX flavours and instrinsics and compilers in combination + // do not handle this pattern of extraction very well. + const float* rhs_data = rhs_ptr; + lhs_ptr += 16; + rhs_ptr += 16; + + { + const __m512 dup_rhs_element_j0 = _mm512_set1_ps(rhs_data[0]); + accum_data_v0 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0); + const __m512 dup_rhs_element_j1 = _mm512_set1_ps(rhs_data[1]); + accum_data_v1 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1); + const __m512 dup_rhs_element_j2 = _mm512_set1_ps(rhs_data[2]); + accum_data_v2 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2); + const __m512 dup_rhs_element_j3 = _mm512_set1_ps(rhs_data[3]); + accum_data_v3 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3); + const __m512 dup_rhs_element_j4 = _mm512_set1_ps(rhs_data[4]); + accum_data_v4 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4); + const __m512 dup_rhs_element_j5 = _mm512_set1_ps(rhs_data[5]); + accum_data_v5 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5); + const __m512 dup_rhs_element_j6 = _mm512_set1_ps(rhs_data[6]); + accum_data_v6 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6); + const __m512 dup_rhs_element_j7 = _mm512_set1_ps(rhs_data[7]); + accum_data_v7 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7); + } + } + { + const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr); + const float* rhs_data = rhs_ptr; + { + const __m512 dup_rhs_element_j0 = _mm512_set1_ps(rhs_data[0]); + accum_data_v0 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0); + const __m512 dup_rhs_element_j1 = _mm512_set1_ps(rhs_data[1]); + accum_data_v1 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1); + const __m512 dup_rhs_element_j2 = _mm512_set1_ps(rhs_data[2]); + accum_data_v2 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2); + const __m512 dup_rhs_element_j3 = _mm512_set1_ps(rhs_data[3]); + accum_data_v3 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3); + const __m512 dup_rhs_element_j4 = _mm512_set1_ps(rhs_data[4]); + accum_data_v4 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4); + const __m512 dup_rhs_element_j5 = _mm512_set1_ps(rhs_data[5]); + accum_data_v5 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5); + const __m512 dup_rhs_element_j6 = _mm512_set1_ps(rhs_data[6]); + accum_data_v6 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6); + const __m512 dup_rhs_element_j7 = _mm512_set1_ps(rhs_data[7]); + accum_data_v7 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7); + } + { + float* block_ptr = dst_ptr + (mmm * 8 + 0) * dst_stride; + accum_data_v0 = _mm512_min_ps(accum_data_v0, clamp_max_v); + accum_data_v0 = _mm512_max_ps(accum_data_v0, clamp_min_v); + _mm512_storeu_ps(block_ptr + 0 * dst_stride, accum_data_v0); + accum_data_v1 = _mm512_min_ps(accum_data_v1, clamp_max_v); + accum_data_v1 = _mm512_max_ps(accum_data_v1, clamp_min_v); + _mm512_storeu_ps(block_ptr + 1 * dst_stride, accum_data_v1); + accum_data_v2 = _mm512_min_ps(accum_data_v2, clamp_max_v); + accum_data_v2 = _mm512_max_ps(accum_data_v2, clamp_min_v); + _mm512_storeu_ps(block_ptr + 2 * dst_stride, accum_data_v2); + accum_data_v3 = _mm512_min_ps(accum_data_v3, clamp_max_v); + accum_data_v3 = _mm512_max_ps(accum_data_v3, clamp_min_v); + _mm512_storeu_ps(block_ptr + 3 * dst_stride, accum_data_v3); + accum_data_v4 = _mm512_min_ps(accum_data_v4, clamp_max_v); + accum_data_v4 = _mm512_max_ps(accum_data_v4, clamp_min_v); + _mm512_storeu_ps(block_ptr + 4 * dst_stride, accum_data_v4); + accum_data_v5 = _mm512_min_ps(accum_data_v5, clamp_max_v); + accum_data_v5 = _mm512_max_ps(accum_data_v5, clamp_min_v); + _mm512_storeu_ps(block_ptr + 5 * dst_stride, accum_data_v5); + accum_data_v6 = _mm512_min_ps(accum_data_v6, clamp_max_v); + accum_data_v6 = _mm512_max_ps(accum_data_v6, clamp_min_v); + _mm512_storeu_ps(block_ptr + 6 * dst_stride, accum_data_v6); + accum_data_v7 = _mm512_min_ps(accum_data_v7, clamp_max_v); + accum_data_v7 = _mm512_max_ps(accum_data_v7, clamp_min_v); + _mm512_storeu_ps(block_ptr + 7 * dst_stride, accum_data_v7); + } + } + } // Inner half-block loop, unrolled, first iteration. + { + constexpr int mmm = 1; + + __m512 accum_data_v0 = initial_accum_data; + __m512 accum_data_v1 = initial_accum_data; + __m512 accum_data_v2 = initial_accum_data; + __m512 accum_data_v3 = initial_accum_data; + __m512 accum_data_v4 = initial_accum_data; + __m512 accum_data_v5 = initial_accum_data; + __m512 accum_data_v6 = initial_accum_data; + __m512 accum_data_v7 = initial_accum_data; + + const float* lhs_ptr = lhs_col_ptr; + const float* rhs_ptr = rhs_col_ptr + 8 * mmm; + for (int d = 0; d < (params.depth - 1); ++d) { + const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr); + const float* rhs_data = rhs_ptr; + lhs_ptr += 16; + rhs_ptr += 16; + { + const __m512 dup_rhs_element_j0 = _mm512_set1_ps(rhs_data[0]); + accum_data_v0 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0); + const __m512 dup_rhs_element_j1 = _mm512_set1_ps(rhs_data[1]); + accum_data_v1 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1); + const __m512 dup_rhs_element_j2 = _mm512_set1_ps(rhs_data[2]); + accum_data_v2 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2); + const __m512 dup_rhs_element_j3 = _mm512_set1_ps(rhs_data[3]); + accum_data_v3 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3); + const __m512 dup_rhs_element_j4 = _mm512_set1_ps(rhs_data[4]); + accum_data_v4 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4); + const __m512 dup_rhs_element_j5 = _mm512_set1_ps(rhs_data[5]); + accum_data_v5 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5); + const __m512 dup_rhs_element_j6 = _mm512_set1_ps(rhs_data[6]); + accum_data_v6 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6); + const __m512 dup_rhs_element_j7 = _mm512_set1_ps(rhs_data[7]); + accum_data_v7 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7); + } + } + { + const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr); + const float* rhs_data = rhs_ptr; + { + const __m512 dup_rhs_element_j0 = _mm512_set1_ps(rhs_data[0]); + accum_data_v0 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0); + const __m512 dup_rhs_element_j1 = _mm512_set1_ps(rhs_data[1]); + accum_data_v1 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1); + const __m512 dup_rhs_element_j2 = _mm512_set1_ps(rhs_data[2]); + accum_data_v2 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2); + const __m512 dup_rhs_element_j3 = _mm512_set1_ps(rhs_data[3]); + accum_data_v3 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3); + const __m512 dup_rhs_element_j4 = _mm512_set1_ps(rhs_data[4]); + accum_data_v4 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4); + const __m512 dup_rhs_element_j5 = _mm512_set1_ps(rhs_data[5]); + accum_data_v5 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5); + const __m512 dup_rhs_element_j6 = _mm512_set1_ps(rhs_data[6]); + accum_data_v6 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6); + const __m512 dup_rhs_element_j7 = _mm512_set1_ps(rhs_data[7]); + accum_data_v7 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7); + } + { + float* block_ptr = dst_ptr + (mmm * 8 + 0) * dst_stride; + accum_data_v0 = _mm512_min_ps(accum_data_v0, clamp_max_v); + accum_data_v0 = _mm512_max_ps(accum_data_v0, clamp_min_v); + _mm512_storeu_ps(block_ptr + 0 * dst_stride, accum_data_v0); + accum_data_v1 = _mm512_min_ps(accum_data_v1, clamp_max_v); + accum_data_v1 = _mm512_max_ps(accum_data_v1, clamp_min_v); + _mm512_storeu_ps(block_ptr + 1 * dst_stride, accum_data_v1); + accum_data_v2 = _mm512_min_ps(accum_data_v2, clamp_max_v); + accum_data_v2 = _mm512_max_ps(accum_data_v2, clamp_min_v); + _mm512_storeu_ps(block_ptr + 2 * dst_stride, accum_data_v2); + accum_data_v3 = _mm512_min_ps(accum_data_v3, clamp_max_v); + accum_data_v3 = _mm512_max_ps(accum_data_v3, clamp_min_v); + _mm512_storeu_ps(block_ptr + 3 * dst_stride, accum_data_v3); + accum_data_v4 = _mm512_min_ps(accum_data_v4, clamp_max_v); + accum_data_v4 = _mm512_max_ps(accum_data_v4, clamp_min_v); + _mm512_storeu_ps(block_ptr + 4 * dst_stride, accum_data_v4); + accum_data_v5 = _mm512_min_ps(accum_data_v5, clamp_max_v); + accum_data_v5 = _mm512_max_ps(accum_data_v5, clamp_min_v); + _mm512_storeu_ps(block_ptr + 5 * dst_stride, accum_data_v5); + accum_data_v6 = _mm512_min_ps(accum_data_v6, clamp_max_v); + accum_data_v6 = _mm512_max_ps(accum_data_v6, clamp_min_v); + _mm512_storeu_ps(block_ptr + 6 * dst_stride, accum_data_v6); + accum_data_v7 = _mm512_min_ps(accum_data_v7, clamp_max_v); + accum_data_v7 = _mm512_max_ps(accum_data_v7, clamp_min_v); + _mm512_storeu_ps(block_ptr + 7 * dst_stride, accum_data_v7); + } + } + } // Inner half-block loop, unrolled, second iteration. + } // End row-block loop. + + // The unrolling within this conditional may be somewhat pointless. It + // depends on the kinds of models. + if (row < end_row) { + const int residual_rows = end_row - row; + + const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; + float* dst_ptr = dst_col_ptr + row; + const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment; + + // Initialize with bias. + const __mmask16 row_mask = + (static_cast(1) << residual_rows) - 1; + const __m512 initial_accum_data = + _mm512_maskz_loadu_ps(row_mask, bias_ptr); + + // Process block in two halves, split by columns. + for (int mmm = 0; mmm < 2; ++mmm) { + __m512 accum_data_v0 = initial_accum_data; + __m512 accum_data_v1 = initial_accum_data; + __m512 accum_data_v2 = initial_accum_data; + __m512 accum_data_v3 = initial_accum_data; + __m512 accum_data_v4 = initial_accum_data; + __m512 accum_data_v5 = initial_accum_data; + __m512 accum_data_v6 = initial_accum_data; + __m512 accum_data_v7 = initial_accum_data; + + const float* lhs_ptr = lhs_col_ptr; + const float* rhs_ptr = rhs_col_ptr + 8 * mmm; + for (int d = 0; d < (params.depth - 1); ++d) { + const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr); + const float* rhs_data = rhs_ptr; + lhs_ptr += 16; + rhs_ptr += 16; + { + const __m512 dup_rhs_element_j0 = _mm512_set1_ps(rhs_data[0]); + accum_data_v0 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0); + const __m512 dup_rhs_element_j1 = _mm512_set1_ps(rhs_data[1]); + accum_data_v1 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1); + const __m512 dup_rhs_element_j2 = _mm512_set1_ps(rhs_data[2]); + accum_data_v2 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2); + const __m512 dup_rhs_element_j3 = _mm512_set1_ps(rhs_data[3]); + accum_data_v3 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3); + const __m512 dup_rhs_element_j4 = _mm512_set1_ps(rhs_data[4]); + accum_data_v4 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4); + const __m512 dup_rhs_element_j5 = _mm512_set1_ps(rhs_data[5]); + accum_data_v5 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5); + const __m512 dup_rhs_element_j6 = _mm512_set1_ps(rhs_data[6]); + accum_data_v6 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6); + const __m512 dup_rhs_element_j7 = _mm512_set1_ps(rhs_data[7]); + accum_data_v7 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7); + } + } + { + const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr); + const float* rhs_data = rhs_ptr; + { + const __m512 dup_rhs_element_j0 = _mm512_set1_ps(rhs_data[0]); + accum_data_v0 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0); + const __m512 dup_rhs_element_j1 = _mm512_set1_ps(rhs_data[1]); + accum_data_v1 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1); + const __m512 dup_rhs_element_j2 = _mm512_set1_ps(rhs_data[2]); + accum_data_v2 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2); + const __m512 dup_rhs_element_j3 = _mm512_set1_ps(rhs_data[3]); + accum_data_v3 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3); + const __m512 dup_rhs_element_j4 = _mm512_set1_ps(rhs_data[4]); + accum_data_v4 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4); + const __m512 dup_rhs_element_j5 = _mm512_set1_ps(rhs_data[5]); + accum_data_v5 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5); + const __m512 dup_rhs_element_j6 = _mm512_set1_ps(rhs_data[6]); + accum_data_v6 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6); + const __m512 dup_rhs_element_j7 = _mm512_set1_ps(rhs_data[7]); + accum_data_v7 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7); + } + { + float* block_ptr = dst_ptr + (mmm * 8 + 0) * dst_stride; + accum_data_v0 = _mm512_min_ps(accum_data_v0, clamp_max_v); + accum_data_v0 = _mm512_max_ps(accum_data_v0, clamp_min_v); + _mm512_mask_storeu_ps(block_ptr + 0 * dst_stride, row_mask, + accum_data_v0); + accum_data_v1 = _mm512_min_ps(accum_data_v1, clamp_max_v); + accum_data_v1 = _mm512_max_ps(accum_data_v1, clamp_min_v); + _mm512_mask_storeu_ps(block_ptr + 1 * dst_stride, row_mask, + accum_data_v1); + accum_data_v2 = _mm512_min_ps(accum_data_v2, clamp_max_v); + accum_data_v2 = _mm512_max_ps(accum_data_v2, clamp_min_v); + _mm512_mask_storeu_ps(block_ptr + 2 * dst_stride, row_mask, + accum_data_v2); + accum_data_v3 = _mm512_min_ps(accum_data_v3, clamp_max_v); + accum_data_v3 = _mm512_max_ps(accum_data_v3, clamp_min_v); + _mm512_mask_storeu_ps(block_ptr + 3 * dst_stride, row_mask, + accum_data_v3); + accum_data_v4 = _mm512_min_ps(accum_data_v4, clamp_max_v); + accum_data_v4 = _mm512_max_ps(accum_data_v4, clamp_min_v); + _mm512_mask_storeu_ps(block_ptr + 4 * dst_stride, row_mask, + accum_data_v4); + accum_data_v5 = _mm512_min_ps(accum_data_v5, clamp_max_v); + accum_data_v5 = _mm512_max_ps(accum_data_v5, clamp_min_v); + _mm512_mask_storeu_ps(block_ptr + 5 * dst_stride, row_mask, + accum_data_v5); + accum_data_v6 = _mm512_min_ps(accum_data_v6, clamp_max_v); + accum_data_v6 = _mm512_max_ps(accum_data_v6, clamp_min_v); + _mm512_mask_storeu_ps(block_ptr + 6 * dst_stride, row_mask, + accum_data_v6); + accum_data_v7 = _mm512_min_ps(accum_data_v7, clamp_max_v); + accum_data_v7 = _mm512_max_ps(accum_data_v7, clamp_min_v); + _mm512_mask_storeu_ps(block_ptr + 7 * dst_stride, row_mask, + accum_data_v7); + } + } + } // Inner half-block loop. + } // Residual rows, main col-block loop. + } // End col-block loop. + + if (col < end_col) { + RUY_DCHECK_GE(end_col - col, 0); + RUY_DCHECK_LT(end_col - col, 16); + + __m512 accum_data_v[8]; + + const float* rhs_col_ptr = adj_rhs_col_ptr + col * rhs_stride; + float* dst_col_ptr = adj_dst_col_ptr + col * dst_stride; + + for (int row = params.start_row; row < end_row; row += 16) { + const int residual_rows = std::min(end_row - row, 16); + + const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; + float* dst_ptr = dst_col_ptr + row; + const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment; + + // Initialize with bias. + const __mmask16 row_mask = + (static_cast(1) << residual_rows) - 1; + const __m512 initial_accum_data = + _mm512_maskz_loadu_ps(row_mask, bias_ptr); + + // Process block in two halves, split by columns. + for (int mmm = 0; mmm < 2; ++mmm) { + for (int j = 0; j < 8; ++j) { + accum_data_v[j] = initial_accum_data; + } + + const float* lhs_ptr = lhs_col_ptr; + const float* rhs_ptr = rhs_col_ptr + 8 * mmm; + for (int d = 0; d < params.depth; ++d) { + const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr); + const float* rhs_data = rhs_ptr; + + for (int j = 0; j < 8; ++j) { + const __m512 dup_rhs_element_j = _mm512_set1_ps(rhs_data[j]); + accum_data_v[j] = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v[j]); + } + lhs_ptr += 16; + rhs_ptr += 16; + } + + const int residual_cols = std::min(end_col - col - 8 * mmm, 8); + + if (residual_rows == 16) { + if (residual_cols == 8) { + for (int j = 0; j < 8; ++j) { + float* block_ptr = dst_ptr + (mmm * 8 + j) * dst_stride; + accum_data_v[j] = _mm512_min_ps(accum_data_v[j], clamp_max_v); + accum_data_v[j] = _mm512_max_ps(accum_data_v[j], clamp_min_v); + _mm512_storeu_ps(block_ptr, accum_data_v[j]); + } + } else { + for (int j = 0; j < residual_cols; ++j) { + float* block_ptr = dst_ptr + (mmm * 8 + j) * dst_stride; + accum_data_v[j] = _mm512_min_ps(accum_data_v[j], clamp_max_v); + accum_data_v[j] = _mm512_max_ps(accum_data_v[j], clamp_min_v); + _mm512_storeu_ps(block_ptr, accum_data_v[j]); + } + } + } else { + for (int j = 0; j < residual_cols; ++j) { + float* block_ptr = dst_ptr + (mmm * 8 + j) * dst_stride; + accum_data_v[j] = _mm512_min_ps(accum_data_v[j], clamp_max_v); + accum_data_v[j] = _mm512_max_ps(accum_data_v[j], clamp_min_v); + _mm512_mask_storeu_ps(block_ptr, row_mask, accum_data_v[j]); + } + } + } // Inner half-block loop. + } // End row-block loop. + } // Residual cols. +} + +void KernelFloatAvx512SingleCol(const KernelParamsFloat<16, 16>& params) { + profiler::ScopeLabel label("Kernel kAvx512 float GEMV"); + + RUY_DCHECK_EQ(params.dst_cols, 1); + RUY_DCHECK_EQ(params.last_col, 0); + RUY_DCHECK_EQ(params.start_col, 0); + + // As parameters are defined, we need to scale by sizeof(float). + const std::int64_t lhs_stride = params.lhs_stride >> 2; + + int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 1 : 0; + const int end_row = std::min(params.dst_rows, params.last_row + 16); + + float* adj_dst_col_ptr = params.dst_base_ptr - params.start_row; + const float* adj_lhs_col_ptr = + params.lhs_base_ptr - params.start_row * lhs_stride; + const float* bias_col_ptr = params.bias; + + const __m512 clamp_max_v = _mm512_set1_ps(params.clamp_max); + const __m512 clamp_min_v = _mm512_set1_ps(params.clamp_min); + + __m512 accum_data_v; + + const float* rhs_col_ptr = params.rhs_base_ptr; + float* dst_col_ptr = adj_dst_col_ptr; + + int row = params.start_row; + for (; row <= end_row - 16; row += 16) { + const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; + float* dst_ptr = dst_col_ptr + row; + const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment; + + // Initialize with bias. + accum_data_v = _mm512_loadu_ps(bias_ptr); + + const float* lhs_ptr = lhs_col_ptr; + const float* rhs_ptr = rhs_col_ptr; + for (int d = 0; d < params.depth; ++d) { + const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr); + const float rhs_data = *rhs_ptr; + + const __m512 dup_rhs_element_j = _mm512_set1_ps(rhs_data); + accum_data_v = _mm512_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v); + lhs_ptr += 16; + rhs_ptr += 16; + } + + accum_data_v = _mm512_min_ps(accum_data_v, clamp_max_v); + accum_data_v = _mm512_max_ps(accum_data_v, clamp_min_v); + _mm512_storeu_ps(dst_ptr, accum_data_v); + } // End row-block loop. + + if (row < end_row) { + const int residual_rows = end_row - row; + RUY_CHECK_GE(residual_rows, 1); + RUY_CHECK_LT(residual_rows, 16); + + const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; + float* dst_ptr = dst_col_ptr + row; + const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment; + + // Initialize with bias. + const __mmask16 row_mask = + (static_cast(1) << residual_rows) - 1; + accum_data_v = _mm512_maskz_loadu_ps(row_mask, bias_ptr); + + const float* lhs_ptr = lhs_col_ptr; + const float* rhs_ptr = rhs_col_ptr; + for (int d = 0; d < params.depth; ++d) { + const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr); + const float rhs_data = *rhs_ptr; + + const __m512 dup_rhs_element_j = _mm512_set1_ps(rhs_data); + accum_data_v = _mm512_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v); + lhs_ptr += 16; + rhs_ptr += 16; + } + + accum_data_v = _mm512_min_ps(accum_data_v, clamp_max_v); + accum_data_v = _mm512_max_ps(accum_data_v, clamp_min_v); + _mm512_mask_storeu_ps(dst_ptr, row_mask, accum_data_v); + } // End handling of residual rows. +} + +#endif // RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_ASM) + +} // namespace ruy diff --git a/ruy/kernel_avxvnni.cc b/ruy/kernel_avxvnni.cc new file mode 100644 index 0000000..4513b20 --- /dev/null +++ b/ruy/kernel_avxvnni.cc @@ -0,0 +1,435 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "ruy/check_macros.h" +#include "ruy/kernel.h" +#include "ruy/opt_set.h" +#include "ruy/platform.h" +#include "ruy/profiler/instrumentation.h" + +#if RUY_PLATFORM(AVX_VNNI) && RUY_OPT_ENABLED(RUY_OPT_ASM) +#include // IWYU pragma: keep +#endif + +namespace ruy { + +#if !(RUY_PLATFORM(AVX_VNNI) && RUY_OPT_ENABLED(RUY_OPT_ASM)) + +void Kernel8bitAvxVnni(const KernelParams8bit<16, 16>& params) { + // CPU-ID-based checks should disable the path that would reach this point. + RUY_DCHECK(false); +} + +void KernelFloatAvxVnni(const KernelParamsFloat<16, 16>& params) { + // CPU-ID-based checks should disable the path that would reach this point. + RUY_DCHECK(false); +} + +#else // RUY_PLATFORM(AVX_VNNI) && RUY_OPT_ENABLED(RUY_OPT_ASM) + +static constexpr int kAvxFloatBlockSize = 16; +static constexpr int kAvx8bitBlockSize = 16; +static constexpr int kAvx8bitInnerSize = 4; + +// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. +// Optimization is not finished. In particular the dimensions of the kernel +// blocks can be changed as desired. +// +// When removing this comment, update profiling label below. +void Kernel8bitAvxVnni(const KernelParams8bit<16, 16>& params) { + profiler::ScopeLabel label("Kernel kAvxVnni 8-bit (UNFINISHED)"); + + std::int32_t accum_data[kAvx8bitBlockSize][kAvx8bitBlockSize]; + + int bias_ptr_block_increment = + params.flags & RUY_ASM_FLAG_HAS_BIAS ? kAvx8bitBlockSize : 0; + + const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; + void* dst_col_ptr = params.dst_base_ptr; + const std::int32_t* bias_col_ptr = params.bias; + if (params.flags & RUY_ASM_FLAG_HAS_BIAS) { + bias_col_ptr += params.start_row; + } + + for (int col = params.start_col; col <= params.last_col; + col += kAvx8bitBlockSize) { + const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; + void* dst_ptr = dst_col_ptr; + const std::int32_t* bias_ptr = bias_col_ptr; + + for (int row = params.start_row; row <= params.last_row; + row += kAvx8bitBlockSize) { + const int residual_rows = + std::min(params.dst_rows - row, kAvx8bitBlockSize); + const int residual_cols = + std::min(params.dst_cols - col, kAvx8bitBlockSize); + + // Initialize with bias. + std::int32_t initial_accum_data[kAvx8bitBlockSize]; + for (int i = 0; i < kAvx8bitBlockSize; ++i) { + initial_accum_data[i] = 0; + } + for (int i = 0; i < residual_rows; ++i) { + initial_accum_data[i] = bias_ptr[i]; + } + + for (int j = 0; j < kAvx8bitBlockSize; ++j) { + for (int i = 0; i < kAvx8bitBlockSize; ++i) { + accum_data[j][i] = initial_accum_data[i]; + } + } + bias_ptr += bias_ptr_block_increment; + + std::int8_t lhs_data[kAvx8bitBlockSize][kAvx8bitInnerSize]; + std::int8_t rhs_data[kAvx8bitBlockSize][kAvx8bitInnerSize]; + const std::int8_t* lhs_ptr = lhs_col_ptr; + const std::int8_t* rhs_ptr = rhs_col_ptr; + for (int d = 0; d < params.depth; d += kAvx8bitInnerSize) { + for (int i = 0; i < kAvx8bitBlockSize; ++i) { + for (int x = 0; x < kAvx8bitInnerSize; ++x) { + lhs_data[i][x] = lhs_ptr[i * kAvx8bitInnerSize + x]; + rhs_data[i][x] = rhs_ptr[i * kAvx8bitInnerSize + x]; + } + } + + for (int j = 0; j < kAvx8bitBlockSize; ++j) { + for (int i = 0; i < kAvx8bitBlockSize; ++i) { + for (int x = 0; x < kAvx8bitInnerSize; ++x) { + accum_data[j][i] += lhs_data[i][x] * rhs_data[j][x]; + } + } + } + + lhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize; + rhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize; + } + + if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && params.rhs_zero_point) { + for (int j = 0; j < kAvx8bitBlockSize; ++j) { + for (int i = 0; i < kAvx8bitBlockSize; ++i) { + accum_data[j][i] -= + params.rhs_zero_point * params.lhs_sums[row + i]; + } + } + } + if ((params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && params.lhs_zero_point) { + for (int j = 0; j < kAvx8bitBlockSize; ++j) { + for (int i = 0; i < kAvx8bitBlockSize; ++i) { + accum_data[j][i] -= + params.lhs_zero_point * params.rhs_sums[col + j]; + } + } + } + if (params.lhs_zero_point && params.rhs_zero_point) { + for (int j = 0; j < kAvx8bitBlockSize; ++j) { + for (int i = 0; i < kAvx8bitBlockSize; ++i) { + accum_data[j][i] += params.prod_zp_depth; + } + } + } + + if (params.dst_type_id != DstTypeId::kValue) { + std::int32_t m_vector[kAvx8bitBlockSize]; + std::int32_t e_vector[kAvx8bitBlockSize]; + // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT. + if (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) { + int i = 0; + for (; i < residual_rows; ++i) { + m_vector[i] = params.multiplier_fixedpoint[row + i]; + e_vector[i] = params.multiplier_exponent[row + i]; + } + for (; i < kAvx8bitBlockSize; ++i) { + m_vector[i] = m_vector[0]; + e_vector[i] = e_vector[0]; + } + } else { + // These arrays have size LhsCols, and are pre-filled. + for (int i = 0; i < kAvx8bitBlockSize; ++i) { + m_vector[i] = params.multiplier_fixedpoint[i]; + e_vector[i] = params.multiplier_exponent[i]; + } + } + + for (int j = 0; j < kAvx8bitBlockSize; ++j) { + for (int i = 0; i < kAvx8bitBlockSize; ++i) { + accum_data[j][i] = MultiplyByQuantizedMultiplier( + accum_data[j][i], m_vector[i], e_vector[i]); + } + } + + if (params.dst_zero_point) { + for (int j = 0; j < kAvx8bitBlockSize; ++j) { + for (int i = 0; i < kAvx8bitBlockSize; ++i) { + accum_data[j][i] += params.dst_zero_point; + } + } + } + + for (int j = 0; j < kAvx8bitBlockSize; ++j) { + for (int i = 0; i < kAvx8bitBlockSize; ++i) { + accum_data[j][i] = + std::min(accum_data[j][i], params.clamp_max); + accum_data[j][i] = + std::max(accum_data[j][i], params.clamp_min); + } + } + } + + const bool store_full_block = (residual_rows == kAvx8bitBlockSize) && + (residual_cols == kAvx8bitBlockSize); + + if (params.dst_type_id == DstTypeId::kValue) { + std::int8_t* tmp_ptr = + store_full_block + ? static_cast(dst_ptr) + : const_cast( + reinterpret_cast(params.dst_tmp_buf)); + const int block_col_offset = + store_full_block ? params.dst_stride / sizeof(std::int8_t) + : kAvx8bitBlockSize; + for (int j = 0; j < kAvx8bitBlockSize; ++j) { + for (int i = 0; i < kAvx8bitBlockSize; ++i) { + tmp_ptr[i] = accum_data[j][i]; + } + tmp_ptr += block_col_offset; + } + + if (!store_full_block) { + const std::int8_t* block_ptr = + reinterpret_cast(params.dst_tmp_buf); + for (int j = 0; j < residual_cols; ++j) { + for (int i = 0; i < residual_rows; ++i) { + static_cast( + dst_ptr)[j * params.dst_stride / sizeof(std::int8_t) + i] = + block_ptr[i]; + } + block_ptr += kAvx8bitBlockSize; + } + } + dst_ptr = static_cast(static_cast(dst_ptr) + + kAvx8bitBlockSize); + } else if (params.dst_type_id == DstTypeId::kValue) { + std::uint8_t* tmp_ptr = store_full_block + ? static_cast(dst_ptr) + : const_cast( + reinterpret_cast( + params.dst_tmp_buf)); + const int block_col_offset = + store_full_block ? params.dst_stride : kAvx8bitBlockSize; + for (int j = 0; j < kAvx8bitBlockSize; ++j) { + for (int i = 0; i < kAvx8bitBlockSize; ++i) { + tmp_ptr[i] = accum_data[j][i]; + } + tmp_ptr += block_col_offset; + } + + if (!store_full_block) { + const std::uint8_t* block_ptr = + reinterpret_cast(params.dst_tmp_buf); + for (int j = 0; j < residual_cols; ++j) { + for (int i = 0; i < residual_rows; ++i) { + static_cast( + dst_ptr)[j * params.dst_stride / sizeof(std::uint8_t) + i] = + block_ptr[i]; + } + block_ptr += kAvx8bitBlockSize; + } + } + dst_ptr = static_cast(static_cast(dst_ptr) + + kAvx8bitBlockSize); + } else if (params.dst_type_id == DstTypeId::kValue) { + if (store_full_block) { + std::int16_t* tmp_ptr = static_cast(dst_ptr); + const int block_col_offset = params.dst_stride / sizeof(std::int16_t); + for (int j = 0; j < kAvx8bitBlockSize; ++j) { + for (int i = 0; i < kAvx8bitBlockSize; ++i) { + tmp_ptr[i] = accum_data[j][i]; + } + tmp_ptr += block_col_offset; + } + } else { + std::int16_t* tmp_ptr = const_cast( + reinterpret_cast(params.dst_tmp_buf)); + const int block_col_offset = kAvx8bitBlockSize; + for (int j = 0; j < kAvx8bitBlockSize; ++j) { + for (int i = 0; i < kAvx8bitBlockSize; ++i) { + tmp_ptr[i] = accum_data[j][i]; + } + tmp_ptr += block_col_offset; + } + const std::int16_t* block_ptr = + reinterpret_cast(params.dst_tmp_buf); + std::int16_t* dst_block_ptr = static_cast(dst_ptr); + for (int j = 0; j < residual_cols; ++j) { + for (int i = 0; i < residual_rows; ++i) { + dst_block_ptr[i] = block_ptr[i]; + } + dst_block_ptr += params.dst_stride / sizeof(std::int16_t); + block_ptr += kAvx8bitBlockSize; + } + } + dst_ptr = static_cast(static_cast(dst_ptr) + + kAvx8bitBlockSize); + } else if (params.dst_type_id == DstTypeId::kValue) { + if (store_full_block) { + std::int32_t* tmp_ptr = static_cast(dst_ptr); + const int block_col_offset = params.dst_stride / sizeof(std::int32_t); + for (int j = 0; j < kAvx8bitBlockSize; ++j) { + for (int i = 0; i < kAvx8bitBlockSize; ++i) { + tmp_ptr[i] = accum_data[j][i]; + } + tmp_ptr += block_col_offset; + } + } else { + std::int32_t* dst_block_ptr = static_cast(dst_ptr); + for (int j = 0; j < residual_cols; ++j) { + for (int i = 0; i < residual_rows; ++i) { + dst_block_ptr[i] = accum_data[j][i]; + } + dst_block_ptr += params.dst_stride / sizeof(std::int32_t); + } + } + dst_ptr = static_cast(static_cast(dst_ptr) + + kAvx8bitBlockSize); + } else { + RUY_DCHECK(false); + } + + lhs_col_ptr += kAvx8bitBlockSize * params.lhs_stride; + } // End row-block loop. + + dst_col_ptr = static_cast(static_cast(dst_col_ptr) + + kAvx8bitBlockSize * params.dst_stride); + rhs_col_ptr += kAvx8bitBlockSize * params.rhs_stride; + } // End col-block loop. +} // NOLINT(readability/fn_size) + +// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. +// Optimization is not finished. In particular the dimensions of the kernel +// blocks can be changed as desired. +// +// When removing this comment, update profiling label below. +void KernelFloatAvxVnni(const KernelParamsFloat<16, 16>& params) { + profiler::ScopeLabel label("Kernel kAvxVnni float (UNFINISHED)"); + + float lhs_data[kAvxFloatBlockSize]; + float rhs_data[kAvxFloatBlockSize]; + float accum_data[kAvxFloatBlockSize][kAvxFloatBlockSize]; + int bias_ptr_block_increment = + params.flags & RUY_ASM_FLAG_HAS_BIAS ? kAvxFloatBlockSize : 0; + + const float* rhs_col_ptr = params.rhs_base_ptr; + float* dst_col_ptr = params.dst_base_ptr; + const float* bias_col_ptr = params.bias; + if (params.flags & RUY_ASM_FLAG_HAS_BIAS) { + bias_col_ptr += params.start_row; + } + + for (int col = params.start_col; col <= params.last_col; + col += kAvxFloatBlockSize) { + const float* lhs_col_ptr = params.lhs_base_ptr; + float* dst_ptr = dst_col_ptr; + const float* bias_ptr = bias_col_ptr; + + for (int row = params.start_row; row <= params.last_row; + row += kAvxFloatBlockSize) { + const int residual_rows = + std::min(params.dst_rows - row, kAvxFloatBlockSize); + const int residual_cols = + std::min(params.dst_cols - col, kAvxFloatBlockSize); + + // Initialize with bias. + float initial_accum_data[kAvxFloatBlockSize]; + for (int i = 0; i < kAvxFloatBlockSize; ++i) { + initial_accum_data[i] = 0.0f; + } + for (int i = 0; i < residual_rows; ++i) { + initial_accum_data[i] = bias_ptr[i]; + } + for (int j = 0; j < kAvxFloatBlockSize; ++j) { + for (int i = 0; i < kAvxFloatBlockSize; ++i) { + accum_data[j][i] = initial_accum_data[i]; + } + } + bias_ptr += bias_ptr_block_increment; + + const float* lhs_ptr = lhs_col_ptr; + const float* rhs_ptr = rhs_col_ptr; + for (int d = 0; d < params.depth; ++d) { + for (int i = 0; i < kAvxFloatBlockSize; ++i) { + lhs_data[i] = lhs_ptr[i]; + rhs_data[i] = rhs_ptr[i]; + } + + for (int j = 0; j < kAvxFloatBlockSize; ++j) { + for (int i = 0; i < kAvxFloatBlockSize; ++i) { + accum_data[j][i] += lhs_data[i] * rhs_data[j]; + } + } + + lhs_ptr += kAvxFloatBlockSize; + rhs_ptr += kAvxFloatBlockSize; + } + + for (int j = 0; j < kAvxFloatBlockSize; ++j) { + for (int i = 0; i < kAvxFloatBlockSize; ++i) { + accum_data[j][i] = + std::min(accum_data[j][i], params.clamp_max); + accum_data[j][i] = + std::max(accum_data[j][i], params.clamp_min); + } + } + + const bool store_full_block = (residual_rows == kAvxFloatBlockSize) && + (residual_cols == kAvxFloatBlockSize); + + { + float* block_ptr = + store_full_block ? dst_ptr : const_cast(params.dst_tmp_buf); + const int block_col_offset = store_full_block + ? params.dst_stride / sizeof(float) + : kAvxFloatBlockSize; + for (int j = 0; j < kAvxFloatBlockSize; ++j) { + for (int i = 0; i < kAvxFloatBlockSize; ++i) { + block_ptr[i] = accum_data[j][i]; + } + block_ptr += block_col_offset; + } + } + if (!store_full_block) { + const float* block_ptr = params.dst_tmp_buf; + for (int j = 0; j < residual_cols; ++j) { + for (int i = 0; i < residual_rows; ++i) { + dst_ptr[j * params.dst_stride / sizeof(float) + i] = block_ptr[i]; + } + block_ptr += kAvxFloatBlockSize; + } + } + + lhs_col_ptr += kAvxFloatBlockSize * params.lhs_stride / sizeof(float); + dst_ptr += kAvxFloatBlockSize; + } // End row-block loop. + + dst_col_ptr += kAvxFloatBlockSize * params.dst_stride / sizeof(float); + rhs_col_ptr += kAvxFloatBlockSize * params.rhs_stride / sizeof(float); + } // End col-block loop. +} + +#endif // RUY_PLATFORM(AVX_VNNI) && RUY_OPT_ENABLED(RUY_OPT_ASM) + +} // namespace ruy diff --git a/ruy/kernel_common.h b/ruy/kernel_common.h new file mode 100644 index 0000000..0cd123f --- /dev/null +++ b/ruy/kernel_common.h @@ -0,0 +1,481 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_KERNEL_COMMON_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_KERNEL_COMMON_H_ + +#include +#include +#include + +#include "ruy/check_macros.h" +#include "ruy/common.h" +#include "ruy/internal_matrix.h" +#include "ruy/matrix.h" +#include "ruy/opt_set.h" +#include "ruy/path.h" +#include "ruy/platform.h" +#include "ruy/profiler/instrumentation.h" +#include "ruy/side_pair.h" +#include "ruy/size_util.h" +#include "ruy/spec.h" +#include "ruy/tune.h" + +namespace ruy { + +template +struct Kernel {}; + +template +void RunKernelTyped(Tuning tuning, const PackedMatrix& lhs, + const PackedMatrix& rhs, const Spec& spec, + int start_row, int start_col, int end_row, int end_col, + Matrix* dst) { + using Kernel = Kernel; + Kernel kernel(tuning); +#if !defined(NDEBUG) || !RUY_OPT_ENABLED(RUY_OPT_FAT_KERNEL) + using LhsLayout = typename Kernel::LhsLayout; + using RhsLayout = typename Kernel::RhsLayout; +#endif + // end_row and end_col may be larger than dst dimensions. + // that is because kernels write directly to the destination matrix, whose + // dimensions may not be a multiple of the kernel dimensions, and we try to + // keep this annoyance localized as an implementation detail in kernels, + // by allowing to pass rounded-up values down as far as possible. + // These assertions encode the contract. + RUY_DCHECK_LE(0, start_row); + RUY_DCHECK_LE(start_row, end_row); + RUY_DCHECK_LT(end_row, dst->layout.rows + LhsLayout::kCols); + RUY_DCHECK_EQ((end_row - start_row) % LhsLayout::kCols, 0); + RUY_DCHECK_LE(0, start_col); + RUY_DCHECK_LE(start_col, end_col); + RUY_DCHECK_LT(end_col, dst->layout.cols + RhsLayout::kCols); + RUY_DCHECK_EQ((end_col - start_col) % RhsLayout::kCols, 0); +#if RUY_OPT_ENABLED(RUY_OPT_FAT_KERNEL) + kernel.Run(lhs, rhs, spec, start_row, start_col, end_row, end_col, dst); +#else + for (int col = start_col; col < end_col; col += RhsLayout::kCols) { + int block_end_col = std::min(col + RhsLayout::kCols, end_col); + for (int row = start_row; row < end_row; row += LhsLayout::kCols) { + int block_end_row = std::min(row + LhsLayout::kCols, end_row); + kernel.Run(lhs, rhs, spec, row, col, block_end_row, block_end_col, dst); + } + } +#endif +} + +// Main entry point for kernels. +template +void RunKernel(Tuning tuning, const SidePair& src, void* spec, + const SidePair& start, const SidePair& end, + DMatrix* dst) { + Matrix mdst = ToMatrix(*dst); + RunKernelTyped( + tuning, ToPackedMatrix(src[Side::kLhs]), + ToPackedMatrix(src[Side::kRhs]), + *static_cast(spec), start[Side::kLhs], start[Side::kRhs], + end[Side::kLhs], end[Side::kRhs], &mdst); +} + +// Copied from gemmlowp/fixedpoint. +inline std::int32_t SaturatingRoundingDoublingHighMul(std::int32_t a, + std::int32_t b) { + bool overflow = a == b && a == std::numeric_limits::min(); + std::int64_t a_64(a); + std::int64_t b_64(b); + std::int64_t ab_64 = a_64 * b_64; + std::int32_t nudge = ab_64 >= 0 ? (1 << 30) : (1 - (1 << 30)); + std::int32_t ab_x2_high32 = + static_cast((ab_64 + nudge) / (1ll << 31)); + return overflow ? std::numeric_limits::max() : ab_x2_high32; +} + +inline std::int32_t RoundingDivideByPOT(std::int32_t numerator, int exponent) { + std::int32_t sign = numerator >= 0 ? 1 : -1; + std::int32_t abs_numerator = std::abs(numerator); + std::int32_t mask = (1LL << exponent) - 1; + std::int32_t remainder = abs_numerator & mask; + std::int32_t threshold = mask >> 1; + std::int32_t abs_result = + (abs_numerator >> exponent) + (remainder > threshold ? 1 : 0); + return sign * abs_result; +} + +// Copied from TF Lite code. +inline std::int32_t MultiplyByQuantizedMultiplier( + std::int32_t x, std::int32_t quantized_multiplier, int shift) { + int left_shift = shift > 0 ? shift : 0; + int right_shift = shift > 0 ? 0 : -shift; + return RoundingDivideByPOT(SaturatingRoundingDoublingHighMul( + x * (1 << left_shift), quantized_multiplier), + right_shift); +} + +// Helper to apply a fixed-point multiplier. Only 'applicable' if AccumScalar +// is int32 (i.e. in all cases except floating-point) and if the destination is +// not int32 (i.e. unless the user wants to get raw accumulators). +template ::value && + !std::is_same::value> +struct ApplyMultiplierImpl {}; + +// Specialization in non-applicable case: do nothing, just check that values +// are default. +template +struct ApplyMultiplierImpl { + using AccumScalar = typename Spec::AccumScalar; + using DstScalar = typename Spec::DstScalar; + static void Run(const Spec& spec, int row, AccumScalar* accum) { + RUY_DCHECK_EQ(spec.multiplier_fixedpoint, 0); + RUY_DCHECK_EQ(spec.multiplier_exponent, 0); + } +}; + +template +struct ApplyMultiplierImpl { + using AccumScalar = typename Spec::AccumScalar; + using DstScalar = typename Spec::DstScalar; + static void Run(const Spec& spec, int row, AccumScalar* accum) { + AccumScalar m = spec.multiplier_fixedpoint_perchannel + ? spec.multiplier_fixedpoint_perchannel[row] + : spec.multiplier_fixedpoint; + int e = spec.multiplier_exponent_perchannel + ? spec.multiplier_exponent_perchannel[row] + : spec.multiplier_exponent; + *accum = MultiplyByQuantizedMultiplier(*accum, m, e); + } +}; + +template +void ApplyMultiplier(const Spec& spec, int row, + typename Spec::AccumScalar* accum) { + ApplyMultiplierImpl::Run(spec, row, accum); +} + +template +struct Kernel { + using AccumScalar = typename Spec::AccumScalar; + using LhsLayout = typename Spec::StandardCppKernelLhsLayout; + using RhsLayout = typename Spec::StandardCppKernelRhsLayout; + explicit Kernel(Tuning) {} + void Run(const PackedMatrix& lhs, + const PackedMatrix& rhs, const Spec& spec, int start_row, + int start_col, int end_row, int end_col, + Matrix* dst) const { + // See the comment in RunKernelTyped. end_row may be larger than + // dst->layout.rows. It's the responsibility of the kernel to avoid + // overrunning dst boundaries, which we do here by computing + // clamped_end_row. + int clamped_end_row = std::min(end_row, dst->layout.rows); + int clamped_end_col = std::min(end_col, dst->layout.cols); + RUY_DCHECK_LE(0, start_row); + RUY_DCHECK_LE(start_row, clamped_end_row); + RUY_DCHECK_LE(clamped_end_row, dst->layout.rows); + RUY_DCHECK_LE(clamped_end_row, end_row); + RUY_DCHECK_LE(end_row - clamped_end_row, LhsLayout::kCols); + RUY_DCHECK_LE(0, start_col); + RUY_DCHECK_LE(start_col, clamped_end_col); + RUY_DCHECK_LE(clamped_end_col, dst->layout.cols); + RUY_DCHECK_LE(clamped_end_col, end_col); + RUY_DCHECK_LE(end_col - clamped_end_col, RhsLayout::kCols); + profiler::ScopeLabel label("Kernel (Standard Cpp)"); + const int depth = lhs.layout.rows; + for (int i = start_row; i < clamped_end_row; i++) { + for (int j = start_col; j < clamped_end_col; j++) { + using AccumScalar = typename Spec::AccumScalar; + AccumScalar accum = 0; + for (int k = 0; k < depth; k++) { + AccumScalar lhs_val = Element(lhs, k, i); + AccumScalar rhs_val = Element(rhs, k, j); + accum += lhs_val * rhs_val; + } + if (spec.bias) { + accum += spec.bias[i]; + } + if (lhs.zero_point) { + accum -= lhs.zero_point * rhs.sums[j]; + } + if (rhs.zero_point) { + accum -= rhs.zero_point * lhs.sums[i]; + } + if (lhs.zero_point && rhs.zero_point) { + accum += lhs.zero_point * rhs.zero_point * depth; + } + ApplyMultiplier(spec, i, &accum); + accum += dst->zero_point; + accum = std::min(accum, spec.clamp_max); + accum = std::max(accum, spec.clamp_min); + *ElementPtr(dst, i, j) = static_cast(accum); + } + } + } +}; + +#define RUY_INHERIT_KERNEL(PARENT, CHILD) \ + template \ + struct Kernel \ + : Kernel { \ + explicit Kernel(Tuning tuning) \ + : Kernel(tuning) {} \ + }; + +#if RUY_PLATFORM(NEON) +RUY_INHERIT_KERNEL(Path::kStandardCpp, Path::kNeon) +RUY_INHERIT_KERNEL(Path::kNeon, Path::kNeonDotprod) +#elif RUY_PLATFORM(X86) +RUY_INHERIT_KERNEL(Path::kStandardCpp, Path::kSse42) +RUY_INHERIT_KERNEL(Path::kSse42, Path::kAvx2) +RUY_INHERIT_KERNEL(Path::kAvx2, Path::kAvx512) +RUY_INHERIT_KERNEL(Path::kAvx512, Path::kAvxVnni) +#endif + +// KernelParams are shared across 32-bit and 64-bit NEON code, and x86 code. +// +// In other cases, we still define (empty) versions, so that dummy kernels +// can use the classes in function signatures. +#if ((RUY_PLATFORM(NEON_64) || RUY_PLATFORM(NEON_32)) && \ + RUY_OPT_ENABLED(RUY_OPT_ASM)) || \ + RUY_PLATFORM(X86) + +#define RUY_ASM_FLAG_HAS_BIAS 0x1 +#define RUY_ASM_FLAG_HAS_LHS_SUMS 0x2 +#define RUY_ASM_FLAG_HAS_RHS_SUMS 0x4 +#define RUY_ASM_FLAG_HAS_PERCHANNEL 0x8 +#define RUY_ASM_FLAG_NEEDS_LEFT_SHIFT 0x10 + +#define RUY_ASM_TYPE_ID_UINT8 1 +#define RUY_ASM_TYPE_ID_INT8 2 +#define RUY_ASM_TYPE_ID_INT16 3 +#define RUY_ASM_TYPE_ID_INT32 4 + +template +struct DstTypeId {}; + +template <> +struct DstTypeId { + static constexpr int kValue = RUY_ASM_TYPE_ID_UINT8; +}; + +template <> +struct DstTypeId { + static constexpr int kValue = RUY_ASM_TYPE_ID_INT8; +}; + +template <> +struct DstTypeId { + static constexpr int kValue = RUY_ASM_TYPE_ID_INT16; +}; + +template <> +struct DstTypeId { + static constexpr int kValue = RUY_ASM_TYPE_ID_INT32; +}; + +template +struct KernelParams8bit { + static constexpr int kMaxDstTypeSize = 4; + + const std::int32_t* bias; + const std::int32_t* lhs_sums; + const std::int32_t* rhs_sums; + const std::int8_t* lhs_base_ptr; + const std::int32_t* multiplier_fixedpoint; + const std::int32_t* multiplier_exponent; + const std::int8_t* rhs_base_ptr; + void* dst_base_ptr; + std::int32_t lhs_zero_point; + std::int32_t rhs_zero_point; + std::int32_t dst_zero_point; + std::int32_t prod_zp_depth; + std::int32_t start_row; + std::int32_t start_col; + std::int32_t last_row; + std::int32_t last_col; + std::int32_t dst_rows; + std::int32_t dst_cols; + std::int32_t lhs_stride; + std::int32_t rhs_stride; + std::int32_t dst_stride; + std::int32_t depth; + std::int32_t clamp_min; + std::int32_t clamp_max; + std::uint8_t flags; + std::uint8_t dst_type_id; + const std::int32_t zero_data[LhsCols] = {0}; + std::uint8_t dst_tmp_buf[LhsCols * RhsCols * kMaxDstTypeSize]; + std::int32_t multiplier_fixedpoint_buf[LhsCols]; + std::int32_t multiplier_exponent_buf[LhsCols]; +}; + +template +void MakeKernelParams8bit(const PackedMatrix& lhs, + const PackedMatrix& rhs, + const BasicSpec& spec, + int start_row, int start_col, int end_row, + int end_col, Matrix* dst, + KernelParams8bit* params) { + using Params = KernelParams8bit; + + static_assert(sizeof(DstScalar) <= Params::kMaxDstTypeSize, ""); + + const int depth = lhs.layout.rows; + RUY_DCHECK_EQ(start_row % LhsCols, 0); + RUY_DCHECK_EQ(start_col % RhsCols, 0); + RUY_DCHECK_EQ(end_row % LhsCols, 0); + RUY_DCHECK_EQ(end_col % RhsCols, 0); + + params->lhs_base_ptr = lhs.data + start_row * lhs.layout.stride; + params->rhs_base_ptr = rhs.data + start_col * rhs.layout.stride; + params->flags = 0; + params->bias = params->zero_data; + if (spec.bias) { + params->bias = spec.bias; + params->flags |= RUY_ASM_FLAG_HAS_BIAS; + } + if (lhs.sums) { + params->lhs_sums = lhs.sums; + params->flags |= RUY_ASM_FLAG_HAS_LHS_SUMS; + } + if (rhs.sums) { + params->rhs_sums = rhs.sums; + params->flags |= RUY_ASM_FLAG_HAS_RHS_SUMS; + } + params->start_row = start_row; + params->start_col = start_col; + params->last_row = end_row - LhsCols; + params->last_col = end_col - RhsCols; + params->lhs_stride = lhs.layout.stride; + params->rhs_stride = rhs.layout.stride; + params->dst_stride = sizeof(DstScalar) * dst->layout.stride; + params->lhs_zero_point = lhs.zero_point; + params->rhs_zero_point = rhs.zero_point; + params->dst_zero_point = dst->zero_point; + params->depth = depth; + params->prod_zp_depth = lhs.zero_point * rhs.zero_point * depth; + if (spec.multiplier_fixedpoint_perchannel) { + params->flags |= RUY_ASM_FLAG_NEEDS_LEFT_SHIFT; + params->flags |= RUY_ASM_FLAG_HAS_PERCHANNEL; + params->multiplier_fixedpoint = spec.multiplier_fixedpoint_perchannel; + params->multiplier_exponent = spec.multiplier_exponent_perchannel; + } else { + if (spec.multiplier_exponent > 0) { + params->flags |= RUY_ASM_FLAG_NEEDS_LEFT_SHIFT; + } + params->multiplier_fixedpoint = params->multiplier_fixedpoint_buf; + params->multiplier_exponent = params->multiplier_exponent_buf; + for (int i = 0; i < LhsCols; i++) { + params->multiplier_fixedpoint_buf[i] = spec.multiplier_fixedpoint; + params->multiplier_exponent_buf[i] = spec.multiplier_exponent; + } + } + params->clamp_min = spec.clamp_min; + params->clamp_max = spec.clamp_max; + params->dst_rows = dst->layout.rows; + params->dst_cols = dst->layout.cols; + + RUY_DCHECK_LT(params->last_row, params->dst_rows); + RUY_DCHECK_LT(params->last_col, params->dst_cols); + + params->dst_type_id = DstTypeId::kValue; + params->dst_base_ptr = + dst->data.get() + start_col * dst->layout.stride + start_row; +} + +template +struct KernelParamsFloat { + const float* lhs_base_ptr; + const float* rhs_base_ptr; + float* dst_base_ptr; + const float* bias; + std::int32_t start_row; + std::int32_t start_col; + std::int32_t last_row; + std::int32_t last_col; + std::int32_t dst_rows; + std::int32_t dst_cols; + std::int32_t lhs_stride; + std::int32_t rhs_stride; + std::int32_t dst_stride; + std::int32_t depth; + float clamp_min; + float clamp_max; + std::uint8_t flags; + const float zero_data[LhsCols] = {0}; + float dst_tmp_buf[LhsCols * RhsCols]; +}; + +template +inline void MakeKernelParamsFloat(const PackedMatrix& lhs, + const PackedMatrix& rhs, + const BasicSpec& spec, + int start_row, int start_col, int end_row, + int end_col, Matrix* dst, + KernelParamsFloat* params) { + const int depth = lhs.layout.rows; + RUY_DCHECK_EQ(start_row % LhsCols, 0); + RUY_DCHECK_EQ(start_col % RhsCols, 0); + RUY_DCHECK_EQ(end_row % LhsCols, 0); + RUY_DCHECK_EQ(end_col % RhsCols, 0); + + params->lhs_base_ptr = lhs.data + start_row * lhs.layout.stride; + params->rhs_base_ptr = rhs.data + start_col * rhs.layout.stride; + params->dst_base_ptr = + dst->data.get() + start_col * dst->layout.stride + start_row; + + std::uint8_t flags = 0; + params->bias = params->zero_data; + if (spec.bias) { + params->bias = spec.bias; + flags |= RUY_ASM_FLAG_HAS_BIAS; + } + params->flags = flags; + params->start_row = start_row; + params->start_col = start_col; + params->last_row = end_row - LhsCols; + params->last_col = end_col - RhsCols; + params->lhs_stride = sizeof(float) * lhs.layout.stride; + params->rhs_stride = sizeof(float) * rhs.layout.stride; + params->dst_stride = sizeof(float) * dst->layout.stride; + params->depth = depth; + params->clamp_min = spec.clamp_min; + params->clamp_max = spec.clamp_max; + params->dst_rows = dst->layout.rows; + params->dst_cols = dst->layout.cols; + + RUY_DCHECK_LT(params->last_row, params->dst_rows); + RUY_DCHECK_LT(params->last_col, params->dst_cols); +} + +#else // ((RUY_PLATFORM(NEON_64) || RUY_PLATFORM(NEON_32)) && + // RUY_OPT_ENABLED(RUY_OPT_ASM)) || RUY_PLATFORM(X86) + +template +struct KernelParams8bit {}; + +template +struct KernelParamsFloat {}; + +#endif // ((RUY_PLATFORM(NEON_64) || RUY_PLATFORM(NEON_32)) && + // RUY_OPT_ENABLED(RUY_OPT_ASM)) || RUY_PLATFORM(X86) + +} // namespace ruy + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_KERNEL_COMMON_H_ diff --git a/ruy/kernel_sse42.cc b/ruy/kernel_sse42.cc new file mode 100644 index 0000000..747ca1c --- /dev/null +++ b/ruy/kernel_sse42.cc @@ -0,0 +1,428 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "ruy/check_macros.h" +#include "ruy/kernel.h" +#include "ruy/opt_set.h" +#include "ruy/platform.h" +#include "ruy/profiler/instrumentation.h" + +#if RUY_PLATFORM(SSE42) && RUY_OPT_ENABLED(RUY_OPT_ASM) +#include // IWYU pragma: keep +#endif + +namespace ruy { + +#if !(RUY_PLATFORM(SSE42) && RUY_OPT_ENABLED(RUY_OPT_ASM)) + +void Kernel8bitSse42(const KernelParams8bit<8, 8>& params) { + // CPU-ID-based checks should disable the path that would reach this point. + RUY_DCHECK(false); +} + +void KernelFloatSse42(const KernelParamsFloat<8, 8>& params) { + // CPU-ID-based checks should disable the path that would reach this point. + RUY_DCHECK(false); +} + +#else // RUY_PLATFORM(SSE42) && RUY_OPT_ENABLED(RUY_OPT_ASM) + +static constexpr int kAvxFloatBlockSize = 8; +static constexpr int kAvx8bitBlockSize = 8; +static constexpr int kAvx8bitInnerSize = 4; + +// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. +// Optimization is not finished. In particular the dimensions of the kernel +// blocks can be changed as desired. +// +// When removing this comment, update profiling label below. +void Kernel8bitSse42(const KernelParams8bit<8, 8>& params) { + profiler::ScopeLabel label("Kernel kSse42 8-bit (UNFINISHED)"); + std::int32_t accum_data[kAvx8bitBlockSize][kAvx8bitBlockSize]; + + int bias_ptr_block_increment = + params.flags & RUY_ASM_FLAG_HAS_BIAS ? kAvx8bitBlockSize : 0; + + const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; + void* dst_col_ptr = params.dst_base_ptr; + const std::int32_t* bias_col_ptr = params.bias; + if (params.flags & RUY_ASM_FLAG_HAS_BIAS) { + bias_col_ptr += params.start_row; + } + + for (int col = params.start_col; col <= params.last_col; + col += kAvx8bitBlockSize) { + const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; + void* dst_ptr = dst_col_ptr; + const std::int32_t* bias_ptr = bias_col_ptr; + + for (int row = params.start_row; row <= params.last_row; + row += kAvx8bitBlockSize) { + const int residual_rows = + std::min(params.dst_rows - row, kAvx8bitBlockSize); + const int residual_cols = + std::min(params.dst_cols - col, kAvx8bitBlockSize); + + // Initialize with bias. + std::int32_t initial_accum_data[kAvx8bitBlockSize]; + for (int i = 0; i < kAvx8bitBlockSize; ++i) { + initial_accum_data[i] = 0; + } + for (int i = 0; i < residual_rows; ++i) { + initial_accum_data[i] = bias_ptr[i]; + } + for (int j = 0; j < kAvx8bitBlockSize; ++j) { + for (int i = 0; i < kAvx8bitBlockSize; ++i) { + accum_data[j][i] = initial_accum_data[i]; + } + } + bias_ptr += bias_ptr_block_increment; + + std::int8_t lhs_data[kAvx8bitBlockSize][kAvx8bitInnerSize]; + std::int8_t rhs_data[kAvx8bitBlockSize][kAvx8bitInnerSize]; + const std::int8_t* lhs_ptr = lhs_col_ptr; + const std::int8_t* rhs_ptr = rhs_col_ptr; + for (int d = 0; d < params.depth; d += kAvx8bitInnerSize) { + for (int i = 0; i < kAvx8bitBlockSize; ++i) { + for (int x = 0; x < kAvx8bitInnerSize; ++x) { + lhs_data[i][x] = lhs_ptr[i * kAvx8bitInnerSize + x]; + rhs_data[i][x] = rhs_ptr[i * kAvx8bitInnerSize + x]; + } + } + for (int j = 0; j < kAvx8bitBlockSize; ++j) { + for (int i = 0; i < kAvx8bitBlockSize; ++i) { + for (int x = 0; x < kAvx8bitInnerSize; ++x) { + accum_data[j][i] += lhs_data[i][x] * rhs_data[j][x]; + } + } + } + lhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize; + rhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize; + } + + if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && params.rhs_zero_point) { + for (int j = 0; j < kAvx8bitBlockSize; ++j) { + for (int i = 0; i < kAvx8bitBlockSize; ++i) { + accum_data[j][i] -= + params.rhs_zero_point * params.lhs_sums[row + i]; + } + } + } + if ((params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && params.lhs_zero_point) { + for (int j = 0; j < kAvx8bitBlockSize; ++j) { + for (int i = 0; i < kAvx8bitBlockSize; ++i) { + accum_data[j][i] -= + params.lhs_zero_point * params.rhs_sums[col + j]; + } + } + } + if (params.lhs_zero_point && params.rhs_zero_point) { + for (int j = 0; j < kAvx8bitBlockSize; ++j) { + for (int i = 0; i < kAvx8bitBlockSize; ++i) { + accum_data[j][i] += params.prod_zp_depth; + } + } + } + + if (params.dst_type_id != DstTypeId::kValue) { + std::int32_t m_vector[kAvx8bitBlockSize]; + std::int32_t e_vector[kAvx8bitBlockSize]; + // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT. + if (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) { + int i = 0; + for (; i < residual_rows; ++i) { + m_vector[i] = params.multiplier_fixedpoint[row + i]; + e_vector[i] = params.multiplier_exponent[row + i]; + } + for (; i < kAvx8bitBlockSize; ++i) { + m_vector[i] = m_vector[0]; + e_vector[i] = e_vector[0]; + } + } else { + // These arrays have size LhsCols, and are pre-filled. + for (int i = 0; i < kAvx8bitBlockSize; ++i) { + m_vector[i] = params.multiplier_fixedpoint[i]; + e_vector[i] = params.multiplier_exponent[i]; + } + } + for (int j = 0; j < kAvx8bitBlockSize; ++j) { + for (int i = 0; i < kAvx8bitBlockSize; ++i) { + accum_data[j][i] = MultiplyByQuantizedMultiplier( + accum_data[j][i], m_vector[i], e_vector[i]); + } + } + + if (params.dst_zero_point) { + for (int j = 0; j < kAvx8bitBlockSize; ++j) { + for (int i = 0; i < kAvx8bitBlockSize; ++i) { + accum_data[j][i] += params.dst_zero_point; + } + } + } + + for (int j = 0; j < kAvx8bitBlockSize; ++j) { + for (int i = 0; i < kAvx8bitBlockSize; ++i) { + accum_data[j][i] = + std::min(accum_data[j][i], params.clamp_max); + accum_data[j][i] = + std::max(accum_data[j][i], params.clamp_min); + } + } + } + + const bool store_full_block = (residual_rows == kAvx8bitBlockSize) && + (residual_cols == kAvx8bitBlockSize); + + if (params.dst_type_id == DstTypeId::kValue) { + std::int8_t* tmp_ptr = + store_full_block + ? static_cast(dst_ptr) + : const_cast( + reinterpret_cast(params.dst_tmp_buf)); + const int block_col_offset = + store_full_block ? params.dst_stride / sizeof(std::int8_t) + : kAvx8bitBlockSize; + for (int j = 0; j < kAvx8bitBlockSize; ++j) { + for (int i = 0; i < kAvx8bitBlockSize; ++i) { + tmp_ptr[i] = accum_data[j][i]; + } + tmp_ptr += block_col_offset; + } + + if (!store_full_block) { + const std::int8_t* block_ptr = + reinterpret_cast(params.dst_tmp_buf); + for (int j = 0; j < residual_cols; ++j) { + for (int i = 0; i < residual_rows; ++i) { + static_cast( + dst_ptr)[j * params.dst_stride / sizeof(std::int8_t) + i] = + block_ptr[i]; + } + block_ptr += kAvx8bitBlockSize; + } + } + dst_ptr = static_cast(static_cast(dst_ptr) + + kAvx8bitBlockSize); + } else if (params.dst_type_id == DstTypeId::kValue) { + std::uint8_t* tmp_ptr = store_full_block + ? static_cast(dst_ptr) + : const_cast( + reinterpret_cast( + params.dst_tmp_buf)); + const int block_col_offset = + store_full_block ? params.dst_stride : kAvx8bitBlockSize; + for (int j = 0; j < kAvx8bitBlockSize; ++j) { + for (int i = 0; i < kAvx8bitBlockSize; ++i) { + tmp_ptr[i] = accum_data[j][i]; + } + tmp_ptr += block_col_offset; + } + + if (!store_full_block) { + const std::uint8_t* block_ptr = + reinterpret_cast(params.dst_tmp_buf); + for (int j = 0; j < residual_cols; ++j) { + for (int i = 0; i < residual_rows; ++i) { + static_cast( + dst_ptr)[j * params.dst_stride / sizeof(std::uint8_t) + i] = + block_ptr[i]; + } + block_ptr += kAvx8bitBlockSize; + } + } + dst_ptr = static_cast(static_cast(dst_ptr) + + kAvx8bitBlockSize); + } else if (params.dst_type_id == DstTypeId::kValue) { + if (store_full_block) { + std::int16_t* tmp_ptr = static_cast(dst_ptr); + const int block_col_offset = params.dst_stride / sizeof(std::int16_t); + for (int j = 0; j < kAvx8bitBlockSize; ++j) { + for (int i = 0; i < kAvx8bitBlockSize; ++i) { + tmp_ptr[i] = accum_data[j][i]; + } + tmp_ptr += block_col_offset; + } + } else { + std::int16_t* tmp_ptr = const_cast( + reinterpret_cast(params.dst_tmp_buf)); + const int block_col_offset = kAvx8bitBlockSize; + for (int j = 0; j < kAvx8bitBlockSize; ++j) { + for (int i = 0; i < kAvx8bitBlockSize; ++i) { + tmp_ptr[i] = accum_data[j][i]; + } + tmp_ptr += block_col_offset; + } + const std::int16_t* block_ptr = + reinterpret_cast(params.dst_tmp_buf); + std::int16_t* dst_block_ptr = static_cast(dst_ptr); + for (int j = 0; j < residual_cols; ++j) { + for (int i = 0; i < residual_rows; ++i) { + dst_block_ptr[i] = block_ptr[i]; + } + dst_block_ptr += params.dst_stride / sizeof(std::int16_t); + block_ptr += kAvx8bitBlockSize; + } + } + dst_ptr = static_cast(static_cast(dst_ptr) + + kAvx8bitBlockSize); + } else if (params.dst_type_id == DstTypeId::kValue) { + if (store_full_block) { + std::int32_t* tmp_ptr = static_cast(dst_ptr); + const int block_col_offset = params.dst_stride / sizeof(std::int32_t); + for (int j = 0; j < kAvx8bitBlockSize; ++j) { + for (int i = 0; i < kAvx8bitBlockSize; ++i) { + tmp_ptr[i] = accum_data[j][i]; + } + tmp_ptr += block_col_offset; + } + } else { + std::int32_t* dst_block_ptr = static_cast(dst_ptr); + for (int j = 0; j < residual_cols; ++j) { + for (int i = 0; i < residual_rows; ++i) { + dst_block_ptr[i] = accum_data[j][i]; + } + dst_block_ptr += params.dst_stride / sizeof(std::int32_t); + } + } + dst_ptr = static_cast(static_cast(dst_ptr) + + kAvx8bitBlockSize); + } else { + RUY_DCHECK(false); + } + + lhs_col_ptr += kAvx8bitBlockSize * params.lhs_stride; + } // End row-block loop. + + dst_col_ptr = static_cast(static_cast(dst_col_ptr) + + kAvx8bitBlockSize * params.dst_stride); + rhs_col_ptr += kAvx8bitBlockSize * params.rhs_stride; + } // End col-block loop. +} // NOLINT(readability/fn_size) + +// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. +// Optimization is not finished. In particular the dimensions of the kernel +// blocks can be changed as desired. +// +// When removing this comment, update profiling label below. +void KernelFloatSse42(const KernelParamsFloat<8, 8>& params) { + profiler::ScopeLabel label("Kernel kSse42 float (UNFINISHED)"); + + float lhs_data[kAvxFloatBlockSize]; + float rhs_data[kAvxFloatBlockSize]; + float accum_data[kAvxFloatBlockSize][kAvxFloatBlockSize]; + int bias_ptr_block_increment = + params.flags & RUY_ASM_FLAG_HAS_BIAS ? kAvxFloatBlockSize : 0; + + const float* rhs_col_ptr = params.rhs_base_ptr; + float* dst_col_ptr = params.dst_base_ptr; + const float* bias_col_ptr = params.bias; + if (params.flags & RUY_ASM_FLAG_HAS_BIAS) { + bias_col_ptr += params.start_row; + } + + for (int col = params.start_col; col <= params.last_col; + col += kAvxFloatBlockSize) { + const float* lhs_col_ptr = params.lhs_base_ptr; + float* dst_ptr = dst_col_ptr; + const float* bias_ptr = bias_col_ptr; + + for (int row = params.start_row; row <= params.last_row; + row += kAvxFloatBlockSize) { + const int residual_rows = + std::min(params.dst_rows - row, kAvxFloatBlockSize); + const int residual_cols = + std::min(params.dst_cols - col, kAvxFloatBlockSize); + + // Initialize with bias. + float initial_accum_data[kAvxFloatBlockSize]; + for (int i = 0; i < kAvxFloatBlockSize; ++i) { + initial_accum_data[i] = 0.0f; + } + for (int i = 0; i < residual_rows; ++i) { + initial_accum_data[i] = bias_ptr[i]; + } + for (int j = 0; j < kAvxFloatBlockSize; ++j) { + for (int i = 0; i < kAvxFloatBlockSize; ++i) { + accum_data[j][i] = initial_accum_data[i]; + } + } + bias_ptr += bias_ptr_block_increment; + + const float* lhs_ptr = lhs_col_ptr; + const float* rhs_ptr = rhs_col_ptr; + for (int d = 0; d < params.depth; ++d) { + for (int i = 0; i < kAvxFloatBlockSize; ++i) { + lhs_data[i] = lhs_ptr[i]; + rhs_data[i] = rhs_ptr[i]; + } + for (int j = 0; j < kAvxFloatBlockSize; ++j) { + for (int i = 0; i < kAvxFloatBlockSize; ++i) { + accum_data[j][i] += lhs_data[i] * rhs_data[j]; + } + } + lhs_ptr += kAvxFloatBlockSize; + rhs_ptr += kAvxFloatBlockSize; + } + + for (int j = 0; j < kAvxFloatBlockSize; ++j) { + for (int i = 0; i < kAvxFloatBlockSize; ++i) { + accum_data[j][i] = + std::min(accum_data[j][i], params.clamp_max); + accum_data[j][i] = + std::max(accum_data[j][i], params.clamp_min); + } + } + + const bool store_full_block = (residual_rows == kAvxFloatBlockSize) && + (residual_cols == kAvxFloatBlockSize); + + { + float* block_ptr = + store_full_block ? dst_ptr : const_cast(params.dst_tmp_buf); + const int block_col_offset = store_full_block + ? params.dst_stride / sizeof(float) + : kAvxFloatBlockSize; + for (int j = 0; j < kAvxFloatBlockSize; ++j) { + for (int i = 0; i < kAvxFloatBlockSize; ++i) { + block_ptr[i] = accum_data[j][i]; + } + block_ptr += block_col_offset; + } + } + if (!store_full_block) { + const float* block_ptr = params.dst_tmp_buf; + for (int j = 0; j < residual_cols; ++j) { + for (int i = 0; i < residual_rows; ++i) { + dst_ptr[j * params.dst_stride / sizeof(float) + i] = block_ptr[i]; + } + block_ptr += kAvxFloatBlockSize; + } + } + + lhs_col_ptr += kAvxFloatBlockSize * params.lhs_stride / sizeof(float); + dst_ptr += kAvxFloatBlockSize; + } // End row-block loop. + + dst_col_ptr += kAvxFloatBlockSize * params.dst_stride / sizeof(float); + rhs_col_ptr += kAvxFloatBlockSize * params.rhs_stride / sizeof(float); + } // End col-block loop. +} + +#endif // RUY_PLATFORM(SSE42) && RUY_OPT_ENABLED(RUY_OPT_ASM) + +} // namespace ruy diff --git a/ruy/kernel_x86.h b/ruy/kernel_x86.h new file mode 100644 index 0000000..dbcf42b --- /dev/null +++ b/ruy/kernel_x86.h @@ -0,0 +1,222 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_KERNEL_X86_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_KERNEL_X86_H_ + +#include + +#include "ruy/common.h" +#include "ruy/internal_matrix.h" +#include "ruy/kernel_common.h" +#include "ruy/matrix.h" +#include "ruy/opt_set.h" +#include "ruy/path.h" +#include "ruy/platform.h" +#include "ruy/spec.h" +#include "ruy/tune.h" + +namespace ruy { + +#if RUY_PLATFORM(X86) +// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. +// Optimization is not finished. In particular the dimensions of the kernel +// blocks can be changed as desired. +// +void Kernel8bitSse42(const KernelParams8bit<8, 8>& params); + +template +struct Kernel> { + Tuning tuning = Tuning::kAuto; + using LhsLayout = FixedKernelLayout; + using RhsLayout = FixedKernelLayout; + explicit Kernel(Tuning tuning_) : tuning(tuning_) {} + void Run(const PackedMatrix& lhs, + const PackedMatrix& rhs, + const BasicSpec& spec, int start_row, + int start_col, int end_row, int end_col, + Matrix* dst) const { + KernelParams8bit params; + MakeKernelParams8bit(lhs, rhs, spec, start_row, start_col, end_row, end_col, + dst, ¶ms); + Kernel8bitSse42(params); + } +}; + +void KernelFloatSse42(const KernelParamsFloat<8, 8>& params); + +template <> +struct Kernel> { + Tuning tuning = Tuning::kAuto; + using LhsLayout = FixedKernelLayout; + using RhsLayout = FixedKernelLayout; + explicit Kernel(Tuning tuning_) : tuning(tuning_) {} + void Run(const PackedMatrix& lhs, const PackedMatrix& rhs, + const BasicSpec& spec, int start_row, int start_col, + int end_row, int end_col, Matrix* dst) const { + KernelParamsFloat params; + MakeKernelParamsFloat(lhs, rhs, spec, start_row, start_col, end_row, + end_col, dst, ¶ms); + KernelFloatSse42(params); + } +}; + +void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params); +void Kernel8bitAvx512SingleCol(const KernelParams8bit<16, 16>& params); + +template +struct Kernel> { + Tuning tuning = Tuning::kAuto; + using LhsLayout = FixedKernelLayout; + using RhsLayout = FixedKernelLayout; + explicit Kernel(Tuning tuning_) : tuning(tuning_) {} + void Run(const PackedMatrix& lhs, + const PackedMatrix& rhs, + const BasicSpec& spec, int start_row, + int start_col, int end_row, int end_col, + Matrix* dst) const { + KernelParams8bit params; + MakeKernelParams8bit(lhs, rhs, spec, start_row, start_col, end_row, end_col, + dst, ¶ms); + if (dst->layout.cols == 1) { + Kernel8bitAvx512SingleCol(params); + } else { + Kernel8bitAvx512(params); + } + } +}; + +void KernelFloatAvx512(const KernelParamsFloat<16, 16>& params); +void KernelFloatAvx512SingleCol(const KernelParamsFloat<16, 16>& param); + +template <> +struct Kernel> { + Tuning tuning = Tuning::kAuto; + using LhsLayout = FixedKernelLayout; + using RhsLayout = FixedKernelLayout; + explicit Kernel(Tuning tuning_) : tuning(tuning_) {} + void Run(const PackedMatrix& lhs, const PackedMatrix& rhs, + const BasicSpec& spec, int start_row, int start_col, + int end_row, int end_col, Matrix* dst) const { + KernelParamsFloat params; + MakeKernelParamsFloat(lhs, rhs, spec, start_row, start_col, end_row, + end_col, dst, ¶ms); + if (dst->layout.cols == 1) { + KernelFloatAvx512SingleCol(params); + } else { + KernelFloatAvx512(params); + } + } +}; + +void Kernel8bitAvx2(const KernelParams8bit<8, 8>& params); +void Kernel8bitAvx2SingleCol(const KernelParams8bit<8, 8>& params); + +template +struct Kernel> { + Tuning tuning = Tuning::kAuto; + using LhsLayout = FixedKernelLayout; + using RhsLayout = FixedKernelLayout; + explicit Kernel(Tuning tuning_) : tuning(tuning_) {} + void Run(const PackedMatrix& lhs, + const PackedMatrix& rhs, + const BasicSpec& spec, int start_row, + int start_col, int end_row, int end_col, + Matrix* dst) const { + KernelParams8bit params; + MakeKernelParams8bit(lhs, rhs, spec, start_row, start_col, end_row, end_col, + dst, ¶ms); + if (dst->layout.cols == 1) { + Kernel8bitAvx2SingleCol(params); + } else { + Kernel8bitAvx2(params); + } + } +}; + +void KernelFloatAvx2(const KernelParamsFloat<8, 8>& params); +void KernelFloatAvx2SingleCol(const KernelParamsFloat<8, 8>& params); + +template <> +struct Kernel> { + Tuning tuning = Tuning::kAuto; + using LhsLayout = FixedKernelLayout; + using RhsLayout = FixedKernelLayout; + explicit Kernel(Tuning tuning_) : tuning(tuning_) {} + void Run(const PackedMatrix& lhs, const PackedMatrix& rhs, + const BasicSpec& spec, int start_row, int start_col, + int end_row, int end_col, Matrix* dst) const { + KernelParamsFloat params; + MakeKernelParamsFloat(lhs, rhs, spec, start_row, start_col, end_row, + end_col, dst, ¶ms); + if (dst->layout.cols == 1) { + KernelFloatAvx2SingleCol(params); + } else { + KernelFloatAvx2(params); + } + } +}; + +// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. +// Optimization is not finished. In particular the dimensions of the kernel +// blocks can be changed as desired. +// +void Kernel8bitAvxVnni(const KernelParams8bit<16, 16>& params); + +template +struct Kernel> { + Tuning tuning = Tuning::kAuto; + using LhsLayout = FixedKernelLayout; + using RhsLayout = FixedKernelLayout; + explicit Kernel(Tuning tuning_) : tuning(tuning_) {} + void Run(const PackedMatrix& lhs, + const PackedMatrix& rhs, + const BasicSpec& spec, int start_row, + int start_col, int end_row, int end_col, + Matrix* dst) const { + KernelParams8bit params; + MakeKernelParams8bit(lhs, rhs, spec, start_row, start_col, end_row, end_col, + dst, ¶ms); + Kernel8bitAvxVnni(params); + } +}; + +void KernelFloatAvxVnni(const KernelParamsFloat<16, 16>& params); + +template <> +struct Kernel> { + Tuning tuning = Tuning::kAuto; + using LhsLayout = FixedKernelLayout; + using RhsLayout = FixedKernelLayout; + explicit Kernel(Tuning tuning_) : tuning(tuning_) {} + void Run(const PackedMatrix& lhs, const PackedMatrix& rhs, + const BasicSpec& spec, int start_row, int start_col, + int end_row, int end_col, Matrix* dst) const { + KernelParamsFloat params; + MakeKernelParamsFloat(lhs, rhs, spec, start_row, start_col, end_row, + end_col, dst, ¶ms); + KernelFloatAvxVnni(params); + } +}; + +#endif // RUY_PLATFORM(X86) + +} // namespace ruy + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_KERNEL_X86_H_ diff --git a/ruy/matrix.h b/ruy/matrix.h new file mode 100644 index 0000000..2dcb081 --- /dev/null +++ b/ruy/matrix.h @@ -0,0 +1,182 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_MATRIX_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_MATRIX_H_ + +#include +#include // IWYU pragma: keep +#include + +#include "ruy/check_macros.h" + +namespace ruy { + +// Layout storage order. Here and elsewhere, 'col' is short for 'column'. +// 'column-major' means that each column is contiguous in memory. +enum class Order : std::uint8_t { kColMajor, kRowMajor }; + +// Describes the shape and storage layout of a matrix. +struct Layout final { + std::int32_t rows = 0; + std::int32_t cols = 0; + // Stride is the offset between two adjacent matrix elements + // in the non-contiguous direction. + std::int32_t stride = 0; + Order order = Order::kColMajor; +}; + +namespace detail { + +// Thin wrapper around a pointer that tracks its constness dynamically. +// +// This is our take on the C++ problem of enforcing constness of data +// wrapped in a containers class: it's not worth the hassle of trying to +// make it fully work at compile-time. +// Instead, we only enforce constness at runtime, and to make it +// zero-overhead, we only enforce it in debug builds. +template +class ConstCheckingPtr final { + public: + using element_type = T; + + // Convenience methods. Most `set` calls go through these. + ConstCheckingPtr& operator=(T* ptr) { + set(ptr); + return *this; + } + ConstCheckingPtr& operator=(const T* ptr) { + set(ptr); + return *this; + } + ConstCheckingPtr& operator=(std::nullptr_t) { + set(static_cast(nullptr)); + return *this; + } + + // Core accessors. These encapsulate the main logic: + // - for `set`, the constness of the argument determines whether internal + // pointer should be tracked as const/mutable. + // - for `get`, the constness of `this` determines whether the call + // counts as a const or mutable use of the internal pointer. + void set(T* ptr) { + ptr_ = ptr; + set_mutable(true); + } + void set(const T* ptr) { + ptr_ = ptr; + set_mutable(false); + } + T* get() /* NOT const */ { + assert_mutable(); + return const_cast(ptr_); + } + const T* get() const { return ptr_; } + + private: + static_assert(!std::is_const::value, ""); + const T* ptr_ = nullptr; +#ifndef NDEBUG + bool is_mutable_ = true; + void set_mutable(bool val) { is_mutable_ = val; } + void assert_mutable() { RUY_DCHECK(is_mutable_); } +#else + void set_mutable(bool) {} + void assert_mutable() {} +#endif +}; + +} // namespace detail + +// A Matrix is really what Eigen and gemmlowp would have called a 'matrix map': +// it merely wraps existing data as a matrix. It doesn't own any buffer. +// Scalar may be any floating-point or integral type. When integral, it may be +// signed or unsigned. +template +struct Matrix final { + Matrix& operator=(const Matrix& other) { + data = other.data; + cacheable = other.cacheable; + layout = other.layout; + zero_point = other.zero_point; + return *this; + } + + // The underlying buffer wrapped by this matrix. + detail::ConstCheckingPtr data; + // The shape and data layout of this matrix. + Layout layout; + // The zero_point, i.e. which Scalar value is to be interpreted as zero. + // When Scalar is floating-point, this must be 0. + Scalar zero_point = 0; + // Clients of Ruy must set this flag to enable any caching behavior. Doesn't + // impact numerical results, but caching can impact observable metrics like + // latency, memory usage, power, etc. + bool cacheable = false; +}; + +inline void MakeSimpleLayout(int rows, int cols, Order order, Layout* layout) { + layout->rows = rows; + layout->cols = cols; + layout->order = order; + layout->stride = order == Order::kColMajor ? rows : cols; +} + +// Opaque data structure representing a pre-packed matrix, as obtained from +// Ruy's advanced API. +struct PrepackedMatrix { + void* data = nullptr; + std::size_t data_size = 0; + void* sums = nullptr; + std::size_t sums_size = 0; +}; + +template +StreamType& operator<<(StreamType& stream, const Matrix& mat) { + for (int row = 0; row < mat.layout.rows; row++) { + for (int col = 0; col < mat.layout.cols; col++) { + stream << static_cast(Element(mat, row, col)) << " "; + } + stream << "\n"; + } + return stream; +} + +// Compile-time version of KernelLayout, used to declare kernel layouts in a +// way that can be consumed by compile-time logic. +// See how partial specializations of Kernel use it to declare their layouts. +// The only reason why this is currently part of the public API is to +// allow testing various layouts for the Path::kStandardCpp kernel, as a +// testing-only feature. See Spec::StandardCppKernelLhsLayout. +template +struct FixedKernelLayout { + static constexpr Order kOrder = tOrder; + static constexpr int kRows = tRows; + static constexpr int kCols = tCols; +}; + +#if (__cplusplus < 201703L) +// A static constexpr data member is automatically inline and should not require +// redeclaration without an initializer. This is actually deprecated from C++17 +// onwards. Clang with -O0 without this can fail to link. +template +constexpr int FixedKernelLayout::kCols; +template +constexpr int FixedKernelLayout::kRows; +#endif + +} // namespace ruy + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_MATRIX_H_ diff --git a/ruy/opt_set.h b/ruy/opt_set.h new file mode 100644 index 0000000..fef0107 --- /dev/null +++ b/ruy/opt_set.h @@ -0,0 +1,51 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_OPT_SET_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_OPT_SET_H_ + +// RUY_OPT_SET is a compile-time API that Ruy provides for enabling/disabling +// certain optimizations. It should be used by defining that macro on the +// compiler command line. +// +// Each bit in RUY_OPT_SET controls a particular optimization done in Ruy. +#define RUY_OPT_INTRINSICS 0x1 +#define RUY_OPT_ASM 0x2 +#define RUY_OPT_TUNING 0x4 +#define RUY_OPT_FAT_KERNEL 0x8 +#define RUY_OPT_NATIVE_ROUNDING 0x10 +#define RUY_OPT_AVOID_ALIASING 0x20 +#define RUY_OPT_MAX_STREAMING 0x40 +#define RUY_OPT_PACK_AHEAD 0x80 +#define RUY_OPT_PREFETCH_LOAD 0x100 +#define RUY_OPT_PREFETCH_STORE 0x200 +#define RUY_OPT_FRACTAL_Z 0x400 +#define RUY_OPT_FRACTAL_U 0x800 +#define RUY_OPT_FRACTAL_HILBERT 0x1000 + +#if !defined(RUY_OPT_SET) +#ifdef RUY_OPTIMIZE_FOR_MATMUL_BENCHMARK +// Load prefetching is detrimental in matrix multiplication benchmarks. +// Store prefetching is not. +#define RUY_OPT_SET (~RUY_OPT_PREFETCH_LOAD) +#else +// Default to all optimizations. +#define RUY_OPT_SET (~0) +#endif +#endif + +#define RUY_OPT_ENABLED(ruy_opt) ((RUY_OPT_SET & ruy_opt) != 0) + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_OPT_SET_H_ diff --git a/ruy/pack.h b/ruy/pack.h new file mode 100644 index 0000000..e066663 --- /dev/null +++ b/ruy/pack.h @@ -0,0 +1,98 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// # What is "packing"? +// +// Before feeding data to the gemm kernels (the parts of Ruy that do lots +// of multiply-add operations), Ruy first performs a data transformation (which +// we call "packing") on the input matrices. This transformation has two main +// goals: +// - rearrange data into blocks that are a convenient size/layout for the gemm +// kernels to consume. This helps make the memory access pattern of the gemm +// kernel simpler and more contiguous, and puts the data in a layout most +// convenient for specific arithmetic instructions in the gemm kernel. +// - compute row/column sums needed for handling quantization with non-symmetric +// zero points. +// +// # Simplified algorithmic analysis of packing +// +// Packing is a relatively simple transformation which does a small constant +// amount of work on each element of an input matrix, and hence for an NxM +// matrix performs O(N*M) work. If N and M are of the same order, then this is +// O(N^2) work. +// +// A NxKxM matrix multiplication requires N*K*M multiply-accumulate operations. +// Note that if N, K, and M are all the same order, then the number of +// multiply-accumulate operations is O(N^3). +// +// Thus, the O(N^2) cost of packing is small compared to the O(N^3) work, in the +// case of all dimensions being roughly the same order. +// +// # Packing cost can be significant +// +// When matrix * matrix multiplications begin to look more like matrix * vector +// multiplications, packing cost can become significant. We sometimes call these +// cases "gemv-like". +// +// Continuing the algorithmic analysis above, if we consider a case where an +// NxKxM matrix multiplication has either N = O(1) or M = O(1), then the +// situation is different. In this case, the multiply-accumulate work is only +// quadratic, so the quadratic cost of packing can be come significant. +// +// Another way to say this is that the cost of packing an input matrix (either +// the LHS or RHS) is amortized across the non-depth dimension of the opposite +// input matrix. Thus, when the LHS has very few rows or the RHS has very few +// columns, the cost of packing the opposite input matrix can become +// significant. +// +// As a rough rule of thumb, the cost of packing starts to become significant +// when either N or M is below 32 (and other dimensions are hundreds), with very +// significant packing costs at 8 or below. This varies by data type, Path, and +// tuning, so these numbers are only rough guides. +// +// One practical use case that is affected by this is inference of +// fully connected neural network layers with a low batch size. The weight +// matrix (which is a constant for inference) is the one affected by significant +// packing cost. +// +// Ruy provides an API in ruy_advanced.h for advanced users to pre-pack +// input matrices that are affected by significant packing costs. +// +// # Implementation notes +// +// Ruy's packing routines always operate on a range of columns and can be +// applied to either the LHS or RHS. This is possible because Ruy internally +// implements a TrMul, so the accumulation along depth is done along columns of +// both the LHS and RHS (whereas for a normal Mul the accumulation along depth +// for the LHS is along rows). As another example, we are always computing +// column sums for quantization (and never row sums, since the LHS is +// transposed). + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PACK_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PACK_H_ + +#include "ruy/platform.h" + +// IWYU pragma: begin_exports +#if RUY_PLATFORM(NEON) +#include "ruy/pack_arm.h" +#elif RUY_PLATFORM(X86) +#include "ruy/pack_x86.h" +#else +#include "ruy/pack_common.h" +#endif +// IWYU pragma: end_exports + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PACK_H_ diff --git a/ruy/pack_arm.cc b/ruy/pack_arm.cc new file mode 100644 index 0000000..8b68a39 --- /dev/null +++ b/ruy/pack_arm.cc @@ -0,0 +1,1936 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include "ruy/common.h" +#include "ruy/opt_set.h" +#include "ruy/pack.h" +#include "ruy/platform.h" +#include "ruy/profiler/instrumentation.h" + +namespace ruy { + +#if RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM) + +void Pack8bitNeonOutOfOrder(const void* src_ptr0, const void* src_ptr1, + const void* src_ptr2, const void* src_ptr3, + int src_inc0, int src_inc1, int src_inc2, + int src_inc3, int src_rows, int src_zero_point, + std::int8_t* packed_ptr, int start_col, int end_col, + std::int32_t* sums_ptr, int input_xor) { + profiler::ScopeLabel label("Pack (kNeon, optimized for out-of-order cores)"); + asm volatile( + // clang-format off + "dup v26.16b, %w[input_xor]\n" + "mov w1, #0\n" + "dup v28.4s, wzr\n" + "dup v29.4s, wzr\n" + "dup v30.4s, wzr\n" + "dup v31.4s, wzr\n" + + "and w2, %w[rows], #-16\n" + "cmp w1, w2\n" + "beq 3f\n" + + "add w1, w1, #16\n" + "ld1 {v0.16b}, [%[src_ptr0]], %[src_inc0]\n" + "ld1 {v1.16b}, [%[src_ptr1]], %[src_inc1]\n" + "cmp w1, w2\n" + "ld1 {v2.16b}, [%[src_ptr2]], %[src_inc2]\n" + "ld1 {v3.16b}, [%[src_ptr3]], %[src_inc3]\n" + "beq 2f\n" + + "1:\n" + + "add w1, w1, #16\n" + "eor v4.16b, v0.16b, v26.16b\n" + "ld1 {v0.16b}, [%[src_ptr0]], %[src_inc0]\n" + "eor v5.16b, v1.16b, v26.16b\n" + "ld1 {v1.16b}, [%[src_ptr1]], %[src_inc1]\n" + "eor v6.16b, v2.16b, v26.16b\n" + "ld1 {v2.16b}, [%[src_ptr2]], %[src_inc2]\n" + "eor v7.16b, v3.16b, v26.16b\n" + "ld1 {v3.16b}, [%[src_ptr3]], %[src_inc3]\n" + + "saddlp v16.8h, v4.16b\n" + "str q4, [%[packed_ptr], #0]\n" + "saddlp v17.8h, v5.16b\n" + "str q5, [%[packed_ptr], #16]\n" + "saddlp v18.8h, v6.16b\n" + "str q6, [%[packed_ptr], #32]\n" + "saddlp v19.8h, v7.16b\n" + "str q7, [%[packed_ptr], #48]\n" + "sadalp v28.4s, v16.8h\n" + "cmp w1, w2\n" + "sadalp v29.4s, v17.8h\n" + "add %[packed_ptr], %[packed_ptr], #64\n" + "sadalp v30.4s, v18.8h\n" + "sadalp v31.4s, v19.8h\n" + + "bne 1b\n" + + "2:\n" + + "eor v4.16b, v0.16b, v26.16b\n" + "eor v5.16b, v1.16b, v26.16b\n" + "eor v6.16b, v2.16b, v26.16b\n" + "eor v7.16b, v3.16b, v26.16b\n" + + "saddlp v16.8h, v4.16b\n" + "str q4, [%[packed_ptr], #0]\n" + "saddlp v17.8h, v5.16b\n" + "str q5, [%[packed_ptr], #16]\n" + "saddlp v18.8h, v6.16b\n" + "str q6, [%[packed_ptr], #32]\n" + "saddlp v19.8h, v7.16b\n" + "str q7, [%[packed_ptr], #48]\n" + "sadalp v28.4s, v16.8h\n" + "sadalp v29.4s, v17.8h\n" + "sadalp v30.4s, v18.8h\n" + "sadalp v31.4s, v19.8h\n" + + "add %[packed_ptr], %[packed_ptr], #64\n" + + "3:\n" + + "ands w2, %w[rows], #15\n" + "beq 4f\n" + "dup v0.16b, %w[src_zero_point]\n" + "dup v1.16b, %w[src_zero_point]\n" + "dup v2.16b, %w[src_zero_point]\n" + "dup v3.16b, %w[src_zero_point]\n" +#define RUY_LOAD_ONE_ROW(R) \ + "cmp w2, #" #R "\n" \ + "beq 5f\n" \ + "ld1 { v0.b }[" #R "], [%[src_ptr0]], #1\n" \ + "ld1 { v1.b }[" #R "], [%[src_ptr1]], #1\n" \ + "ld1 { v2.b }[" #R "], [%[src_ptr2]], #1\n" \ + "ld1 { v3.b }[" #R "], [%[src_ptr3]], #1\n" + + RUY_LOAD_ONE_ROW(0) + RUY_LOAD_ONE_ROW(1) + RUY_LOAD_ONE_ROW(2) + RUY_LOAD_ONE_ROW(3) + RUY_LOAD_ONE_ROW(4) + RUY_LOAD_ONE_ROW(5) + RUY_LOAD_ONE_ROW(6) + RUY_LOAD_ONE_ROW(7) + RUY_LOAD_ONE_ROW(8) + RUY_LOAD_ONE_ROW(9) + RUY_LOAD_ONE_ROW(10) + RUY_LOAD_ONE_ROW(11) + RUY_LOAD_ONE_ROW(12) + RUY_LOAD_ONE_ROW(13) + RUY_LOAD_ONE_ROW(14) + RUY_LOAD_ONE_ROW(15) +#undef RUY_LOAD_ONE_ROW + "5:\n" + + "eor v4.16b, v0.16b, v26.16b\n" + "eor v5.16b, v1.16b, v26.16b\n" + "eor v6.16b, v2.16b, v26.16b\n" + "eor v7.16b, v3.16b, v26.16b\n" + + "saddlp v16.8h, v4.16b\n" + "saddlp v17.8h, v5.16b\n" + "saddlp v18.8h, v6.16b\n" + "saddlp v19.8h, v7.16b\n" + "sadalp v28.4s, v16.8h\n" + "sadalp v29.4s, v17.8h\n" + "sadalp v30.4s, v18.8h\n" + "sadalp v31.4s, v19.8h\n" + + "str q4, [%[packed_ptr], #0]\n" + "str q5, [%[packed_ptr], #16]\n" + "str q6, [%[packed_ptr], #32]\n" + "str q7, [%[packed_ptr], #48]\n" + "add %[packed_ptr], %[packed_ptr], #64\n" + + "4:\n" + + "addp v28.4s, v28.4s, v29.4s\n" + "addp v30.4s, v30.4s, v31.4s\n" + "addp v28.4s, v28.4s, v30.4s\n" + + "cmp %[sums_ptr], #0\n" + "beq 6f\n" + "st1 {v28.4s}, [%[sums_ptr]], #16\n" + "6:\n" + // clang-format on + + : [ src_ptr0 ] "+r"(src_ptr0), [ src_ptr1 ] "+r"(src_ptr1), + [ src_ptr2 ] "+r"(src_ptr2), [ src_ptr3 ] "+r"(src_ptr3), + [ packed_ptr ] "+r"(packed_ptr), [ sums_ptr ] "+r"(sums_ptr) + : [ src_inc0 ] "r"(static_cast(src_inc0)), + [ src_inc1 ] "r"(static_cast(src_inc1)), + [ src_inc2 ] "r"(static_cast(src_inc2)), + [ src_inc3 ] "r"(static_cast(src_inc3)), + [ rows ] "r"(src_rows), [ src_zero_point ] "r"(src_zero_point), + [ input_xor ] "r"(input_xor) + : "cc", "memory", "x1", "x2", "v0", "v1", "v2", "v3", "v4", "v5", "v6", + "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", + "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", + "v27", "v28", "v29", "v30", "v31"); +} +#endif + +#if RUY_PLATFORM(NEON_32) && RUY_OPT_ENABLED(RUY_OPT_ASM) + +#define RUY_OFFSET_SRC_PTR0 0 +#define RUY_OFFSET_SRC_PTR1 4 +#define RUY_OFFSET_SRC_PTR2 8 +#define RUY_OFFSET_SRC_PTR3 12 +#define RUY_OFFSET_SUMS_PTR 16 +#define RUY_OFFSET_PACKED_PTR 20 +#define RUY_OFFSET_SRC_INC0 24 +#define RUY_OFFSET_SRC_INC1 28 +#define RUY_OFFSET_SRC_INC2 32 +#define RUY_OFFSET_SRC_INC3 36 +#define RUY_OFFSET_SRC_ROWS 40 +#define RUY_OFFSET_SRC_ZERO_POINT 44 +#define RUY_OFFSET_INPUT_XOR 48 + +template +void CheckOffsetsInPackParams8bit(const Params&) { + static_assert(offsetof(Params, src_ptr0) == RUY_OFFSET_SRC_PTR0, ""); + static_assert(offsetof(Params, src_ptr1) == RUY_OFFSET_SRC_PTR1, ""); + static_assert(offsetof(Params, src_ptr2) == RUY_OFFSET_SRC_PTR2, ""); + static_assert(offsetof(Params, src_ptr3) == RUY_OFFSET_SRC_PTR3, ""); + static_assert(offsetof(Params, sums_ptr) == RUY_OFFSET_SUMS_PTR, ""); + static_assert(offsetof(Params, packed_ptr) == RUY_OFFSET_PACKED_PTR, ""); + static_assert(offsetof(Params, src_inc0) == RUY_OFFSET_SRC_INC0, ""); + static_assert(offsetof(Params, src_inc1) == RUY_OFFSET_SRC_INC1, ""); + static_assert(offsetof(Params, src_inc2) == RUY_OFFSET_SRC_INC2, ""); + static_assert(offsetof(Params, src_inc3) == RUY_OFFSET_SRC_INC3, ""); + static_assert(offsetof(Params, src_rows) == RUY_OFFSET_SRC_ROWS, ""); + static_assert(offsetof(Params, src_zero_point) == RUY_OFFSET_SRC_ZERO_POINT, + ""); + static_assert(offsetof(Params, input_xor) == RUY_OFFSET_INPUT_XOR, ""); +} + +// Packing code for out-of-order ARMv7 CPUs like the Krait 400 or A9. +// No attempt made at making this code efficient on in-order cores yet. +void Pack8bitNeonOutOfOrder4Cols(const PackParams8bit& params) { + CheckOffsetsInPackParams8bit(params); + profiler::ScopeLabel label("Pack (kNeon, optimized for out-of-order cores)"); + const void* src_ptr0 = params.src_ptr0; + const void* src_ptr1 = params.src_ptr1; + const void* src_ptr2 = params.src_ptr2; + const void* src_ptr3 = params.src_ptr3; + const int src_inc0 = params.src_inc0; + const int src_inc1 = params.src_inc1; + const int src_inc2 = params.src_inc2; + const int src_inc3 = params.src_inc3; + const std::int8_t* packed_ptr = params.packed_ptr; + + asm volatile( + // clang-format off + + "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_INPUT_XOR) "]\n" + "vdup.8 q11, r2\n" + "mov r1, #0\n" + // Zero-out the accumulators + "vmov.i32 q12, #0\n" + "vmov.i32 q13, #0\n" + "vmov.i32 q14, #0\n" + "vmov.i32 q15, #0\n" + + // Round down src_rows to nearest multiple of 16. + "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SRC_ROWS) "]\n" + "and r2, r3, #-16\n" + "cmp r1, r2\n" + "beq 3f\n" + + "1:\n" + "add r1, r1, #16\n" + /* Load q0 */ + "vld1.8 {d0, d1}, [%[src_ptr0]]\n" + "add %[src_ptr0], %[src_ptr0], %[src_inc0]\n" + RUY_PREFETCH_LOAD("pld [%[src_ptr0]]\n") + + /* Load q1 */ + "vld1.8 {d2, d3}, [%[src_ptr1]]\n" + "add %[src_ptr1], %[src_ptr1], %[src_inc1]\n" + RUY_PREFETCH_LOAD("pld [%[src_ptr1]]\n") + + "veor.8 q4, q0, q11\n" + "veor.8 q5, q1, q11\n" + + // Pairwise add in to 16b accumulators. + "vpaddl.s8 q8, q4\n" + "vpaddl.s8 q9, q5\n" + + "vst1.32 {q4}, [%[packed_ptr]]!\n" + "vst1.32 {q5}, [%[packed_ptr]]!\n" + + // Pairwise add accumulate into 32b accumulators. + // q12 and q13 contain 4x32b accumulators + "vpadal.s16 q12, q8\n" + "vpadal.s16 q13, q9\n" + + // Now do the same for src_ptr2 and src_ptr3. + "vld1.8 {d0, d1}, [%[src_ptr2]]\n" + "add %[src_ptr2], %[src_ptr2], %[src_inc2]\n" + RUY_PREFETCH_LOAD("pld [%[src_ptr2]]\n") + + "vld1.8 {d2, d3}, [%[src_ptr3]]\n" + "add %[src_ptr3], %[src_ptr3], %[src_inc3]\n" + RUY_PREFETCH_LOAD("pld [%[src_ptr3]]\n") + + "veor.8 q4, q0, q11\n" + "veor.8 q5, q1, q11\n" + + "vpaddl.s8 q8, q4\n" + "vpaddl.s8 q9, q5\n" + + "vst1.32 {q4}, [%[packed_ptr]]!\n" + "vst1.32 {q5}, [%[packed_ptr]]!\n" + + // Pairwise add accumulate into 32b accumulators. + // q14 and q15 contain 4x32b accumulators + "vpadal.s16 q14, q8\n" + "vpadal.s16 q15, q9\n" + + "cmp r1, r2\n" + "bne 1b\n" + + "3:\n" + + // Now pack the last (num_rows % 16) rows. + "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SRC_ROWS) "]\n" + "ands r2, r3, #15\n" + "beq 4f\n" + "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SRC_ZERO_POINT) "]\n" + "vdup.8 q0, r3\n" + "vdup.8 q1, r3\n" + +// First, read/accumulate/write for src_ptr0 and src_ptr1. +#define RUY_LOAD_ONE_ROW1(I, R) \ + "cmp r2, #" #I "\n" \ + "beq 5f\n" \ + "vld1.8 { d0[" #R "]}, [%[src_ptr0]]!\n" \ + "vld1.8 { d2[" #R "]}, [%[src_ptr1]]!\n" \ + + RUY_LOAD_ONE_ROW1(0, 0) + RUY_LOAD_ONE_ROW1(1, 1) + RUY_LOAD_ONE_ROW1(2, 2) + RUY_LOAD_ONE_ROW1(3, 3) + RUY_LOAD_ONE_ROW1(4, 4) + RUY_LOAD_ONE_ROW1(5, 5) + RUY_LOAD_ONE_ROW1(6, 6) + RUY_LOAD_ONE_ROW1(7, 7) +#undef RUY_LOAD_ONE_ROW1 + +#define RUY_LOAD_ONE_ROW2(I, R) \ + "cmp r2, #" #I "\n" \ + "beq 5f\n" \ + "vld1.8 { d1[" #R "]}, [%[src_ptr0]]!\n" \ + "vld1.8 { d3[" #R "]}, [%[src_ptr1]]!\n" \ + + RUY_LOAD_ONE_ROW2(8, 0) + RUY_LOAD_ONE_ROW2(9, 1) + RUY_LOAD_ONE_ROW2(10, 2) + RUY_LOAD_ONE_ROW2(11, 3) + RUY_LOAD_ONE_ROW2(12, 4) + RUY_LOAD_ONE_ROW2(13, 5) + RUY_LOAD_ONE_ROW2(14, 6) + RUY_LOAD_ONE_ROW2(15, 7) +#undef RUY_LOAD_ONE_ROW2 + + "5:\n" + + "veor.16 q4, q0, q11\n" + "veor.16 q5, q1, q11\n" + + "vpaddl.s8 q8, q4\n" + "vpaddl.s8 q9, q5\n" + + // Pairwise add accumulate to 4x32b accumulators. + "vpadal.s16 q12, q8\n" + "vpadal.s16 q13, q9\n" + + "vst1.32 {q4}, [%[packed_ptr]]!\n" + "vst1.32 {q5}, [%[packed_ptr]]!\n" + + // Reset to src_zero for src_ptr2 and src_ptr3. + "vdup.8 q0, r3\n" + "vdup.8 q1, r3\n" + +// Next, read/accumulate/write for src_ptr2 and src_ptr3. +#define RUY_LOAD_ONE_ROW1(I, R) \ + "cmp r2, #" #I "\n" \ + "beq 5f\n" \ + "vld1.8 { d0[" #R "]}, [%[src_ptr2]]!\n" \ + "vld1.8 { d2[" #R "]}, [%[src_ptr3]]!\n" \ + + RUY_LOAD_ONE_ROW1(0, 0) + RUY_LOAD_ONE_ROW1(1, 1) + RUY_LOAD_ONE_ROW1(2, 2) + RUY_LOAD_ONE_ROW1(3, 3) + RUY_LOAD_ONE_ROW1(4, 4) + RUY_LOAD_ONE_ROW1(5, 5) + RUY_LOAD_ONE_ROW1(6, 6) + RUY_LOAD_ONE_ROW1(7, 7) +#undef RUY_LOAD_ONE_ROW1 + +#define RUY_LOAD_ONE_ROW2(I, R) \ + "cmp r2, #" #I "\n" \ + "beq 5f\n" \ + "vld1.8 { d1[" #R "]}, [%[src_ptr2]]!\n" \ + "vld1.8 { d3[" #R "]}, [%[src_ptr3]]!\n" \ + + RUY_LOAD_ONE_ROW2(8, 0) + RUY_LOAD_ONE_ROW2(9, 1) + RUY_LOAD_ONE_ROW2(10, 2) + RUY_LOAD_ONE_ROW2(11, 3) + RUY_LOAD_ONE_ROW2(12, 4) + RUY_LOAD_ONE_ROW2(13, 5) + RUY_LOAD_ONE_ROW2(14, 6) + RUY_LOAD_ONE_ROW2(15, 7) +#undef RUY_LOAD_ONE_ROW2 + + "5:\n" + + "veor.16 q4, q0, q11\n" + "veor.16 q5, q1, q11\n" + + "vpaddl.s8 q8, q4\n" + "vpaddl.s8 q9, q5\n" + + // Pairwise add accumulate to 4x32b accumulators. + "vpadal.s16 q14, q8\n" + "vpadal.s16 q15, q9\n" + + "vst1.32 {q4}, [%[packed_ptr]]!\n" + "vst1.32 {q5}, [%[packed_ptr]]!\n" + + "4:\n" + // Pairwise add 32-bit accumulators + "vpadd.i32 d24, d24, d25\n" + "vpadd.i32 d26, d26, d27\n" + "vpadd.i32 d28, d28, d29\n" + "vpadd.i32 d30, d30, d31\n" + // Final 32-bit values per row + "vpadd.i32 d25, d24, d26\n" + "vpadd.i32 d27, d28, d30\n" + + "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SUMS_PTR) "]\n" + "cmp r3, #0\n" + "beq 6f\n" + "vst1.32 {d25}, [r3]!\n" + "vst1.32 {d27}, [r3]!\n" + "6:\n" + // clang-format on + + : [ src_ptr0 ] "+r"(src_ptr0), [ src_ptr1 ] "+r"(src_ptr1), + [ src_ptr2 ] "+r"(src_ptr2), [ src_ptr3 ] "+r"(src_ptr3) + : [ src_inc0 ] "r"(src_inc0), [ src_inc1 ] "r"(src_inc1), + [ src_inc2 ] "r"(src_inc2), [ src_inc3 ] "r"(src_inc3), + [ packed_ptr ] "r"(packed_ptr), [ params ] "r"(¶ms) + : "cc", "memory", "r1", "r2", "r3", "q0", "q1", "q2", "q3", + "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13"); +} + +// Packing code for out-of-order ARMv7 CPUs like the Krait 400 or A9. +// No attempt made at making this code efficient on in-order cores yet. +// This version differs from the above in that we only handle two columns +// at a time. +void Pack8bitNeonOutOfOrder2Cols(const PackParams8bit& params) { + CheckOffsetsInPackParams8bit(params); + profiler::ScopeLabel label("Pack (kNeon, optimized for out-of-order cores)"); + const void* src_ptr0 = params.src_ptr0; + const void* src_ptr1 = params.src_ptr1; + const int src_inc0 = params.src_inc0; + const int src_inc1 = params.src_inc1; + const std::int8_t* packed_ptr = params.packed_ptr; + + asm volatile( + // clang-format off + + "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_INPUT_XOR) "]\n" + "vdup.8 q11, r2\n" + "mov r1, #0\n" + // Zero-out the accumulators + "vmov.i32 q12, #0\n" + "vmov.i32 q13, #0\n" + + // Round down src_rows to nearest multiple of 16. + "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SRC_ROWS) "]\n" + "and r2, r3, #-16\n" + "cmp r1, r2\n" + "beq 3f\n" + + "1:\n" + "add r1, r1, #16\n" + /* Load q0 */ + "vld1.8 {d0, d1}, [%[src_ptr0]]\n" + "add %[src_ptr0], %[src_ptr0], %[src_inc0]\n" + + /* Load q1 */ + "vld1.8 {d2, d3}, [%[src_ptr1]]\n" + "add %[src_ptr1], %[src_ptr1], %[src_inc1]\n" + + "veor.8 q4, q0, q11\n" + "veor.8 q5, q1, q11\n" + + // Pairwise add in to 16b accumulators. + "vpaddl.s8 q8, q4\n" + "vpaddl.s8 q9, q5\n" + + "vst1.32 {q4}, [%[packed_ptr]]!\n" + "vst1.32 {q5}, [%[packed_ptr]]!\n" + + // Pairwise add accumulate into 32b accumulators. + // q12 and q13 contain 4x32b accumulators + "vpadal.s16 q12, q8\n" + "vpadal.s16 q13, q9\n" + + "cmp r1, r2\n" + + "bne 1b\n" + + "3:\n" + + // Now pack the last (num_rows % 16) rows. + "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SRC_ROWS) "]\n" + "ands r2, r3, #15\n" + "beq 4f\n" + "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SRC_ZERO_POINT) "]\n" + "vdup.8 q0, r3\n" + "vdup.8 q1, r3\n" + +// Read/accumulate/write for src_ptr0 and src_ptr1. +#define RUY_LOAD_ONE_ROW1(I, R) \ + "cmp r2, #" #I "\n" \ + "beq 5f\n" \ + "vld1.8 { d0[" #R "]}, [%[src_ptr0]]!\n" \ + "vld1.8 { d2[" #R "]}, [%[src_ptr1]]!\n" \ + + RUY_LOAD_ONE_ROW1(0, 0) + RUY_LOAD_ONE_ROW1(1, 1) + RUY_LOAD_ONE_ROW1(2, 2) + RUY_LOAD_ONE_ROW1(3, 3) + RUY_LOAD_ONE_ROW1(4, 4) + RUY_LOAD_ONE_ROW1(5, 5) + RUY_LOAD_ONE_ROW1(6, 6) + RUY_LOAD_ONE_ROW1(7, 7) +#undef RUY_LOAD_ONE_ROW1 + +#define RUY_LOAD_ONE_ROW2(I, R) \ + "cmp r2, #" #I "\n" \ + "beq 5f\n" \ + "vld1.8 { d1[" #R "]}, [%[src_ptr0]]!\n" \ + "vld1.8 { d3[" #R "]}, [%[src_ptr1]]!\n" \ + + RUY_LOAD_ONE_ROW2(8, 0) + RUY_LOAD_ONE_ROW2(9, 1) + RUY_LOAD_ONE_ROW2(10, 2) + RUY_LOAD_ONE_ROW2(11, 3) + RUY_LOAD_ONE_ROW2(12, 4) + RUY_LOAD_ONE_ROW2(13, 5) + RUY_LOAD_ONE_ROW2(14, 6) + RUY_LOAD_ONE_ROW2(15, 7) +#undef RUY_LOAD_ONE_ROW2 + + "5:\n" + + "veor.16 q4, q0, q11\n" + "veor.16 q5, q1, q11\n" + + "vpaddl.s8 q8, q4\n" + "vpaddl.s8 q9, q5\n" + + + // Pairwise add accumulate to 4x32b accumulators. + "vpadal.s16 q12, q8\n" + "vpadal.s16 q13, q9\n" + + "vst1.32 {q4}, [%[packed_ptr]]!\n" + "vst1.32 {q5}, [%[packed_ptr]]!\n" + + "4:\n" + + // Pairwise add 32-bit accumulators + "vpadd.i32 d24, d24, d25\n" + "vpadd.i32 d26, d26, d27\n" + // Final 32-bit values per row + "vpadd.i32 d25, d24, d26\n" + + "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SUMS_PTR) "]\n" + "cmp r3, #0\n" + "beq 6f\n" + "vst1.32 {d25}, [r3]!\n" + "6:\n" + // clang-format on + + : [ src_ptr0 ] "+r"(src_ptr0), [ src_ptr1 ] "+r"(src_ptr1) + : [ src_inc0 ] "r"(src_inc0), [ src_inc1 ] "r"(src_inc1), + [ packed_ptr ] "r"(packed_ptr), [ params ] "r"(¶ms) + : "cc", "memory", "r1", "r2", "r3", "q0", "q1", "q2", "q3", + "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13"); +} + +#undef RUY_OFFSET_SRC_PTR0 +#undef RUY_OFFSET_SRC_PTR1 +#undef RUY_OFFSET_SRC_PTR2 +#undef RUY_OFFSET_SRC_PTR32 +#undef RUY_OFFSET_SUMS_PTR +#undef RUY_OFFSET_PACKED_PTR0 +#undef RUY_OFFSET_SRC_INC0 +#undef RUY_OFFSET_SRC_INC1 +#undef RUY_OFFSET_SRC_INC2 +#undef RUY_OFFSET_SRC_INC3 +#undef RUY_OFFSET_SRC_ROWS +#undef RUY_OFFSET_SRC_ZERO_POINT +#undef RUY_OFFSET_INPUT_XOR + +#endif // RUY_PLATFORM(NEON_32) && RUY_OPT_ENABLED(RUY_OPT_ASM) + +#if RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM) + +void Pack8bitNeonInOrder(const void* src_ptr0, const void* src_ptr1, + const void* src_ptr2, const void* src_ptr3, + int src_inc0, int src_inc1, int src_inc2, int src_inc3, + int src_rows, int src_zero_point, + std::int8_t* packed_ptr, int start_col, int end_col, + std::int32_t* sums_ptr, int input_xor) { + profiler::ScopeLabel label("Pack (kNeon, optimized for in-order cores)"); + asm volatile( + // clang-format off + "dup v26.16b, %w[input_xor]\n" + "mov w1, #0\n" + "dup v28.4s, wzr\n" + "dup v29.4s, wzr\n" + "dup v30.4s, wzr\n" + "dup v31.4s, wzr\n" + + "and w2, %w[rows], #-16\n" + "cmp w1, w2\n" + "beq 3f\n" + "ldr x10, [%[src_ptr0], #8]\n" + "ld1 {v0.8b}, [%[src_ptr0]], %[src_inc0]\n" + "ldr x11, [%[src_ptr1], #8]\n" + "ld1 {v1.8b}, [%[src_ptr1]], %[src_inc1]\n" + "ldr x12, [%[src_ptr2], #8]\n" + "ld1 {v2.8b}, [%[src_ptr2]], %[src_inc2]\n" + "ldr x13, [%[src_ptr3], #8]\n" + "ld1 {v3.8b}, [%[src_ptr3]], %[src_inc3]\n" + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #64]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #64]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #64]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #64]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #128]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #128]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #128]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #128]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #192]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #192]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #192]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #192]\n") + "add w1, w1, #16\n" + "cmp w1, w2\n" + + "beq 2f\n" + + "1:\n" + "add w1, w1, #16\n" + "ins v0.d[1], x10\n" + "ldr x10, [%[src_ptr0], #8]\n" + "ins v1.d[1], x11\n" + "ldr x11, [%[src_ptr1], #8]\n" + "ins v2.d[1], x12\n" + "ldr x12, [%[src_ptr2], #8]\n" + "ins v3.d[1], x13\n" + "ldr x13, [%[src_ptr3], #8]\n" + "eor v4.16b, v0.16b, v26.16b\n" + "ld1 {v0.8b}, [%[src_ptr0]], %[src_inc0]\n" + "eor v5.16b, v1.16b, v26.16b\n" + "ld1 {v1.8b}, [%[src_ptr1]], %[src_inc1]\n" + "eor v6.16b, v2.16b, v26.16b\n" + "ld1 {v2.8b}, [%[src_ptr2]], %[src_inc2]\n" + "eor v7.16b, v3.16b, v26.16b\n" + "ld1 {v3.8b}, [%[src_ptr3]], %[src_inc3]\n" + "saddlp v16.8h, v4.16b\n" + "str q4, [%[packed_ptr], #0]\n" + "saddlp v17.8h, v5.16b\n" + "str q5, [%[packed_ptr], #16]\n" + "saddlp v18.8h, v6.16b\n" + "str q6, [%[packed_ptr], #32]\n" + "saddlp v19.8h, v7.16b\n" + "str q7, [%[packed_ptr], #48]\n" + "sadalp v28.4s, v16.8h\n" + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #240]\n") + "cmp w1, w2\n" + "sadalp v29.4s, v17.8h\n" + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #240]\n") + "add %[packed_ptr], %[packed_ptr], #64\n" + "sadalp v30.4s, v18.8h\n" + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #240]\n") + "sadalp v31.4s, v19.8h\n" + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #240]\n") + + "bne 1b\n" + + "2:\n" + "ins v0.d[1], x10\n" + "ins v1.d[1], x11\n" + "ins v2.d[1], x12\n" + "ins v3.d[1], x13\n" + "eor v4.16b, v0.16b, v26.16b\n" + "eor v5.16b, v1.16b, v26.16b\n" + "eor v6.16b, v2.16b, v26.16b\n" + "eor v7.16b, v3.16b, v26.16b\n" + + "saddlp v16.8h, v4.16b\n" + "str q4, [%[packed_ptr], #0]\n" + "saddlp v17.8h, v5.16b\n" + "str q5, [%[packed_ptr], #16]\n" + "saddlp v18.8h, v6.16b\n" + "str q6, [%[packed_ptr], #32]\n" + "saddlp v19.8h, v7.16b\n" + "str q7, [%[packed_ptr], #48]\n" + "sadalp v28.4s, v16.8h\n" + "sadalp v29.4s, v17.8h\n" + "sadalp v30.4s, v18.8h\n" + "sadalp v31.4s, v19.8h\n" + + "add %[packed_ptr], %[packed_ptr], #64\n" + + "3:\n" + + "ands w2, %w[rows], #15\n" + "beq 4f\n" + "dup v0.16b, %w[src_zero_point]\n" + "dup v1.16b, %w[src_zero_point]\n" + "dup v2.16b, %w[src_zero_point]\n" + "dup v3.16b, %w[src_zero_point]\n" +#define RUY_LOAD_ONE_ROW(R) \ + "cmp w2, #" #R "\n" \ + "beq 5f\n" \ + "ld1 { v0.b }[" #R "], [%[src_ptr0]], #1\n" \ + "ld1 { v1.b }[" #R "], [%[src_ptr1]], #1\n" \ + "ld1 { v2.b }[" #R "], [%[src_ptr2]], #1\n" \ + "ld1 { v3.b }[" #R "], [%[src_ptr3]], #1\n" + + RUY_LOAD_ONE_ROW(0) + RUY_LOAD_ONE_ROW(1) + RUY_LOAD_ONE_ROW(2) + RUY_LOAD_ONE_ROW(3) + RUY_LOAD_ONE_ROW(4) + RUY_LOAD_ONE_ROW(5) + RUY_LOAD_ONE_ROW(6) + RUY_LOAD_ONE_ROW(7) + RUY_LOAD_ONE_ROW(8) + RUY_LOAD_ONE_ROW(9) + RUY_LOAD_ONE_ROW(10) + RUY_LOAD_ONE_ROW(11) + RUY_LOAD_ONE_ROW(12) + RUY_LOAD_ONE_ROW(13) + RUY_LOAD_ONE_ROW(14) + RUY_LOAD_ONE_ROW(15) +#undef RUY_LOAD_ONE_ROW + "5:\n" + + "eor v4.16b, v0.16b, v26.16b\n" + "eor v5.16b, v1.16b, v26.16b\n" + "eor v6.16b, v2.16b, v26.16b\n" + "eor v7.16b, v3.16b, v26.16b\n" + + "saddlp v16.8h, v4.16b\n" + "saddlp v17.8h, v5.16b\n" + "saddlp v18.8h, v6.16b\n" + "saddlp v19.8h, v7.16b\n" + "sadalp v28.4s, v16.8h\n" + "sadalp v29.4s, v17.8h\n" + "sadalp v30.4s, v18.8h\n" + "sadalp v31.4s, v19.8h\n" + + "str q4, [%[packed_ptr], #0]\n" + "str q5, [%[packed_ptr], #16]\n" + "str q6, [%[packed_ptr], #32]\n" + "str q7, [%[packed_ptr], #48]\n" + "add %[packed_ptr], %[packed_ptr], #64\n" + + "4:\n" + + "addp v28.4s, v28.4s, v29.4s\n" + "addp v30.4s, v30.4s, v31.4s\n" + "addp v28.4s, v28.4s, v30.4s\n" + + "cmp %[sums_ptr], #0\n" + "beq 6f\n" + "st1 {v28.4s}, [%[sums_ptr]], #16\n" + "6:\n" + // clang-format on + + : [ src_ptr0 ] "+r"(src_ptr0), [ src_ptr1 ] "+r"(src_ptr1), + [ src_ptr2 ] "+r"(src_ptr2), [ src_ptr3 ] "+r"(src_ptr3), + [ packed_ptr ] "+r"(packed_ptr), [ sums_ptr ] "+r"(sums_ptr) + : [ src_inc0 ] "r"(static_cast(src_inc0)), [ src_inc1 ] "r"(static_cast(src_inc1)), + [ src_inc2 ] "r"(static_cast(src_inc2)), [ src_inc3 ] "r"(static_cast(src_inc3)), + [ rows ] "r"(src_rows), + [ src_zero_point ] "r"(src_zero_point), + [input_xor] "r"(input_xor) + : "cc", "memory", "x1", "x2", "x10", "x11", "x12", "x13", "v0", "v1", "v2", "v3", "v4", "v5", + "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", + "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", + "v25", "v26", "v27", "v28", "v29", "v30", "v31"); +} + +void Pack8bitNeonDotprodInOrder(const void* src_ptr0, const void* src_ptr1, + const void* src_ptr2, const void* src_ptr3, + int src_inc0, int src_inc1, int src_inc2, + int src_inc3, int src_rows, int src_zero_point, + std::int8_t* packed_ptr, int start_col, + int end_col, std::int32_t* sums_ptr, + int input_xor) { + profiler::ScopeLabel label( + "Pack (kNeonDotprod, optimized for in-order cores)"); + asm volatile( + // clang-format off + "dup v26.16b, %w[input_xor]\n" + "mov w1, #1\n" + "dup v27.16b, w1\n" + "mov w1, #0\n" + "dup v28.4s, wzr\n" + "dup v29.4s, wzr\n" + "dup v30.4s, wzr\n" + "dup v31.4s, wzr\n" + + "and w2, %w[rows], #-16\n" + "cmp w1, w2\n" + "beq 3f\n" + "ldr x10, [%[src_ptr0], #8]\n" + "ld1 {v0.8b}, [%[src_ptr0]], %[src_inc0]\n" + "ldr x11, [%[src_ptr1], #8]\n" + "ld1 {v1.8b}, [%[src_ptr1]], %[src_inc1]\n" + "ldr x12, [%[src_ptr2], #8]\n" + "ld1 {v2.8b}, [%[src_ptr2]], %[src_inc2]\n" + "ldr x13, [%[src_ptr3], #8]\n" + "ld1 {v3.8b}, [%[src_ptr3]], %[src_inc3]\n" + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #64]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #64]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #64]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #64]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #128]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #128]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #128]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #128]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #192]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #192]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #192]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #192]\n") + "add w1, w1, #16\n" + "cmp w1, w2\n" + + "beq 2f\n" + + "1:\n" + "add w1, w1, #16\n" + "ins v0.d[1], x10\n" + "ldr x10, [%[src_ptr0], #8]\n" + "ins v1.d[1], x11\n" + "ldr x11, [%[src_ptr1], #8]\n" + "ins v2.d[1], x12\n" + "ldr x12, [%[src_ptr2], #8]\n" + "ins v3.d[1], x13\n" + "ldr x13, [%[src_ptr3], #8]\n" + + "eor v4.16b, v0.16b, v26.16b\n" + "ld1 {v0.8b}, [%[src_ptr0]], %[src_inc0]\n" + "eor v5.16b, v1.16b, v26.16b\n" + "ld1 {v1.8b}, [%[src_ptr1]], %[src_inc1]\n" + "eor v6.16b, v2.16b, v26.16b\n" + "ld1 {v2.8b}, [%[src_ptr2]], %[src_inc2]\n" + "eor v7.16b, v3.16b, v26.16b\n" + "ld1 {v3.8b}, [%[src_ptr3]], %[src_inc3]\n" + + "trn1 v16.4s, v4.4s, v5.4s\n" + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #240]\n") + "trn2 v17.4s, v4.4s, v5.4s\n" + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #240]\n") + "trn1 v18.4s, v6.4s, v7.4s\n" + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #240]\n") + "trn2 v19.4s, v6.4s, v7.4s\n" + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #240]\n") + + "trn1 v20.2d, v16.2d, v18.2d\n" + "trn2 v22.2d, v16.2d, v18.2d\n" + "trn1 v21.2d, v17.2d, v19.2d\n" + "trn2 v23.2d, v17.2d, v19.2d\n" + "cmp w1, w2\n" + + ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" + "str q20, [%[packed_ptr], #0]\n" + ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" + "str q21, [%[packed_ptr], #32]\n" + ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" + "str q22, [%[packed_ptr], #64]\n" + ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" + "str q23, [%[packed_ptr], #96]\n" + + "add %[packed_ptr], %[packed_ptr], #128\n" + + "bne 1b\n" + + "2:\n" + "ins v0.d[1], x10\n" + "ins v1.d[1], x11\n" + "ins v2.d[1], x12\n" + "ins v3.d[1], x13\n" + "eor v0.16b, v0.16b, v26.16b\n" + "eor v1.16b, v1.16b, v26.16b\n" + "eor v2.16b, v2.16b, v26.16b\n" + "eor v3.16b, v3.16b, v26.16b\n" + + "trn1 v16.4s, v0.4s, v1.4s\n" + "trn2 v17.4s, v0.4s, v1.4s\n" + "trn1 v18.4s, v2.4s, v3.4s\n" + "trn2 v19.4s, v2.4s, v3.4s\n" + + "trn1 v20.2d, v16.2d, v18.2d\n" + "trn2 v22.2d, v16.2d, v18.2d\n" + "trn1 v21.2d, v17.2d, v19.2d\n" + "trn2 v23.2d, v17.2d, v19.2d\n" + + ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" + "str q20, [%[packed_ptr], #0]\n" + ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" + "str q21, [%[packed_ptr], #32]\n" + ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" + "str q22, [%[packed_ptr], #64]\n" + ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" + "str q23, [%[packed_ptr], #96]\n" + "add %[packed_ptr], %[packed_ptr], #128\n" + + "3:\n" + + "ands w2, %w[rows], #15\n" + "beq 4f\n" + "dup v0.16b, %w[src_zero_point]\n" + "dup v1.16b, %w[src_zero_point]\n" + "dup v2.16b, %w[src_zero_point]\n" + "dup v3.16b, %w[src_zero_point]\n" +#define RUY_LOAD_ONE_ROW(R) \ + "cmp w2, #" #R "\n" \ + "beq 5f\n" \ + "ld1 { v0.b }[" #R "], [%[src_ptr0]], #1\n" \ + "ld1 { v1.b }[" #R "], [%[src_ptr1]], #1\n" \ + "ld1 { v2.b }[" #R "], [%[src_ptr2]], #1\n" \ + "ld1 { v3.b }[" #R "], [%[src_ptr3]], #1\n" + + RUY_LOAD_ONE_ROW(0) + RUY_LOAD_ONE_ROW(1) + RUY_LOAD_ONE_ROW(2) + RUY_LOAD_ONE_ROW(3) + RUY_LOAD_ONE_ROW(4) + RUY_LOAD_ONE_ROW(5) + RUY_LOAD_ONE_ROW(6) + RUY_LOAD_ONE_ROW(7) + RUY_LOAD_ONE_ROW(8) + RUY_LOAD_ONE_ROW(9) + RUY_LOAD_ONE_ROW(10) + RUY_LOAD_ONE_ROW(11) + RUY_LOAD_ONE_ROW(12) + RUY_LOAD_ONE_ROW(13) + RUY_LOAD_ONE_ROW(14) + RUY_LOAD_ONE_ROW(15) +#undef RUY_LOAD_ONE_ROW + "5:\n" + + "eor v0.16b, v0.16b, v26.16b\n" + "eor v1.16b, v1.16b, v26.16b\n" + "eor v2.16b, v2.16b, v26.16b\n" + "eor v3.16b, v3.16b, v26.16b\n" + + "trn1 v16.4s, v0.4s, v1.4s\n" + "trn2 v17.4s, v0.4s, v1.4s\n" + "trn1 v18.4s, v2.4s, v3.4s\n" + "trn2 v19.4s, v2.4s, v3.4s\n" + + "trn1 v20.2d, v16.2d, v18.2d\n" + "trn2 v22.2d, v16.2d, v18.2d\n" + "trn1 v21.2d, v17.2d, v19.2d\n" + "trn2 v23.2d, v17.2d, v19.2d\n" + + ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" + "str q20, [%[packed_ptr], #0]\n" + "cmp w2, #4\n" + "ble 4f\n" + ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" + "str q21, [%[packed_ptr], #32]\n" + "cmp w2, #8\n" + "ble 4f\n" + ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" + "str q22, [%[packed_ptr], #64]\n" + "cmp w2, #12\n" + "ble 4f\n" + ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" + "str q23, [%[packed_ptr], #96]\n" + "add %[packed_ptr], %[packed_ptr], #128\n" + + "4:\n" + + "add v28.4s, v28.4s, v29.4s\n" + "add v30.4s, v30.4s, v31.4s\n" + "add v28.4s, v28.4s, v30.4s\n" + + "cmp %[sums_ptr], #0\n" + "beq 6f\n" + "st1 {v28.4s}, [%[sums_ptr]], #16\n" + "6:\n" + // clang-format on + + : [ src_ptr0 ] "+r"(src_ptr0), [src_ptr1] "+r"(src_ptr1), [src_ptr2] "+r"(src_ptr2), + [src_ptr3] "+r"(src_ptr3), [packed_ptr] "+r"(packed_ptr), [sums_ptr] "+r"(sums_ptr) + : [ src_inc0 ] "r"(static_cast(src_inc0)), [ src_inc1 ] "r"(static_cast(src_inc1)), + [ src_inc2 ] "r"(static_cast(src_inc2)), [ src_inc3 ] "r"(static_cast(src_inc3)), + [rows] "r"(src_rows), + [src_zero_point] "r"(static_cast(src_zero_point)), + [input_xor] "r"(input_xor) + : "cc", "memory", "x1", "x2", "x10", "x11", "x12", "x13", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", + "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", + "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); +} + +void Pack8bitNeonDotprodOutOfOrder(const void* src_ptr0, const void* src_ptr1, + const void* src_ptr2, const void* src_ptr3, + int src_inc0, int src_inc1, int src_inc2, + int src_inc3, int src_rows, + int src_zero_point, std::int8_t* packed_ptr, + int start_col, int end_col, + std::int32_t* sums_ptr, int input_xor) { + profiler::ScopeLabel label( + "Pack (kNeonDotprod, optimized for out-of-order cores)"); + asm volatile( + // clang-format off + "dup v26.16b, %w[input_xor]\n" + "mov w1, #1\n" + "dup v27.16b, w1\n" + "mov w1, #0\n" + "dup v28.4s, wzr\n" + "dup v29.4s, wzr\n" + "dup v30.4s, wzr\n" + "dup v31.4s, wzr\n" + +#if RUY_OPT_ENABLED(RUY_OPT_MAX_STREAMING) + "and w2, %w[rows], #-64\n" + "cmp w1, w2\n" + "beq 9f\n" + + "ld1 {v0.16b}, [%[src_ptr0]], %[src_inc0]\n" + "ld1 {v1.16b}, [%[src_ptr1]], %[src_inc1]\n" + "ld1 {v2.16b}, [%[src_ptr2]], %[src_inc2]\n" + "ld1 {v3.16b}, [%[src_ptr3]], %[src_inc3]\n" + "ld1 {v4.16b}, [%[src_ptr0]], %[src_inc0]\n" + "ld1 {v5.16b}, [%[src_ptr1]], %[src_inc1]\n" + "ld1 {v6.16b}, [%[src_ptr2]], %[src_inc2]\n" + "ld1 {v7.16b}, [%[src_ptr3]], %[src_inc3]\n" + "ld1 {v8.16b}, [%[src_ptr0]], %[src_inc0]\n" + "ld1 {v9.16b}, [%[src_ptr1]], %[src_inc1]\n" + "ld1 {v10.16b}, [%[src_ptr2]], %[src_inc2]\n" + "ld1 {v11.16b}, [%[src_ptr3]], %[src_inc3]\n" + "ld1 {v12.16b}, [%[src_ptr0]], %[src_inc0]\n" + "ld1 {v13.16b}, [%[src_ptr1]], %[src_inc1]\n" + "ld1 {v14.16b}, [%[src_ptr2]], %[src_inc2]\n" + "ld1 {v15.16b}, [%[src_ptr3]], %[src_inc3]\n" + "add w1, w1, #64\n" + "cmp w1, w2\n" + "beq 8f\n" + + "7:\n" + "eor v0.16b, v0.16b, v26.16b\n" + "eor v1.16b, v1.16b, v26.16b\n" + "eor v2.16b, v2.16b, v26.16b\n" + "eor v3.16b, v3.16b, v26.16b\n" + + "trn1 v16.4s, v0.4s, v1.4s\n" + "trn2 v17.4s, v0.4s, v1.4s\n" + "trn1 v18.4s, v2.4s, v3.4s\n" + "trn2 v19.4s, v2.4s, v3.4s\n" + + "ld1 {v0.16b}, [%[src_ptr0]], %[src_inc0]\n" + "ld1 {v1.16b}, [%[src_ptr1]], %[src_inc1]\n" + "ld1 {v2.16b}, [%[src_ptr2]], %[src_inc2]\n" + "ld1 {v3.16b}, [%[src_ptr3]], %[src_inc3]\n" + "add w1, w1, #16\n" + + "trn1 v20.2d, v16.2d, v18.2d\n" + "trn2 v22.2d, v16.2d, v18.2d\n" + "trn1 v21.2d, v17.2d, v19.2d\n" + "trn2 v23.2d, v17.2d, v19.2d\n" + + ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" + ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" + ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" + ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" + + "str q20, [%[packed_ptr], #0]\n" + "str q21, [%[packed_ptr], #32]\n" + "str q22, [%[packed_ptr], #64]\n" + "str q23, [%[packed_ptr], #96]\n" + "add %[packed_ptr], %[packed_ptr], #128\n" + + "eor v4.16b, v4.16b, v26.16b\n" + "eor v5.16b, v5.16b, v26.16b\n" + "eor v6.16b, v6.16b, v26.16b\n" + "eor v7.16b, v7.16b, v26.16b\n" + + "trn1 v16.4s, v4.4s, v5.4s\n" + "trn2 v17.4s, v4.4s, v5.4s\n" + "trn1 v18.4s, v6.4s, v7.4s\n" + "trn2 v19.4s, v6.4s, v7.4s\n" + + "ld1 {v4.16b}, [%[src_ptr0]], %[src_inc0]\n" + "ld1 {v5.16b}, [%[src_ptr1]], %[src_inc1]\n" + "ld1 {v6.16b}, [%[src_ptr2]], %[src_inc2]\n" + "ld1 {v7.16b}, [%[src_ptr3]], %[src_inc3]\n" + "add w1, w1, #16\n" + + "trn1 v20.2d, v16.2d, v18.2d\n" + "trn2 v22.2d, v16.2d, v18.2d\n" + "trn1 v21.2d, v17.2d, v19.2d\n" + "trn2 v23.2d, v17.2d, v19.2d\n" + + ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" + ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" + ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" + ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" + + "str q20, [%[packed_ptr], #0]\n" + "str q21, [%[packed_ptr], #32]\n" + "str q22, [%[packed_ptr], #64]\n" + "str q23, [%[packed_ptr], #96]\n" + "add %[packed_ptr], %[packed_ptr], #128\n" + + "eor v8.16b, v8.16b, v26.16b\n" + "eor v9.16b, v9.16b, v26.16b\n" + "eor v10.16b, v10.16b, v26.16b\n" + "eor v11.16b, v11.16b, v26.16b\n" + + "trn1 v16.4s, v8.4s, v9.4s\n" + "trn2 v17.4s, v8.4s, v9.4s\n" + "trn1 v18.4s, v10.4s, v11.4s\n" + "trn2 v19.4s, v10.4s, v11.4s\n" + + "ld1 {v8.16b}, [%[src_ptr0]], %[src_inc0]\n" + "ld1 {v9.16b}, [%[src_ptr1]], %[src_inc1]\n" + "ld1 {v10.16b}, [%[src_ptr2]], %[src_inc2]\n" + "ld1 {v11.16b}, [%[src_ptr3]], %[src_inc3]\n" + "add w1, w1, #16\n" + + "trn1 v20.2d, v16.2d, v18.2d\n" + "trn2 v22.2d, v16.2d, v18.2d\n" + "trn1 v21.2d, v17.2d, v19.2d\n" + "trn2 v23.2d, v17.2d, v19.2d\n" + + ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" + ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" + ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" + ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" + + "str q20, [%[packed_ptr], #0]\n" + "str q21, [%[packed_ptr], #32]\n" + "str q22, [%[packed_ptr], #64]\n" + "str q23, [%[packed_ptr], #96]\n" + "add %[packed_ptr], %[packed_ptr], #128\n" + + "eor v12.16b, v12.16b, v26.16b\n" + "eor v13.16b, v13.16b, v26.16b\n" + "eor v14.16b, v14.16b, v26.16b\n" + "eor v15.16b, v15.16b, v26.16b\n" + + "trn1 v16.4s, v12.4s, v13.4s\n" + "trn2 v17.4s, v12.4s, v13.4s\n" + "trn1 v18.4s, v14.4s, v15.4s\n" + "trn2 v19.4s, v14.4s, v15.4s\n" + + "ld1 {v12.16b}, [%[src_ptr0]], %[src_inc0]\n" + "ld1 {v13.16b}, [%[src_ptr1]], %[src_inc1]\n" + "ld1 {v14.16b}, [%[src_ptr2]], %[src_inc2]\n" + "ld1 {v15.16b}, [%[src_ptr3]], %[src_inc3]\n" + "add w1, w1, #16\n" + + "trn1 v20.2d, v16.2d, v18.2d\n" + "trn2 v22.2d, v16.2d, v18.2d\n" + "trn1 v21.2d, v17.2d, v19.2d\n" + "trn2 v23.2d, v17.2d, v19.2d\n" + + ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" + ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" + ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" + ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" + + "str q20, [%[packed_ptr], #0]\n" + "str q21, [%[packed_ptr], #32]\n" + "str q22, [%[packed_ptr], #64]\n" + "str q23, [%[packed_ptr], #96]\n" + "add %[packed_ptr], %[packed_ptr], #128\n" + + "cmp w1, w2\n" + "bne 7b\n" + + "8:\n" + + "eor v0.16b, v0.16b, v26.16b\n" + "eor v1.16b, v1.16b, v26.16b\n" + "eor v2.16b, v2.16b, v26.16b\n" + "eor v3.16b, v3.16b, v26.16b\n" + + "trn1 v16.4s, v0.4s, v1.4s\n" + "trn2 v17.4s, v0.4s, v1.4s\n" + "trn1 v18.4s, v2.4s, v3.4s\n" + "trn2 v19.4s, v2.4s, v3.4s\n" + + "trn1 v20.2d, v16.2d, v18.2d\n" + "trn2 v22.2d, v16.2d, v18.2d\n" + "trn1 v21.2d, v17.2d, v19.2d\n" + "trn2 v23.2d, v17.2d, v19.2d\n" + + ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" + ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" + ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" + ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" + + "str q20, [%[packed_ptr], #0]\n" + "str q21, [%[packed_ptr], #32]\n" + "str q22, [%[packed_ptr], #64]\n" + "str q23, [%[packed_ptr], #96]\n" + "add %[packed_ptr], %[packed_ptr], #128\n" + + "eor v4.16b, v4.16b, v26.16b\n" + "eor v5.16b, v5.16b, v26.16b\n" + "eor v6.16b, v6.16b, v26.16b\n" + "eor v7.16b, v7.16b, v26.16b\n" + + "trn1 v16.4s, v4.4s, v5.4s\n" + "trn2 v17.4s, v4.4s, v5.4s\n" + "trn1 v18.4s, v6.4s, v7.4s\n" + "trn2 v19.4s, v6.4s, v7.4s\n" + + "trn1 v20.2d, v16.2d, v18.2d\n" + "trn2 v22.2d, v16.2d, v18.2d\n" + "trn1 v21.2d, v17.2d, v19.2d\n" + "trn2 v23.2d, v17.2d, v19.2d\n" + + ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" + ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" + ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" + ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" + + "str q20, [%[packed_ptr], #0]\n" + "str q21, [%[packed_ptr], #32]\n" + "str q22, [%[packed_ptr], #64]\n" + "str q23, [%[packed_ptr], #96]\n" + "add %[packed_ptr], %[packed_ptr], #128\n" + + "eor v8.16b, v8.16b, v26.16b\n" + "eor v9.16b, v9.16b, v26.16b\n" + "eor v10.16b, v10.16b, v26.16b\n" + "eor v11.16b, v11.16b, v26.16b\n" + + "trn1 v16.4s, v8.4s, v9.4s\n" + "trn2 v17.4s, v8.4s, v9.4s\n" + "trn1 v18.4s, v10.4s, v11.4s\n" + "trn2 v19.4s, v10.4s, v11.4s\n" + + "trn1 v20.2d, v16.2d, v18.2d\n" + "trn2 v22.2d, v16.2d, v18.2d\n" + "trn1 v21.2d, v17.2d, v19.2d\n" + "trn2 v23.2d, v17.2d, v19.2d\n" + + ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" + ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" + ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" + ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" + + "str q20, [%[packed_ptr], #0]\n" + "str q21, [%[packed_ptr], #32]\n" + "str q22, [%[packed_ptr], #64]\n" + "str q23, [%[packed_ptr], #96]\n" + "add %[packed_ptr], %[packed_ptr], #128\n" + + "eor v12.16b, v12.16b, v26.16b\n" + "eor v13.16b, v13.16b, v26.16b\n" + "eor v14.16b, v14.16b, v26.16b\n" + "eor v15.16b, v15.16b, v26.16b\n" + + "trn1 v16.4s, v12.4s, v13.4s\n" + "trn2 v17.4s, v12.4s, v13.4s\n" + "trn1 v18.4s, v14.4s, v15.4s\n" + "trn2 v19.4s, v14.4s, v15.4s\n" + + "trn1 v20.2d, v16.2d, v18.2d\n" + "trn2 v22.2d, v16.2d, v18.2d\n" + "trn1 v21.2d, v17.2d, v19.2d\n" + "trn2 v23.2d, v17.2d, v19.2d\n" + + ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" + ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" + ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" + ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" + + "str q20, [%[packed_ptr], #0]\n" + "str q21, [%[packed_ptr], #32]\n" + "str q22, [%[packed_ptr], #64]\n" + "str q23, [%[packed_ptr], #96]\n" + "add %[packed_ptr], %[packed_ptr], #128\n" + + "9:\n" +#endif // #if RUY_OPT_ENABLED(RUY_OPT_MAX_STREAMING) + "and w2, %w[rows], #-16\n" + "cmp w1, w2\n" + "beq 3f\n" + + "ld1 {v0.16b}, [%[src_ptr0]], %[src_inc0]\n" + "ld1 {v1.16b}, [%[src_ptr1]], %[src_inc1]\n" + "ld1 {v2.16b}, [%[src_ptr2]], %[src_inc2]\n" + "ld1 {v3.16b}, [%[src_ptr3]], %[src_inc3]\n" + "add w1, w1, #16\n" + "cmp w1, w2\n" + "beq 2f\n" + + "1:\n" + + "eor v0.16b, v0.16b, v26.16b\n" + "eor v1.16b, v1.16b, v26.16b\n" + "eor v2.16b, v2.16b, v26.16b\n" + "eor v3.16b, v3.16b, v26.16b\n" + + "trn1 v16.4s, v0.4s, v1.4s\n" + "trn2 v17.4s, v0.4s, v1.4s\n" + "trn1 v18.4s, v2.4s, v3.4s\n" + "trn2 v19.4s, v2.4s, v3.4s\n" + + "ld1 {v0.16b}, [%[src_ptr0]], %[src_inc0]\n" + "ld1 {v1.16b}, [%[src_ptr1]], %[src_inc1]\n" + "ld1 {v2.16b}, [%[src_ptr2]], %[src_inc2]\n" + "ld1 {v3.16b}, [%[src_ptr3]], %[src_inc3]\n" + "add w1, w1, #16\n" + + "trn1 v20.2d, v16.2d, v18.2d\n" + "trn2 v22.2d, v16.2d, v18.2d\n" + "trn1 v21.2d, v17.2d, v19.2d\n" + "trn2 v23.2d, v17.2d, v19.2d\n" + + ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" + ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" + ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" + ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" + + "str q20, [%[packed_ptr], #0]\n" + "str q21, [%[packed_ptr], #32]\n" + "str q22, [%[packed_ptr], #64]\n" + "str q23, [%[packed_ptr], #96]\n" + "add %[packed_ptr], %[packed_ptr], #128\n" + + "cmp w1, w2\n" + "bne 1b\n" + + "2:\n" + + "eor v0.16b, v0.16b, v26.16b\n" + "eor v1.16b, v1.16b, v26.16b\n" + "eor v2.16b, v2.16b, v26.16b\n" + "eor v3.16b, v3.16b, v26.16b\n" + + "trn1 v16.4s, v0.4s, v1.4s\n" + "trn2 v17.4s, v0.4s, v1.4s\n" + "trn1 v18.4s, v2.4s, v3.4s\n" + "trn2 v19.4s, v2.4s, v3.4s\n" + + "trn1 v20.2d, v16.2d, v18.2d\n" + "trn2 v22.2d, v16.2d, v18.2d\n" + "trn1 v21.2d, v17.2d, v19.2d\n" + "trn2 v23.2d, v17.2d, v19.2d\n" + + ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" + ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" + ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" + ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" + + "str q20, [%[packed_ptr], #0]\n" + "str q21, [%[packed_ptr], #32]\n" + "str q22, [%[packed_ptr], #64]\n" + "str q23, [%[packed_ptr], #96]\n" + "add %[packed_ptr], %[packed_ptr], #128\n" + + "3:\n" + + "ands w2, %w[rows], #15\n" + "beq 4f\n" + "dup v0.16b, %w[src_zero_point]\n" + "dup v1.16b, %w[src_zero_point]\n" + "dup v2.16b, %w[src_zero_point]\n" + "dup v3.16b, %w[src_zero_point]\n" +#define RUY_LOAD_ONE_ROW(R) \ + "cmp w2, #" #R "\n" \ + "beq 5f\n" \ + "ld1 { v0.b }[" #R "], [%[src_ptr0]], #1\n" \ + "ld1 { v1.b }[" #R "], [%[src_ptr1]], #1\n" \ + "ld1 { v2.b }[" #R "], [%[src_ptr2]], #1\n" \ + "ld1 { v3.b }[" #R "], [%[src_ptr3]], #1\n" + + RUY_LOAD_ONE_ROW(0) + RUY_LOAD_ONE_ROW(1) + RUY_LOAD_ONE_ROW(2) + RUY_LOAD_ONE_ROW(3) + RUY_LOAD_ONE_ROW(4) + RUY_LOAD_ONE_ROW(5) + RUY_LOAD_ONE_ROW(6) + RUY_LOAD_ONE_ROW(7) + RUY_LOAD_ONE_ROW(8) + RUY_LOAD_ONE_ROW(9) + RUY_LOAD_ONE_ROW(10) + RUY_LOAD_ONE_ROW(11) + RUY_LOAD_ONE_ROW(12) + RUY_LOAD_ONE_ROW(13) + RUY_LOAD_ONE_ROW(14) + RUY_LOAD_ONE_ROW(15) +#undef RUY_LOAD_ONE_ROW + "5:\n" + + "eor v0.16b, v0.16b, v26.16b\n" + "eor v1.16b, v1.16b, v26.16b\n" + "eor v2.16b, v2.16b, v26.16b\n" + "eor v3.16b, v3.16b, v26.16b\n" + + "trn1 v16.4s, v0.4s, v1.4s\n" + "trn2 v17.4s, v0.4s, v1.4s\n" + "trn1 v18.4s, v2.4s, v3.4s\n" + "trn2 v19.4s, v2.4s, v3.4s\n" + + "trn1 v20.2d, v16.2d, v18.2d\n" + "trn2 v22.2d, v16.2d, v18.2d\n" + "trn1 v21.2d, v17.2d, v19.2d\n" + "trn2 v23.2d, v17.2d, v19.2d\n" + + ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" + "str q20, [%[packed_ptr], #0]\n" + "cmp w2, #4\n" + "ble 4f\n" + ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" + "str q21, [%[packed_ptr], #32]\n" + "cmp w2, #8\n" + "ble 4f\n" + ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" + "str q22, [%[packed_ptr], #64]\n" + "cmp w2, #12\n" + "ble 4f\n" + ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" + "str q23, [%[packed_ptr], #96]\n" + "add %[packed_ptr], %[packed_ptr], #128\n" + + "4:\n" + + "add v28.4s, v28.4s, v29.4s\n" + "add v30.4s, v30.4s, v31.4s\n" + "add v28.4s, v28.4s, v30.4s\n" + + "cmp %[sums_ptr], #0\n" + "beq 6f\n" + "st1 {v28.4s}, [%[sums_ptr]], #16\n" + "6:\n" + // clang-format on + + : [ src_ptr0 ] "+r"(src_ptr0), [ src_ptr1 ] "+r"(src_ptr1), + [ src_ptr2 ] "+r"(src_ptr2), [ src_ptr3 ] "+r"(src_ptr3), + [ packed_ptr ] "+r"(packed_ptr), [ sums_ptr ] "+r"(sums_ptr) + : [ src_inc0 ] "r"(static_cast(src_inc0)), + [ src_inc1 ] "r"(static_cast(src_inc1)), + [ src_inc2 ] "r"(static_cast(src_inc2)), + [ src_inc3 ] "r"(static_cast(src_inc3)), + [ rows ] "r"(src_rows), + [ src_zero_point ] "r"(static_cast(src_zero_point)), + [ input_xor ] "r"(input_xor) + : "cc", "memory", "x1", "x2", "v0", "v1", "v2", "v3", "v4", "v5", "v6", + "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", + "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", + "v27", "v28", "v29", "v30", "v31"); +} + +#endif // RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM) + +#if RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM) +void PackFloatNeonOutOfOrder(const float* src_ptr0, const float* src_ptr1, + const float* src_ptr2, const float* src_ptr3, + int src_inc0, int src_inc1, int src_inc2, + int src_inc3, int src_rows, int src_zero_point, + float* packed_ptr, int start_col, int end_col) { + profiler::ScopeLabel label("Pack (kNeon, optimized for out-of-order cores)"); + asm volatile( + // clang-format off + "mov w1, #0\n" + + "and w2, %w[rows], #-4\n" + "cmp w1, w2\n" + "beq 3f\n" + "ld1 {v0.4s}, [%[src_ptr0]], %[src_inc0]\n" + "ld1 {v1.4s}, [%[src_ptr1]], %[src_inc1]\n" + "ld1 {v2.4s}, [%[src_ptr2]], %[src_inc2]\n" + "ld1 {v3.4s}, [%[src_ptr3]], %[src_inc3]\n" + "add w1, w1, #4\n" + "cmp w1, w2\n" + + "beq 2f\n" + + "1:\n" + "add w1, w1, #4\n" + + "trn1 v16.4s, v0.4s, v1.4s\n" + "trn2 v17.4s, v0.4s, v1.4s\n" + "trn1 v18.4s, v2.4s, v3.4s\n" + "trn2 v19.4s, v2.4s, v3.4s\n" + + "ld1 {v0.4s}, [%[src_ptr0]], %[src_inc0]\n" + "ld1 {v1.4s}, [%[src_ptr1]], %[src_inc1]\n" + "ld1 {v2.4s}, [%[src_ptr2]], %[src_inc2]\n" + "ld1 {v3.4s}, [%[src_ptr3]], %[src_inc3]\n" + + "trn1 v20.2d, v16.2d, v18.2d\n" + "trn2 v22.2d, v16.2d, v18.2d\n" + "trn1 v21.2d, v17.2d, v19.2d\n" + "trn2 v23.2d, v17.2d, v19.2d\n" + "cmp w1, w2\n" + + "str q20, [%[packed_ptr], #0]\n" + "str q21, [%[packed_ptr], #32]\n" + "str q22, [%[packed_ptr], #64]\n" + "str q23, [%[packed_ptr], #96]\n" + + "add %[packed_ptr], %[packed_ptr], #128\n" + + "bne 1b\n" + + "2:\n" + + "trn1 v16.4s, v0.4s, v1.4s\n" + "trn2 v17.4s, v0.4s, v1.4s\n" + "trn1 v18.4s, v2.4s, v3.4s\n" + "trn2 v19.4s, v2.4s, v3.4s\n" + + "trn1 v20.2d, v16.2d, v18.2d\n" + "trn2 v22.2d, v16.2d, v18.2d\n" + "trn1 v21.2d, v17.2d, v19.2d\n" + "trn2 v23.2d, v17.2d, v19.2d\n" + + "str q20, [%[packed_ptr], #0]\n" + "str q21, [%[packed_ptr], #32]\n" + "str q22, [%[packed_ptr], #64]\n" + "str q23, [%[packed_ptr], #96]\n" + "add %[packed_ptr], %[packed_ptr], #128\n" + + "3:\n" + + "ands w2, %w[rows], #3\n" + "beq 4f\n" + "dup v0.16b, wzr\n" + "dup v1.16b, wzr\n" + "dup v2.16b, wzr\n" + "dup v3.16b, wzr\n" +#define RUY_LOAD_ONE_ROW(R) \ + "cmp w2, #" #R "\n" \ + "beq 5f\n" \ + "ld1 { v0.s }[" #R "], [%[src_ptr0]], #4\n" \ + "ld1 { v1.s }[" #R "], [%[src_ptr1]], #4\n" \ + "ld1 { v2.s }[" #R "], [%[src_ptr2]], #4\n" \ + "ld1 { v3.s }[" #R "], [%[src_ptr3]], #4\n" + + RUY_LOAD_ONE_ROW(0) + RUY_LOAD_ONE_ROW(1) + RUY_LOAD_ONE_ROW(2) + RUY_LOAD_ONE_ROW(3) +#undef RUY_LOAD_ONE_ROW + "5:\n" + + "trn1 v16.4s, v0.4s, v1.4s\n" + "trn2 v17.4s, v0.4s, v1.4s\n" + "trn1 v18.4s, v2.4s, v3.4s\n" + "trn2 v19.4s, v2.4s, v3.4s\n" + + "trn1 v20.2d, v16.2d, v18.2d\n" + "trn2 v22.2d, v16.2d, v18.2d\n" + "trn1 v21.2d, v17.2d, v19.2d\n" + "trn2 v23.2d, v17.2d, v19.2d\n" + + "mov x1, #32\n" + +#define RUY_STORE_ONE_ROW(ROW, REGISTER) \ + "cmp w2, #" #ROW "\n" \ + "beq 4f\n" \ + "st1 {" #REGISTER ".4s}, [%[packed_ptr]], x1\n" + + RUY_STORE_ONE_ROW(0, v20) + RUY_STORE_ONE_ROW(1, v21) + RUY_STORE_ONE_ROW(2, v22) + RUY_STORE_ONE_ROW(3, v23) + +#undef RUY_STORE_ONE_ROW + + "4:\n" + + // clang-format on + + : [ src_ptr0 ] "+r"(src_ptr0), [ src_ptr1 ] "+r"(src_ptr1), + [ src_ptr2 ] "+r"(src_ptr2), [ src_ptr3 ] "+r"(src_ptr3), + [ packed_ptr ] "+r"(packed_ptr) + : [ src_inc0 ] "r"(static_cast(src_inc0)), + [ src_inc1 ] "r"(static_cast(src_inc1)), + [ src_inc2 ] "r"(static_cast(src_inc2)), + [ src_inc3 ] "r"(static_cast(src_inc3)), + [ rows ] "r"(src_rows) + : "cc", "memory", "x1", "x2", "x10", "x11", "x12", "x13", "v0", "v1", + "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", + "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", + "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); +} +#endif + +#if RUY_PLATFORM(NEON_32) && RUY_OPT_ENABLED(RUY_OPT_ASM) +void PackFloatNeonOutOfOrder(const float* src_ptr0, const float* src_ptr1, + const float* src_ptr2, const float* src_ptr3, + int src_inc, int src_rows, int src_zero_point, + float* packed_ptr, int start_col, int end_col, + int output_stride) { + profiler::ScopeLabel label("Pack (kNeon, optimized for out-of-order cores)"); + asm volatile( + // clang-format off + "mov r1, #0\n" + "and r2, %[rows], #-4\n" + "cmp r1, r2\n" + "beq 3f\n" +#define RUY_LOAD_FOUR_BY_FOUR() \ + /* Load q0 */ \ + "vld1.32 {d0, d1}, [%[src_ptr0]]\n" \ + /* if src_inc0 != 0, add 16 to src_ptr0 */ \ + "and r3, %[src_inc], #1\n" \ + "add %[src_ptr0], %[src_ptr0], r3, lsl #4\n"\ + /* Load q1 */ \ + "vld1.32 {d2, d3}, [%[src_ptr1]]\n" \ + /* if src_inc1 != 0, add 16 to src_ptr0 */ \ + "and r3, %[src_inc], #2\n" \ + "add %[src_ptr1], %[src_ptr1], r3, lsl #3\n"\ + /* Load q2 */ \ + "vld1.32 {d4, d5}, [%[src_ptr2]]\n" \ + /* if src_inc2 != 0, add 16 to src_ptr0 */ \ + "and r3, %[src_inc], #4\n" \ + "add %[src_ptr2], %[src_ptr2], r3, lsl #2\n"\ + /* Load q3 */ \ + "vld1.32 {d6, d7}, [%[src_ptr3]]\n" \ + /* if src_inc3 != 0, add 16 to src_ptr0 */ \ + "and r3, %[src_inc], #8\n" \ + "add %[src_ptr3], %[src_ptr3], r3, lsl #1\n"\ + + RUY_LOAD_FOUR_BY_FOUR() + "add r1, r1, #4\n" + "cmp r1, r2\n" + + "beq 2f\n" + + "1:\n" + "add r1, r1, #4\n" + + // Transpose 4x4 matrix. + "vzip.32 q0, q1\n" + "vzip.32 q2, q3\n" + + "vtrn.32 q0, q2\n" + "vtrn.32 q1, q3\n" + + "vzip.32 q0, q2\n" + "vzip.32 q1, q3\n" + + "vmov q8, q0\n" + "vmov q9, q1\n" + "vmov q10, q2\n" + "vmov q11, q3\n" + + RUY_LOAD_FOUR_BY_FOUR() +#undef RUY_LOAD_FOUR_BY_FOUR + +#define RUY_STORE_FOUR_BY_FOUR() \ + /* Store q8, q10, q9, q11 */ \ + /* q8 = d16, d17 */ \ + "vst1.32 {d16, d17}, [%[packed_ptr]]\n" \ + /* q10 = d20, d21 */ \ + "add %[packed_ptr], %[packed_ptr], %[stride]\n" \ + "vst1.32 {d20, d21}, [%[packed_ptr]]\n" \ + /* q9 = d18, d19 */ \ + "add %[packed_ptr], %[packed_ptr], %[stride]\n" \ + "vst1.32 {d18, d19}, [%[packed_ptr]]\n" \ + /* q11 = d22, d23 */ \ + "add %[packed_ptr], %[packed_ptr], %[stride]\n" \ + "vst1.32 {d22, d23}, [%[packed_ptr]]\n" \ + "add %[packed_ptr], %[packed_ptr], %[stride]\n" \ + + RUY_STORE_FOUR_BY_FOUR() + "cmp r1, r2\n" + + "bne 1b\n" + + "2:\n" + + // Transpose 4x4 matrix. + "vzip.32 q0, q1\n" + "vzip.32 q2, q3\n" + + "vtrn.32 q0, q2\n" + "vtrn.32 q1, q3\n" + + "vzip.32 q0, q2\n" + "vzip.32 q1, q3\n" + + "vmov q8, q0\n" + "vmov q9, q1\n" + "vmov q10, q2\n" + "vmov q11, q3\n" + + RUY_STORE_FOUR_BY_FOUR() +#undef RUY_STORE_FOUR_BY_FOUR + "3:\n" + + "ands r2, %[rows], #3\n" + "beq 4f\n" + "mov r0, #0\n" + // Zero out q0 - q3 + "vdup.32 q0, r0\n" + "vdup.32 q1, r0\n" + "vdup.32 q2, r0\n" + "vdup.32 q3, r0\n" +#define RUY_LOAD_ONE_ROW_FIRST_HALF(R, I) \ + "cmp r2, #" #R "\n" \ + "beq 5f\n" \ + "vld1.32 { d0[" #I "] }, [%[src_ptr0]]!\n" \ + "vld1.32 { d2[" #I "] }, [%[src_ptr1]]!\n" \ + "vld1.32 { d4[" #I "] }, [%[src_ptr2]]!\n" \ + "vld1.32 { d6[" #I "] }, [%[src_ptr3]]!\n" + +#define RUY_LOAD_ONE_ROW_SECOND_HALF(R, I) \ + "vld1.32 { d1[" #I "] }, [%[src_ptr0]]!\n" \ + "vld1.32 { d3[" #I "] }, [%[src_ptr1]]!\n" \ + "vld1.32 { d5[" #I "] }, [%[src_ptr2]]!\n" \ + "vld1.32 { d7[" #I "] }, [%[src_ptr3]]!\n" + + RUY_LOAD_ONE_ROW_FIRST_HALF(0, 0) + RUY_LOAD_ONE_ROW_FIRST_HALF(1, 1) + RUY_LOAD_ONE_ROW_SECOND_HALF(2, 0) + RUY_LOAD_ONE_ROW_SECOND_HALF(3, 1) +#undef RUY_LOAD_ONE_ROW_SECOND_HALF +#undef RUY_LOAD_ONE_ROW_FIRST_HALF + "5:\n" + + // Transpose 4x4 matrix. + "vzip.32 q0, q1\n" + "vzip.32 q2, q3\n" + + "vtrn.32 q0, q2\n" + "vtrn.32 q1, q3\n" + + "vzip.32 q0, q2\n" + "vzip.32 q1, q3\n" + + "vmov q8, q0\n" + "vmov q9, q1\n" + "vmov q10, q2\n" + "vmov q11, q3\n" + + "mov r1, #32\n" + +#define RUY_STORE_ONE_ROW(ROW, REGISTER) \ + "cmp r2, #" #ROW "\n" \ + "beq 4f\n" \ + "vst1.32 {" #REGISTER "}, [%[packed_ptr]]\n" \ + "add %[packed_ptr], %[packed_ptr], %[stride]\n" + + // Store q8 + RUY_STORE_ONE_ROW(0, q8) + // Store q10 + RUY_STORE_ONE_ROW(1, q10) + // Store q9 + RUY_STORE_ONE_ROW(2, q9) + // Store q11 + RUY_STORE_ONE_ROW(3, q11) + +#undef RUY_STORE_ONE_ROW + + "4:\n" + + // clang-format on + : [ src_ptr0 ] "+r"(src_ptr0), [ src_ptr1 ] "+r"(src_ptr1), + [ src_ptr2 ] "+r"(src_ptr2), [ src_ptr3 ] "+r"(src_ptr3), + [ packed_ptr ] "+r"(packed_ptr) + : [ src_inc ] "r"(static_cast(src_inc)), + [ rows ] "r"(src_rows), [ stride ] "r"(output_stride) + : "cc", "memory", "r0", "r1", "r2", "r3", "q0", "q1", "q2", "q3", + "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11"); +} + +#endif // (RUY_PLATFORM(NEON_32) + +#if RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM) +void PackFloatNeonInOrder(const float* src_ptr0, const float* src_ptr1, + const float* src_ptr2, const float* src_ptr3, + int src_inc0, int src_inc1, int src_inc2, + int src_inc3, int src_rows, int src_zero_point, + float* packed_ptr, int start_col, int end_col) { + profiler::ScopeLabel label("Pack (kNeon, optimized for in-order cores)"); + + asm volatile( + // clang-format off + "mov w1, #0\n" + + "and w2, %w[rows], #-4\n" + "cmp w1, w2\n" + "beq 3f\n" + "ld1 {v0.4s}, [%[src_ptr0]], %[src_inc0]\n" + "ld1 {v1.4s}, [%[src_ptr1]], %[src_inc1]\n" + "ld1 {v2.4s}, [%[src_ptr2]], %[src_inc2]\n" + "ld1 {v3.4s}, [%[src_ptr3]], %[src_inc3]\n" + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #64]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #64]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #64]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #64]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #128]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #128]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #128]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #128]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #192]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #192]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #192]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #192]\n") + "add w1, w1, #4\n" + "cmp w1, w2\n" + + "beq 2f\n" + + "1:\n" + "add w1, w1, #4\n" + + "ldr x10, [%[src_ptr0], #8]\n" + "trn1 v16.4s, v0.4s, v1.4s\n" + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #240]\n") + "ldr x11, [%[src_ptr1], #8]\n" + "trn2 v17.4s, v0.4s, v1.4s\n" + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #240]\n") + "ldr x12, [%[src_ptr2], #8]\n" + "trn1 v18.4s, v2.4s, v3.4s\n" + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #240]\n") + "ldr x13, [%[src_ptr3], #8]\n" + "trn2 v19.4s, v2.4s, v3.4s\n" + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #240]\n") + + "ld1 {v0.2s}, [%[src_ptr0]], %[src_inc0]\n" + "trn1 v20.2d, v16.2d, v18.2d\n" + "ld1 {v1.2s}, [%[src_ptr1]], %[src_inc1]\n" + "trn2 v22.2d, v16.2d, v18.2d\n" + "ld1 {v2.2s}, [%[src_ptr2]], %[src_inc2]\n" + "trn1 v21.2d, v17.2d, v19.2d\n" + "ld1 {v3.2s}, [%[src_ptr3]], %[src_inc3]\n" + "trn2 v23.2d, v17.2d, v19.2d\n" + "cmp w1, w2\n" + + "ins v0.d[1], x10\n" + "str q20, [%[packed_ptr], #0]\n" + "ins v1.d[1], x11\n" + "str q21, [%[packed_ptr], #32]\n" + "ins v2.d[1], x12\n" + "str q22, [%[packed_ptr], #64]\n" + "ins v3.d[1], x13\n" + "str q23, [%[packed_ptr], #96]\n" + + "add %[packed_ptr], %[packed_ptr], #128\n" + + "bne 1b\n" + + "2:\n" + + "trn1 v16.4s, v0.4s, v1.4s\n" + "trn2 v17.4s, v0.4s, v1.4s\n" + "trn1 v18.4s, v2.4s, v3.4s\n" + "trn2 v19.4s, v2.4s, v3.4s\n" + + "trn1 v20.2d, v16.2d, v18.2d\n" + "trn2 v22.2d, v16.2d, v18.2d\n" + "trn1 v21.2d, v17.2d, v19.2d\n" + "trn2 v23.2d, v17.2d, v19.2d\n" + + "str q20, [%[packed_ptr], #0]\n" + "str q21, [%[packed_ptr], #32]\n" + "str q22, [%[packed_ptr], #64]\n" + "str q23, [%[packed_ptr], #96]\n" + "add %[packed_ptr], %[packed_ptr], #128\n" + + "3:\n" + + "ands w2, %w[rows], #3\n" + "beq 4f\n" + "dup v0.16b, wzr\n" + "dup v1.16b, wzr\n" + "dup v2.16b, wzr\n" + "dup v3.16b, wzr\n" +#define RUY_LOAD_ONE_ROW(R) \ + "cmp w2, #" #R "\n" \ + "beq 5f\n" \ + "ld1 { v0.s }[" #R "], [%[src_ptr0]], #4\n" \ + "ld1 { v1.s }[" #R "], [%[src_ptr1]], #4\n" \ + "ld1 { v2.s }[" #R "], [%[src_ptr2]], #4\n" \ + "ld1 { v3.s }[" #R "], [%[src_ptr3]], #4\n" + + RUY_LOAD_ONE_ROW(0) + RUY_LOAD_ONE_ROW(1) + RUY_LOAD_ONE_ROW(2) + RUY_LOAD_ONE_ROW(3) +#undef RUY_LOAD_ONE_ROW + "5:\n" + + "trn1 v16.4s, v0.4s, v1.4s\n" + "trn2 v17.4s, v0.4s, v1.4s\n" + "trn1 v18.4s, v2.4s, v3.4s\n" + "trn2 v19.4s, v2.4s, v3.4s\n" + + "trn1 v20.2d, v16.2d, v18.2d\n" + "trn2 v22.2d, v16.2d, v18.2d\n" + "trn1 v21.2d, v17.2d, v19.2d\n" + "trn2 v23.2d, v17.2d, v19.2d\n" + + "mov x1, #32\n" + +#define RUY_STORE_ONE_ROW(ROW, REGISTER) \ + "cmp w2, #" #ROW "\n" \ + "beq 4f\n" \ + "st1 {" #REGISTER ".4s}, [%[packed_ptr]], x1\n" + + RUY_STORE_ONE_ROW(0, v20) + RUY_STORE_ONE_ROW(1, v21) + RUY_STORE_ONE_ROW(2, v22) + RUY_STORE_ONE_ROW(3, v23) + +#undef RUY_STORE_ONE_ROW + + "4:\n" + + // clang-format on + + : [ src_ptr0 ] "+r"(src_ptr0), [src_ptr1] "+r"(src_ptr1), [src_ptr2] "+r"(src_ptr2), + [src_ptr3] "+r"(src_ptr3), [packed_ptr] "+r"(packed_ptr) + : [ src_inc0 ] "r"(static_cast(src_inc0)), [src_inc1] "r"(static_cast(src_inc1)), [src_inc2] "r"(static_cast(src_inc2)), + [src_inc3] "r"(static_cast(src_inc3)), [rows] "r"(src_rows) + : "cc", "memory", "x1", "x2", "x10", "x11", "x12", "x13", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", + "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", + "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); +} +#endif // RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM) + +} // namespace ruy diff --git a/ruy/pack_arm.h b/ruy/pack_arm.h new file mode 100644 index 0000000..8e7f619 --- /dev/null +++ b/ruy/pack_arm.h @@ -0,0 +1,497 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// # What is "packing"? +// +// Before feeding data to the gemm kernels (the parts of Ruy that do lots +// of multiply-add operations), Ruy first performs a data transformation (which +// we call "packing") on the input matrices. This transformation has two main +// goals: +// - rearrange data into blocks that are a convenient size/layout for the gemm +// kernels to consume. This helps make the memory access pattern of the gemm +// kernel simpler and more contiguous, and puts the data in a layout most +// convenient for specific arithmetic instructions in the gemm kernel. +// - compute row/column sums needed for handling quantization with non-symmetric +// zero points. +// +// # Simplified algorithmic analysis of packing +// +// Packing is a relatively simple transformation which does a small constant +// amount of work on each element of an input matrix, and hence for an NxM +// matrix performs O(N*M) work. If N and M are of the same order, then this is +// O(N^2) work. +// +// A NxKxM matrix multiplication requires N*K*M multiply-accumulate operations. +// Note that if N, K, and M are all the same order, then the number of +// multiply-accumulate operations is O(N^3). +// +// Thus, the O(N^2) cost of packing is small compared to the O(N^3) work, in the +// case of all dimensions being roughly the same order. +// +// # Packing cost can be significant +// +// When matrix * matrix multiplications begin to look more like matrix * vector +// multiplications, packing cost can become significant. We sometimes call these +// cases "gemv-like". +// +// Continuing the algorithmic analysis above, if we consider a case where an +// NxKxM matrix multiplication has either N = O(1) or M = O(1), then the +// situation is different. In this case, the multiply-accumulate work is only +// quadratic, so the quadratic cost of packing can be come significant. +// +// Another way to say this is that the cost of packing an input matrix (either +// the LHS or RHS) is amortized across the non-depth dimension of the opposite +// input matrix. Thus, when the LHS has very few rows or the RHS has very few +// columns, the cost of packing the opposite input matrix can become +// significant. +// +// As a rough rule of thumb, the cost of packing starts to become significant +// when either N or M is below 32 (and other dimensions are hundreds), with very +// significant packing costs at 8 or below. This varies by data type, Path, and +// tuning, so these numbers are only rough guides. +// +// One practical use case that is affected by this is inference of +// fully connected neural network layers with a low batch size. The weight +// matrix (which is a constant for inference) is the one affected by significant +// packing cost. +// +// Ruy provides an API in ruy_advanced.h for advanced users to pre-pack +// input matrices that are affected by significant packing costs. +// +// # Implementation notes +// +// Ruy's packing routines always operate on a range of columns and can be +// applied to either the LHS or RHS. This is possible because Ruy internally +// implements a TrMul, so the accumulation along depth is done along columns of +// both the LHS and RHS (whereas for a normal Mul the accumulation along depth +// for the LHS is along rows). As another example, we are always computing +// column sums for quantization (and never row sums, since the LHS is +// transposed). + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PACK_ARM_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PACK_ARM_H_ + +#include +#include + +#include "ruy/check_macros.h" +#include "ruy/common.h" +#include "ruy/internal_matrix.h" +#include "ruy/matrix.h" +#include "ruy/opt_set.h" +#include "ruy/pack_common.h" +#include "ruy/path.h" +#include "ruy/platform.h" +#include "ruy/profiler/instrumentation.h" +#include "ruy/tune.h" + +namespace ruy { + +#if RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM) +void Pack8bitNeonOutOfOrder(const void* src_ptr0, const void* src_ptr1, + const void* src_ptr2, const void* src_ptr3, + int src_inc0, int src_inc1, int src_inc2, + int src_inc3, int src_rows, int src_zero_point, + std::int8_t* packed_ptr, int start_col, int end_col, + std::int32_t* sums_ptr, int input_xor); +void Pack8bitNeonInOrder(const void* src_ptr0, const void* src_ptr1, + const void* src_ptr2, const void* src_ptr3, + int src_inc0, int src_inc1, int src_inc2, int src_inc3, + int src_rows, int src_zero_point, + std::int8_t* packed_ptr, int start_col, int end_col, + std::int32_t* sums_ptr, int input_xor); +void Pack8bitNeonDotprodOutOfOrder(const void* src_ptr0, const void* src_ptr1, + const void* src_ptr2, const void* src_ptr3, + int src_inc0, int src_inc1, int src_inc2, + int src_inc3, int src_rows, + int src_zero_point, std::int8_t* packed_ptr, + int start_col, int end_col, + std::int32_t* sums_ptr, int input_xor); +void Pack8bitNeonDotprodInOrder(const void* src_ptr0, const void* src_ptr1, + const void* src_ptr2, const void* src_ptr3, + int src_inc0, int src_inc1, int src_inc2, + int src_inc3, int src_rows, int src_zero_point, + std::int8_t* packed_ptr, int start_col, + int end_col, std::int32_t* sums_ptr, + int input_xor); + +#elif RUY_PLATFORM(NEON_32) && RUY_OPT_ENABLED(RUY_OPT_ASM) +void Pack8bitNeonOutOfOrder4Cols(const PackParams8bit& params); +void Pack8bitNeonOutOfOrder2Cols(const PackParams8bit& params); +#endif // (RUY_PLATFORM(NEON_64)&& RUY_OPT_ENABLED(RUY_OPT_ASM) + +#if (RUY_PLATFORM(NEON_32) || RUY_PLATFORM(NEON_64)) && \ + RUY_OPT_ENABLED(RUY_OPT_ASM) + +template +struct PackImpl, Scalar, + std::int8_t, std::int32_t> { + static_assert(std::is_same::value || + std::is_same::value, + ""); + static constexpr int kInputXor = + std::is_same::value ? 0 : 0x80; + + static void Run(Tuning tuning, const Matrix& src_matrix, + PackedMatrix* packed_matrix, int start_col, + int end_col) { + RUY_DCHECK(IsColMajor(src_matrix.layout)); + RUY_DCHECK(IsColMajor(packed_matrix->layout)); + RUY_DCHECK_EQ(start_col % 4, 0); + std::int32_t* sums = packed_matrix->sums; + Scalar zerobuf[16]; + memset(zerobuf, src_matrix.zero_point, sizeof(zerobuf)); + for (int block_col = start_col; block_col < end_col; block_col += 4) { + int src_stride = src_matrix.layout.stride; + const Scalar* src_ptr0 = src_matrix.data.get() + src_stride * block_col; + const Scalar* src_ptr1 = src_ptr0 + src_stride; + const Scalar* src_ptr2 = src_ptr1 + src_stride; + const Scalar* src_ptr3 = src_ptr2 + src_stride; + int src_inc0 = 16; + int src_inc1 = 16; + int src_inc2 = 16; + int src_inc3 = 16; + if (block_col >= src_matrix.layout.cols - 3) { + if (block_col >= src_matrix.layout.cols - 0) { + src_ptr0 = zerobuf; + src_inc0 = 0; + } + if (block_col >= src_matrix.layout.cols - 1) { + src_ptr1 = zerobuf; + src_inc1 = 0; + } + if (block_col >= src_matrix.layout.cols - 2) { + src_ptr2 = zerobuf; + src_inc2 = 0; + } + if (block_col >= src_matrix.layout.cols - 3) { + src_ptr3 = zerobuf; + src_inc3 = 0; + } + } + std::int8_t* packed_ptr = + packed_matrix->data + packed_matrix->layout.stride * block_col; + std::int32_t* sums_ptr = sums ? sums + block_col : nullptr; +#if RUY_PLATFORM(NEON_64) + if (__builtin_expect(tuning == Tuning::kInOrder, true)) { + Pack8bitNeonInOrder( + src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc0, src_inc1, + src_inc2, src_inc3, src_matrix.layout.rows, src_matrix.zero_point, + packed_ptr, start_col, end_col, sums_ptr, kInputXor); + } else { + Pack8bitNeonOutOfOrder( + src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc0, src_inc1, + src_inc2, src_inc3, src_matrix.layout.rows, src_matrix.zero_point, + packed_ptr, start_col, end_col, sums_ptr, kInputXor); + } +#else + // We have a more limited set of general purpose registers in ARMv7, so + // we use the "params" struct technique from the kernel code to save + // registers. + PackParams8bit params; + MakePackParams8bit(src_ptr0, src_ptr1, src_ptr2, src_ptr3, sums_ptr, + packed_ptr, src_inc0, src_inc1, src_inc2, src_inc3, + src_matrix.layout.rows, src_matrix.zero_point, + kInputXor, ¶ms); + Pack8bitNeonOutOfOrder4Cols(params); +#endif // RUY_PLATFORM(NEON_64) + } + } +}; + +#endif // (RUY_PLATFORM(NEON_32) || RUY_PLATFORM(NEON_64)) && + // RUY_OPT_ENABLED(RUY_OPT_ASM) + +#if RUY_PLATFORM(NEON_32) && RUY_OPT_ENABLED(RUY_OPT_ASM) +// The 32-bit float kernel is 4 rows X 2 columns, so we need an additional +// partial specialization for the RHS, which has a FixedKernelLayout with 2 +// columns. +template +struct PackImpl, Scalar, + std::int8_t, std::int32_t> { + static_assert(std::is_same::value || + std::is_same::value, + ""); + static constexpr int kInputXor = + std::is_same::value ? 0 : 0x80; + static void Run(Tuning tuning, const Matrix& src_matrix, + PackedMatrix* packed_matrix, int start_col, + int end_col) { + RUY_DCHECK(IsColMajor(src_matrix.layout)); + RUY_DCHECK(IsColMajor(packed_matrix->layout)); + RUY_DCHECK_EQ(start_col % 2, 0); + std::int32_t* sums = packed_matrix->sums; + Scalar zerobuf[16]; + memset(zerobuf, src_matrix.zero_point, sizeof(zerobuf)); + for (int block_col = start_col; block_col < end_col; block_col += 2) { + int src_stride = src_matrix.layout.stride; + const Scalar* src_ptr0 = src_matrix.data.get() + src_stride * block_col; + const Scalar* src_ptr1 = src_ptr0 + src_stride; + int src_inc0 = 16; + int src_inc1 = 16; + if (block_col >= src_matrix.layout.cols - 2) { + if (block_col >= src_matrix.layout.cols - 0) { + src_ptr0 = zerobuf; + src_inc0 = 0; + } + if (block_col >= src_matrix.layout.cols - 1) { + src_ptr1 = zerobuf; + src_inc1 = 0; + } + } + std::int8_t* packed_ptr = + packed_matrix->data + packed_matrix->layout.stride * block_col; + std::int32_t* sums_ptr = sums ? sums + block_col : nullptr; + PackParams8bit params; + MakePackParams8bit(src_ptr0, src_ptr1, nullptr, nullptr, sums_ptr, + packed_ptr, src_inc0, src_inc1, -1, -1, + src_matrix.layout.rows, src_matrix.zero_point, + kInputXor, ¶ms); + Pack8bitNeonOutOfOrder2Cols(params); + } + } +}; +#endif // (RUY_PLATFORM(NEON_32)) && RUY_OPT_ENABLED(RUY_OPT_ASM) + +#if RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM) +template +struct PackImpl, + Scalar, std::int8_t, std::int32_t> { + static_assert(std::is_same::value || + std::is_same::value, + ""); + static constexpr int kInputXor = + std::is_same::value ? 0 : 0x80; + + static void Run(Tuning tuning, const Matrix& src_matrix, + PackedMatrix* packed_matrix, int start_col, + int end_col) { + RUY_DCHECK(IsColMajor(src_matrix.layout)); + RUY_DCHECK(IsColMajor(packed_matrix->layout)); + RUY_DCHECK_EQ(start_col % 8, 0); + std::int32_t* sums = packed_matrix->sums; + Scalar zerobuf[16]; + memset(zerobuf, src_matrix.zero_point, sizeof(zerobuf)); + for (int block_col = start_col; block_col < end_col; block_col += 4) { + int src_stride = src_matrix.layout.stride; + const Scalar* src_ptr0 = src_matrix.data.get() + src_stride * block_col; + const Scalar* src_ptr1 = src_ptr0 + src_stride; + const Scalar* src_ptr2 = src_ptr1 + src_stride; + const Scalar* src_ptr3 = src_ptr2 + src_stride; + std::int64_t src_inc0 = 16; + std::int64_t src_inc1 = 16; + std::int64_t src_inc2 = 16; + std::int64_t src_inc3 = 16; + if (block_col >= src_matrix.layout.cols - 3) { + if (block_col >= src_matrix.layout.cols - 0) { + src_ptr0 = zerobuf; + src_inc0 = 0; + } + if (block_col >= src_matrix.layout.cols - 1) { + src_ptr1 = zerobuf; + src_inc1 = 0; + } + if (block_col >= src_matrix.layout.cols - 2) { + src_ptr2 = zerobuf; + src_inc2 = 0; + } + if (block_col >= src_matrix.layout.cols - 3) { + src_ptr3 = zerobuf; + src_inc3 = 0; + } + } + std::int8_t* packed_ptr = + packed_matrix->data + + packed_matrix->layout.stride * (block_col & ~7) + + ((block_col & 4) * 4); + std::int32_t* sums_ptr = sums ? sums + block_col : nullptr; + if (__builtin_expect(tuning == Tuning::kInOrder, true)) { + Pack8bitNeonDotprodInOrder( + src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc0, src_inc1, + src_inc2, src_inc3, src_matrix.layout.rows, src_matrix.zero_point, + packed_ptr, start_col, end_col, sums_ptr, kInputXor); + } else { + Pack8bitNeonDotprodOutOfOrder( + src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc0, src_inc1, + src_inc2, src_inc3, src_matrix.layout.rows, src_matrix.zero_point, + packed_ptr, start_col, end_col, sums_ptr, kInputXor); + } + } + } +}; +#endif // (RUY_PLATFORM(NEON_64)&& RUY_OPT_ENABLED(RUY_OPT_ASM) + +#if RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM) +void PackFloatNeonOutOfOrder(const float* src_ptr0, const float* src_ptr1, + const float* src_ptr2, const float* src_ptr3, + int src_inc0, int src_inc1, int src_inc2, + int src_inc3, int src_rows, int src_zero_point, + float* packed_ptr, int start_col, int end_col); +void PackFloatNeonInOrder(const float* src_ptr0, const float* src_ptr1, + const float* src_ptr2, const float* src_ptr3, + int src_inc0, int src_inc1, int src_inc2, + int src_inc3, int src_rows, int src_zero_point, + float* packed_ptr, int start_col, int end_col); + +#elif RUY_PLATFORM(NEON_32) && RUY_OPT_ENABLED(RUY_OPT_ASM) +void PackFloatNeonOutOfOrder(const float* src_ptr0, const float* src_ptr1, + const float* src_ptr2, const float* src_ptr3, + int src_inc, int src_rows, int src_zero_point, + float* packed_ptr, int start_col, int end_col, + int stride); +#endif // (RUY_PLATFORM(NEON_64)&& RUY_OPT_ENABLED(RUY_OPT_ASM) + +#if (RUY_PLATFORM(NEON_32) || RUY_PLATFORM(NEON_64)) && \ + RUY_OPT_ENABLED(RUY_OPT_ASM) + +template <> +struct PackImpl, float, + float, float> { + static void Run(Tuning tuning, const Matrix& src_matrix, + PackedMatrix* packed_matrix, int start_col, + int end_col) { + RUY_DCHECK(IsColMajor(src_matrix.layout)); + RUY_DCHECK(IsColMajor(packed_matrix->layout)); + RUY_DCHECK_EQ(start_col % 8, 0); + const float zerobuf[4] = {0}; + for (int block_col = start_col; block_col < end_col; block_col += 4) { + int src_stride = src_matrix.layout.stride; + const float* src_ptr0 = src_matrix.data.get() + src_stride * block_col; + const float* src_ptr1 = src_ptr0 + src_stride; + const float* src_ptr2 = src_ptr1 + src_stride; + const float* src_ptr3 = src_ptr2 + src_stride; + std::int64_t src_inc0 = 16; + std::int64_t src_inc1 = 16; + std::int64_t src_inc2 = 16; + std::int64_t src_inc3 = 16; + if (block_col >= src_matrix.layout.cols - 3) { + if (block_col >= src_matrix.layout.cols - 0) { + src_ptr0 = zerobuf; + src_inc0 = 0; + } + if (block_col >= src_matrix.layout.cols - 1) { + src_ptr1 = zerobuf; + src_inc1 = 0; + } + if (block_col >= src_matrix.layout.cols - 2) { + src_ptr2 = zerobuf; + src_inc2 = 0; + } + if (block_col >= src_matrix.layout.cols - 3) { + src_ptr3 = zerobuf; + src_inc3 = 0; + } + } + float* packed_ptr = packed_matrix->data + + packed_matrix->layout.stride * (block_col & ~7) + + ((block_col & 4)); +#if RUY_PLATFORM(NEON_64) + if (__builtin_expect(tuning == Tuning::kInOrder, true)) { + PackFloatNeonInOrder(src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc0, + src_inc1, src_inc2, src_inc3, + src_matrix.layout.rows, src_matrix.zero_point, + packed_ptr, start_col, end_col); + } else { + PackFloatNeonOutOfOrder(src_ptr0, src_ptr1, src_ptr2, src_ptr3, + src_inc0, src_inc1, src_inc2, src_inc3, + src_matrix.layout.rows, src_matrix.zero_point, + packed_ptr, start_col, end_col); + } +#else + // Encode each of src_inc0, ..., src_inc3 in lowest 4 bits of src_inc + // to save on registers (we have fewer general purpose registers in + // 32-bit ARM than in 64-bit ARM). For the 64-bit case, we pass four + // values that are each either 16 or 0 and use them directly. For the + // 32-bit case, bits 0, 1, 2, and 3 are used to determine if we should + // use the value 16 (bit is set) or 0 (bit is not set) for the + // respective increment value. + std::int64_t src_inc = 0; + src_inc += src_inc0 == 16 ? 1 : 0; + src_inc += src_inc1 == 16 ? 2 : 0; + src_inc += src_inc2 == 16 ? 4 : 0; + src_inc += src_inc3 == 16 ? 8 : 0; + const int kOutputStride = 32; + PackFloatNeonOutOfOrder(src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc, + src_matrix.layout.rows, src_matrix.zero_point, + packed_ptr, start_col, end_col, kOutputStride); +#endif // RUY_PLATFORM(NEON_64) + } + } +}; + +#if RUY_PLATFORM(NEON_32) +// The 32-bit float kernel is 8 rows X 4 columns, so we need an additional +// specialization for a FixedKernelLayout with 4 columns. +template <> +struct PackImpl, float, + float, float> { + static void Run(Tuning tuning, const Matrix& src_matrix, + PackedMatrix* packed_matrix, int start_col, + int end_col) { + RUY_DCHECK(IsColMajor(src_matrix.layout)); + RUY_DCHECK(IsColMajor(packed_matrix->layout)); + RUY_DCHECK_EQ(start_col % 4, 0); + const float zerobuf[4] = {0}; + for (int block_col = start_col; block_col < end_col; block_col += 4) { + int src_stride = src_matrix.layout.stride; + const float* src_ptr0 = src_matrix.data.get() + src_stride * block_col; + const float* src_ptr1 = src_ptr0 + src_stride; + const float* src_ptr2 = src_ptr1 + src_stride; + const float* src_ptr3 = src_ptr2 + src_stride; + std::int64_t src_inc0 = 16; + std::int64_t src_inc1 = 16; + std::int64_t src_inc2 = 16; + std::int64_t src_inc3 = 16; + if (block_col >= src_matrix.layout.cols - 3) { + if (block_col >= src_matrix.layout.cols - 0) { + src_ptr0 = zerobuf; + src_inc0 = 0; + } + if (block_col >= src_matrix.layout.cols - 1) { + src_ptr1 = zerobuf; + src_inc1 = 0; + } + if (block_col >= src_matrix.layout.cols - 2) { + src_ptr2 = zerobuf; + src_inc2 = 0; + } + if (block_col >= src_matrix.layout.cols - 3) { + src_ptr3 = zerobuf; + src_inc3 = 0; + } + } + float* packed_ptr = + packed_matrix->data + packed_matrix->layout.stride * (block_col); + // Encode each of src_inc0, ..., src_inc1 in lowest 4 bits of scrc_inc + // to save registers. + std::int64_t src_inc = 0; + src_inc += src_inc0 == 16 ? 1 : 0; + src_inc += src_inc1 == 16 ? 2 : 0; + src_inc += src_inc2 == 16 ? 4 : 0; + src_inc += src_inc3 == 16 ? 8 : 0; + const int kOutputStride = 16; + PackFloatNeonOutOfOrder(src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc, + src_matrix.layout.rows, src_matrix.zero_point, + packed_ptr, start_col, end_col, kOutputStride); + } + } +}; +#endif // (RUY_PLATFORM(NEON_32)) +#endif // (RUY_PLATFORM(NEON_64) || RUY_PLATFORM(NEON_32)) && \ + // RUY_OPT_ENABLED(RUY_OPT_ASM) + +} // namespace ruy + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PACK_ARM_H_ diff --git a/ruy/pack_avx2.cc b/ruy/pack_avx2.cc new file mode 100644 index 0000000..013a8c0 --- /dev/null +++ b/ruy/pack_avx2.cc @@ -0,0 +1,816 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "ruy/check_macros.h" +#include "ruy/matrix.h" +#include "ruy/opt_set.h" +#include "ruy/pack.h" +#include "ruy/path.h" +#include "ruy/platform.h" +#include "ruy/profiler/instrumentation.h" + +#if RUY_PLATFORM(AVX2) && RUY_OPT_ENABLED(RUY_OPT_INTRINSICS) +#include // IWYU pragma: keep +#endif + +namespace ruy { + +#if !(RUY_PLATFORM(AVX2) && RUY_OPT_ENABLED(RUY_OPT_ASM)) + +void Pack8bitAvx2(const std::int8_t* src_ptr, std::int8_t input_xor, + const std::int8_t* zerobuf, int src_stride, + int remaining_src_cols, int src_rows, std::int8_t* packed_ptr, + std::int32_t* sums_ptr) { + // CPU-ID-based checks should disable the path that would reach this point. + RUY_DCHECK(false); +} + +void PackFloatAvx2(const float* src_ptr, const float* zerobuf, int src_stride, + int remaining_src_cols, int src_rows, float* packed_ptr) { + // CPU-ID-based checks should disable the path that would reach this point. + RUY_DCHECK(false); +} + +#else // RUY_PLATFORM(AVX2) && RUY_OPT_ENABLED(RUY_OPT_ASM) + +// The first int8_t template parameter is arbitrary: this routine is common to +// all 8-bit source matrix types. +using PackImpl8bitAvx2 = + PackImpl, + std::int8_t, std::int8_t, std::int32_t>; + +using PackImplFloatAvx2 = + PackImpl, float, + float, float>; + +namespace { + +inline __m256i MaskLoadu(int available_src_rows, std::int8_t zero_point, + const std::int8_t* addr) { + RUY_DCHECK_LT(available_src_rows, 32); + __m256i padded_data; + + if (available_src_rows >= 16) { + __m128i load_hi = _mm_set1_epi8(zero_point); + __m128i load_lo = _mm_loadu_si128(reinterpret_cast(addr)); + memcpy(&load_hi, addr + 16, available_src_rows - 16); + padded_data = _mm256_set_m128i(load_hi, load_lo); + } else { + __m128i load_hi = _mm_set1_epi8(zero_point); + __m128i load_lo = load_hi; + memcpy(&load_lo, addr, available_src_rows); + padded_data = _mm256_set_m128i(load_hi, load_lo); + } + return padded_data; +} + +inline void Pack8bitAvx2Packer(const std::int8_t* src_ptr, + std::int8_t input_xor, + const std::int8_t* zerobuf, int src_stride, + int remaining_src_cols, int src_rows, + std::int8_t* packed_ptr, std::int32_t* sums_ptr, + std::int8_t* trailing_buf) { + using Layout = PackImpl8bitAvx2::Layout; + RUY_DCHECK_EQ(Layout::kCols, 8); + RUY_DCHECK_EQ(Layout::kRows, 4); + // Each Layout::Rows is 4 contiguous input, contiguous packed elements. + // We process 8 of these chunks at a time, padding short input chunks. + constexpr int kNumRowChunks = 8; + constexpr int kNumChunkedSrcRows = kNumRowChunks * Layout::kRows; + + const std::int8_t* src_ptr0 = src_ptr; + const std::int8_t* src_ptr1 = src_ptr0 + src_stride; + const std::int8_t* src_ptr2 = src_ptr1 + src_stride; + const std::int8_t* src_ptr3 = src_ptr2 + src_stride; + const std::int8_t* src_ptr4 = src_ptr3 + src_stride; + const std::int8_t* src_ptr5 = src_ptr4 + src_stride; + const std::int8_t* src_ptr6 = src_ptr5 + src_stride; + const std::int8_t* src_ptr7 = src_ptr6 + src_stride; + std::int64_t src_inc0 = kNumChunkedSrcRows; + std::int64_t src_inc1 = kNumChunkedSrcRows; + std::int64_t src_inc2 = kNumChunkedSrcRows; + std::int64_t src_inc3 = kNumChunkedSrcRows; + std::int64_t src_inc4 = kNumChunkedSrcRows; + std::int64_t src_inc5 = kNumChunkedSrcRows; + std::int64_t src_inc6 = kNumChunkedSrcRows; + std::int64_t src_inc7 = kNumChunkedSrcRows; + // Handle cases where source does not have Layout::kCols (8) columns. + if (remaining_src_cols < 8) { + if (remaining_src_cols <= 0) { + src_ptr0 = zerobuf; + src_inc0 = 0; + } + if (remaining_src_cols <= 1) { + src_ptr1 = zerobuf; + src_inc1 = 0; + } + if (remaining_src_cols <= 2) { + src_ptr2 = zerobuf; + src_inc2 = 0; + } + if (remaining_src_cols <= 3) { + src_ptr3 = zerobuf; + src_inc3 = 0; + } + if (remaining_src_cols <= 4) { + src_ptr4 = zerobuf; + src_inc4 = 0; + } + if (remaining_src_cols <= 5) { + src_ptr5 = zerobuf; + src_inc5 = 0; + } + if (remaining_src_cols <= 6) { + src_ptr6 = zerobuf; + src_inc6 = 0; + } + src_ptr7 = zerobuf; + src_inc7 = 0; + } + + const std::int8_t zero_point = zerobuf[0]; + + if (sums_ptr) { + // i: Layout::kCols. + for (int i = 0; i < 8; ++i) { + sums_ptr[i] = 0; + } + } + std::int32_t sums_adjustment = 0; + const __m256i ones_16bit = _mm256_set1_epi16(1); + __m256i sums_4x2_32bit_lo = _mm256_set1_epi32(0); + __m256i sums_4x2_32bit_hi = _mm256_set1_epi32(0); + + // The overall packing effectively pads the source rows to + // (src_rows + 63) & ~63. The iteration over k may skip when m=1, and then we + // only pack for (src_rows + 31) & ~31. When there is an incomplete + // destination block, this is stored into trailing_buf instead of packed_ptr. + for (int k = 0; k < src_rows; k += kNumChunkedSrcRows) { + // Available source rows. + // If this is less than 0 (for m=1), we skip, having filled trailing + // buffer for m=0. Also, if source rows is zero on m=1, then we filled + // exactly to the end of the column in the packed buffer. + const int available_src_rows = src_rows - k; + // Effectively, + // available rows = std::max(0, std::min(8, src_rows - k)); + // treat each case separately. + if (available_src_rows >= kNumChunkedSrcRows) { + if (sums_ptr) { + __m256i t0, t1, t2, t3, t4, t5, t6, t7; + __m256i r0, r1, r2, r3, r4, r5, r6, r7; + const __m256i input_xor_v = _mm256_set1_epi8(input_xor); + + t0 = _mm256_loadu_si256(reinterpret_cast(src_ptr0)); + t4 = _mm256_loadu_si256(reinterpret_cast(src_ptr4)); + t1 = _mm256_loadu_si256(reinterpret_cast(src_ptr1)); + t5 = _mm256_loadu_si256(reinterpret_cast(src_ptr5)); + t2 = _mm256_loadu_si256(reinterpret_cast(src_ptr2)); + t6 = _mm256_loadu_si256(reinterpret_cast(src_ptr6)); + t3 = _mm256_loadu_si256(reinterpret_cast(src_ptr3)); + t7 = _mm256_loadu_si256(reinterpret_cast(src_ptr7)); + + r0 = _mm256_unpacklo_epi32(t0, t1); + r4 = _mm256_unpacklo_epi32(t4, t5); + r2 = _mm256_unpackhi_epi32(t0, t1); + r6 = _mm256_unpackhi_epi32(t4, t5); + r1 = _mm256_unpacklo_epi32(t2, t3); + r5 = _mm256_unpacklo_epi32(t6, t7); + r3 = _mm256_unpackhi_epi32(t2, t3); + r7 = _mm256_unpackhi_epi32(t6, t7); + + t0 = _mm256_unpacklo_epi64(r0, r1); + t4 = _mm256_unpacklo_epi64(r4, r5); + t2 = _mm256_unpackhi_epi64(r0, r1); + t6 = _mm256_unpackhi_epi64(r4, r5); + t1 = _mm256_unpacklo_epi64(r2, r3); + t5 = _mm256_unpacklo_epi64(r6, r7); + t3 = _mm256_unpackhi_epi64(r2, r3); + t7 = _mm256_unpackhi_epi64(r6, r7); + + // The preceding sets of rearrangement operations interleaved by 4 bytes + // and then by 8 bytes *within* lanes. The following set interleave by + // 16 bytes (128-bit), operating *between* AVX lanes. For instance (t0, + // t4) are interleaved to create (r0, r1). This complexity follows from + // the way that AVX is centered around MM 128-bit lanes. + r0 = _mm256_permute2x128_si256(t0, t4, 0x20); + r4 = _mm256_permute2x128_si256(t1, t5, 0x20); + r1 = _mm256_permute2x128_si256(t0, t4, 0x31); + r5 = _mm256_permute2x128_si256(t1, t5, 0x31); + r2 = _mm256_permute2x128_si256(t2, t6, 0x20); + r6 = _mm256_permute2x128_si256(t3, t7, 0x20); + r3 = _mm256_permute2x128_si256(t2, t6, 0x31); + r7 = _mm256_permute2x128_si256(t3, t7, 0x31); + + r0 = _mm256_xor_si256(r0, input_xor_v); + r1 = _mm256_xor_si256(r1, input_xor_v); + r2 = _mm256_xor_si256(r2, input_xor_v); + r3 = _mm256_xor_si256(r3, input_xor_v); + r4 = _mm256_xor_si256(r4, input_xor_v); + r5 = _mm256_xor_si256(r5, input_xor_v); + r6 = _mm256_xor_si256(r6, input_xor_v); + r7 = _mm256_xor_si256(r7, input_xor_v); + + __m256i sums_4x4_16bit_lo; + sums_4x4_16bit_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r0)); + sums_4x4_16bit_lo = + _mm256_add_epi16(sums_4x4_16bit_lo, + _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r1))); + sums_4x4_16bit_lo = + _mm256_add_epi16(sums_4x4_16bit_lo, + _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r2))); + sums_4x4_16bit_lo = + _mm256_add_epi16(sums_4x4_16bit_lo, + _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r3))); + sums_4x4_16bit_lo = + _mm256_add_epi16(sums_4x4_16bit_lo, + _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r4))); + sums_4x4_16bit_lo = + _mm256_add_epi16(sums_4x4_16bit_lo, + _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r5))); + sums_4x4_16bit_lo = + _mm256_add_epi16(sums_4x4_16bit_lo, + _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r6))); + sums_4x4_16bit_lo = + _mm256_add_epi16(sums_4x4_16bit_lo, + _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r7))); + + // The sums have been performed across columns, and now we have 4x16-bit + // sums packed together. We use madd for pairwise 32-bit sums. + const __m256i sums_4x2_32bit_lo_new = + _mm256_madd_epi16(sums_4x4_16bit_lo, ones_16bit); + sums_4x2_32bit_lo = + _mm256_add_epi32(sums_4x2_32bit_lo, sums_4x2_32bit_lo_new); + + __m256i sums_4x4_16bit_hi; + sums_4x4_16bit_hi = + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r0, 1)); + sums_4x4_16bit_hi = _mm256_add_epi16( + sums_4x4_16bit_hi, + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r1, 1))); + sums_4x4_16bit_hi = _mm256_add_epi16( + sums_4x4_16bit_hi, + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r2, 1))); + sums_4x4_16bit_hi = _mm256_add_epi16( + sums_4x4_16bit_hi, + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r3, 1))); + sums_4x4_16bit_hi = _mm256_add_epi16( + sums_4x4_16bit_hi, + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r4, 1))); + sums_4x4_16bit_hi = _mm256_add_epi16( + sums_4x4_16bit_hi, + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r5, 1))); + sums_4x4_16bit_hi = _mm256_add_epi16( + sums_4x4_16bit_hi, + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r6, 1))); + sums_4x4_16bit_hi = _mm256_add_epi16( + sums_4x4_16bit_hi, + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r7, 1))); + + const __m256i sums_4x2_32bit_hi_new = + _mm256_madd_epi16(sums_4x4_16bit_hi, ones_16bit); + sums_4x2_32bit_hi = + _mm256_add_epi32(sums_4x2_32bit_hi, sums_4x2_32bit_hi_new); + + _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 0 * 8 * 4), + r0); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 2 * 8 * 4), + r4); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 4 * 8 * 4), + r1); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 6 * 8 * 4), + r5); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 1 * 8 * 4), + r2); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 3 * 8 * 4), + r6); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 5 * 8 * 4), + r3); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 7 * 8 * 4), + r7); + } else { + __m256i t0, t1, t2, t3, t4, t5, t6, t7; + __m256i r0, r1, r2, r3, r4, r5, r6, r7; + const __m256i input_xor_v = _mm256_set1_epi8(input_xor); + + t0 = _mm256_loadu_si256(reinterpret_cast(src_ptr0)); + t4 = _mm256_loadu_si256(reinterpret_cast(src_ptr4)); + t1 = _mm256_loadu_si256(reinterpret_cast(src_ptr1)); + t5 = _mm256_loadu_si256(reinterpret_cast(src_ptr5)); + t2 = _mm256_loadu_si256(reinterpret_cast(src_ptr2)); + t6 = _mm256_loadu_si256(reinterpret_cast(src_ptr6)); + t3 = _mm256_loadu_si256(reinterpret_cast(src_ptr3)); + t7 = _mm256_loadu_si256(reinterpret_cast(src_ptr7)); + + r0 = _mm256_unpacklo_epi32(t0, t1); + r4 = _mm256_unpacklo_epi32(t4, t5); + r2 = _mm256_unpackhi_epi32(t0, t1); + r6 = _mm256_unpackhi_epi32(t4, t5); + r1 = _mm256_unpacklo_epi32(t2, t3); + r5 = _mm256_unpacklo_epi32(t6, t7); + r3 = _mm256_unpackhi_epi32(t2, t3); + r7 = _mm256_unpackhi_epi32(t6, t7); + + t0 = _mm256_unpacklo_epi64(r0, r1); + t4 = _mm256_unpacklo_epi64(r4, r5); + t2 = _mm256_unpackhi_epi64(r0, r1); + t6 = _mm256_unpackhi_epi64(r4, r5); + t1 = _mm256_unpacklo_epi64(r2, r3); + t5 = _mm256_unpacklo_epi64(r6, r7); + t3 = _mm256_unpackhi_epi64(r2, r3); + t7 = _mm256_unpackhi_epi64(r6, r7); + + // The preceding sets of rearrangement operations interleaved by 4 bytes + // and then by 8 bytes *within* lanes. The following set interleave by + // 16 bytes (128-bit), operating *between* AVX lanes. For instance (t0, + // t4) are interleaved to create (r0, r1). This complexity follows from + // the way that AVX is centered around MM 128-bit lanes. + r0 = _mm256_permute2x128_si256(t0, t4, 0x20); + r4 = _mm256_permute2x128_si256(t1, t5, 0x20); + r1 = _mm256_permute2x128_si256(t0, t4, 0x31); + r5 = _mm256_permute2x128_si256(t1, t5, 0x31); + r2 = _mm256_permute2x128_si256(t2, t6, 0x20); + r6 = _mm256_permute2x128_si256(t3, t7, 0x20); + r3 = _mm256_permute2x128_si256(t2, t6, 0x31); + r7 = _mm256_permute2x128_si256(t3, t7, 0x31); + + r0 = _mm256_xor_si256(r0, input_xor_v); + r1 = _mm256_xor_si256(r1, input_xor_v); + r2 = _mm256_xor_si256(r2, input_xor_v); + r3 = _mm256_xor_si256(r3, input_xor_v); + r4 = _mm256_xor_si256(r4, input_xor_v); + r5 = _mm256_xor_si256(r5, input_xor_v); + r6 = _mm256_xor_si256(r6, input_xor_v); + r7 = _mm256_xor_si256(r7, input_xor_v); + + _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 0 * 8 * 4), + r0); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 2 * 8 * 4), + r4); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 4 * 8 * 4), + r1); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 6 * 8 * 4), + r5); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 1 * 8 * 4), + r2); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 3 * 8 * 4), + r6); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 5 * 8 * 4), + r3); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 7 * 8 * 4), + r7); + } + } else if (available_src_rows > 0) { + RUY_DCHECK_LT(available_src_rows, kNumChunkedSrcRows); + // We do not care what goes into the trailing buffer, but we want + // in_data[...] ^ input_xor == 0 for irrelevant values in the summation. + // + // We compensate for padding-with-zero_point by initializing the + // summations with the compensating offset, effectively + // ((input_xor ^ input_xor) - (zero_point ^ input_xor)) * + // 4 * (8 - ((available_src_rows + 3) >> 2)). + // + // Note that (zero_point ^ input_xor) is performed in 8-bits and then + // cast. + sums_adjustment += + -(zero_point ^ input_xor) * 4 * (8 - ((available_src_rows + 3) >> 2)); + + __m256i t0, t1, t2, t3, t4, t5, t6, t7; + __m256i r0, r1, r2, r3, r4, r5, r6, r7; + const __m256i input_xor_v = _mm256_set1_epi8(input_xor); + + t0 = MaskLoadu(available_src_rows, zero_point, src_ptr0); + t4 = MaskLoadu(available_src_rows, zero_point, src_ptr4); + t1 = MaskLoadu(available_src_rows, zero_point, src_ptr1); + t5 = MaskLoadu(available_src_rows, zero_point, src_ptr5); + t2 = MaskLoadu(available_src_rows, zero_point, src_ptr2); + t6 = MaskLoadu(available_src_rows, zero_point, src_ptr6); + t3 = MaskLoadu(available_src_rows, zero_point, src_ptr3); + t7 = MaskLoadu(available_src_rows, zero_point, src_ptr7); + + r0 = _mm256_unpacklo_epi32(t0, t1); + r4 = _mm256_unpacklo_epi32(t4, t5); + r2 = _mm256_unpackhi_epi32(t0, t1); + r6 = _mm256_unpackhi_epi32(t4, t5); + r1 = _mm256_unpacklo_epi32(t2, t3); + r5 = _mm256_unpacklo_epi32(t6, t7); + r3 = _mm256_unpackhi_epi32(t2, t3); + r7 = _mm256_unpackhi_epi32(t6, t7); + + t0 = _mm256_unpacklo_epi64(r0, r1); + t4 = _mm256_unpacklo_epi64(r4, r5); + t2 = _mm256_unpackhi_epi64(r0, r1); + t6 = _mm256_unpackhi_epi64(r4, r5); + t1 = _mm256_unpacklo_epi64(r2, r3); + t5 = _mm256_unpacklo_epi64(r6, r7); + t3 = _mm256_unpackhi_epi64(r2, r3); + t7 = _mm256_unpackhi_epi64(r6, r7); + + // The preceding sets of rearrangement operations interleaved by 4 bytes + // and then by 8 bytes *within* lanes. The following set interleave by + // 16 bytes (128-bit), operating *between* AVX lanes. For instance (t0, + // t4) are interleaved to create (r0, r1). This complexity follows from + // the way that AVX is centered around MM 128-bit lanes. + r0 = _mm256_permute2x128_si256(t0, t4, 0x20); + r4 = _mm256_permute2x128_si256(t1, t5, 0x20); + r1 = _mm256_permute2x128_si256(t0, t4, 0x31); + r5 = _mm256_permute2x128_si256(t1, t5, 0x31); + r2 = _mm256_permute2x128_si256(t2, t6, 0x20); + r6 = _mm256_permute2x128_si256(t3, t7, 0x20); + r3 = _mm256_permute2x128_si256(t2, t6, 0x31); + r7 = _mm256_permute2x128_si256(t3, t7, 0x31); + + r0 = _mm256_xor_si256(r0, input_xor_v); + r1 = _mm256_xor_si256(r1, input_xor_v); + r2 = _mm256_xor_si256(r2, input_xor_v); + r3 = _mm256_xor_si256(r3, input_xor_v); + r4 = _mm256_xor_si256(r4, input_xor_v); + r5 = _mm256_xor_si256(r5, input_xor_v); + r6 = _mm256_xor_si256(r6, input_xor_v); + r7 = _mm256_xor_si256(r7, input_xor_v); + + __m256i sums_4x4_16bit_lo; + sums_4x4_16bit_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r0)); + sums_4x4_16bit_lo = _mm256_add_epi16( + sums_4x4_16bit_lo, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r1))); + sums_4x4_16bit_lo = _mm256_add_epi16( + sums_4x4_16bit_lo, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r2))); + sums_4x4_16bit_lo = _mm256_add_epi16( + sums_4x4_16bit_lo, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r3))); + sums_4x4_16bit_lo = _mm256_add_epi16( + sums_4x4_16bit_lo, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r4))); + sums_4x4_16bit_lo = _mm256_add_epi16( + sums_4x4_16bit_lo, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r5))); + sums_4x4_16bit_lo = _mm256_add_epi16( + sums_4x4_16bit_lo, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r6))); + sums_4x4_16bit_lo = _mm256_add_epi16( + sums_4x4_16bit_lo, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r7))); + + // The sums have been performed across columns, and now we have 4x16-bit + // sums packed together. We use madd for pairwise 32-bit sums. + const __m256i sums_4x2_32bit_lo_new = + _mm256_madd_epi16(sums_4x4_16bit_lo, ones_16bit); + sums_4x2_32bit_lo = + _mm256_add_epi32(sums_4x2_32bit_lo, sums_4x2_32bit_lo_new); + + __m256i sums_4x4_16bit_hi; + sums_4x4_16bit_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r0, 1)); + sums_4x4_16bit_hi = _mm256_add_epi16( + sums_4x4_16bit_hi, + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r1, 1))); + sums_4x4_16bit_hi = _mm256_add_epi16( + sums_4x4_16bit_hi, + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r2, 1))); + sums_4x4_16bit_hi = _mm256_add_epi16( + sums_4x4_16bit_hi, + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r3, 1))); + sums_4x4_16bit_hi = _mm256_add_epi16( + sums_4x4_16bit_hi, + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r4, 1))); + sums_4x4_16bit_hi = _mm256_add_epi16( + sums_4x4_16bit_hi, + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r5, 1))); + sums_4x4_16bit_hi = _mm256_add_epi16( + sums_4x4_16bit_hi, + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r6, 1))); + sums_4x4_16bit_hi = _mm256_add_epi16( + sums_4x4_16bit_hi, + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r7, 1))); + + const __m256i sums_4x2_32bit_hi_new = + _mm256_madd_epi16(sums_4x4_16bit_hi, ones_16bit); + sums_4x2_32bit_hi = + _mm256_add_epi32(sums_4x2_32bit_hi, sums_4x2_32bit_hi_new); + + _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 0 * 8 * 4), + r0); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 2 * 8 * 4), + r4); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 4 * 8 * 4), + r1); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 6 * 8 * 4), + r5); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 1 * 8 * 4), + r2); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 3 * 8 * 4), + r6); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 5 * 8 * 4), + r3); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 7 * 8 * 4), + r7); + } + + packed_ptr += 8 * kNumChunkedSrcRows; + src_ptr0 += src_inc0; + src_ptr1 += src_inc1; + src_ptr2 += src_inc2; + src_ptr3 += src_inc3; + src_ptr4 += src_inc4; + src_ptr5 += src_inc5; + src_ptr6 += src_inc6; + src_ptr7 += src_inc7; + } + + if (sums_ptr) { + const __m256i sums_adjustment_v = _mm256_set1_epi32(sums_adjustment); + + __m256i sums = + _mm256_loadu_si256(reinterpret_cast(sums_ptr)); + const __m256i idx = _mm256_set_epi32(7, 5, 3, 1, 6, 4, 2, 0); + + // We earlier used madd for pairwise 32-bit sums, and now we deinterlace the + // neighbours, finshing up by adding them to the stored accumulated sums. + const __m256i sums_2x4_32bit_lo = + _mm256_permutevar8x32_epi32(sums_4x2_32bit_lo, idx); + const __m256i sums_2x4_32bit_hi = + _mm256_permutevar8x32_epi32(sums_4x2_32bit_hi, idx); + const __m256i sums_2x4_32bit_a = + _mm256_permute2x128_si256(sums_2x4_32bit_lo, sums_2x4_32bit_hi, 0x20); + const __m256i sums_2x4_32bit_b = + _mm256_permute2x128_si256(sums_2x4_32bit_lo, sums_2x4_32bit_hi, 0x31); + sums = _mm256_add_epi32(sums, sums_adjustment_v); + sums = _mm256_add_epi32(sums, sums_2x4_32bit_a); + sums = _mm256_add_epi32(sums, sums_2x4_32bit_b); + + _mm256_storeu_si256(reinterpret_cast<__m256i*>(sums_ptr), sums); + } +} + +inline __m256 Mm256UnpackloPsx2(const __m256 a, const __m256 b) { + return _mm256_castpd_ps( + _mm256_unpacklo_pd(_mm256_castps_pd(a), _mm256_castps_pd(b))); +} + +inline __m256 Mm256UnpackhiPsx2(const __m256 a, const __m256 b) { + return _mm256_castpd_ps( + _mm256_unpackhi_pd(_mm256_castps_pd(a), _mm256_castps_pd(b))); +} + +inline void PackFloatAvx2Packer(const float* src_ptr, const float* zerobuf, + int src_stride, int remaining_src_cols, + int src_rows, float* packed_ptr, + float* trailing_buf) { + RUY_DCHECK_EQ(PackImplFloatAvx2::Layout::kCols, 8); + RUY_DCHECK_EQ(PackImplFloatAvx2::Layout::kRows, 1); + + // This packing amounts to transposition of 8x8 blocks. + static constexpr int kPackCols = 8; // Source cols packed together. + static constexpr int kPackRows = 8; // Short input is padded. + + const float* src_ptr0 = src_ptr; + const float* src_ptr1 = src_ptr0 + src_stride; + const float* src_ptr2 = src_ptr1 + src_stride; + const float* src_ptr3 = src_ptr2 + src_stride; + const float* src_ptr4 = src_ptr3 + src_stride; + const float* src_ptr5 = src_ptr4 + src_stride; + const float* src_ptr6 = src_ptr5 + src_stride; + const float* src_ptr7 = src_ptr6 + src_stride; + std::int64_t src_inc0 = 8; + std::int64_t src_inc1 = 8; + std::int64_t src_inc2 = 8; + std::int64_t src_inc3 = 8; + std::int64_t src_inc4 = 8; + std::int64_t src_inc5 = 8; + std::int64_t src_inc6 = 8; + std::int64_t src_inc7 = 8; + // Handle cases where source does not have kPackDim (8) columns. + if (remaining_src_cols < kPackCols) { + if (remaining_src_cols <= 0) { + src_ptr0 = zerobuf; + src_inc0 = 0; + } + if (remaining_src_cols <= 1) { + src_ptr1 = zerobuf; + src_inc1 = 0; + } + if (remaining_src_cols <= 2) { + src_ptr2 = zerobuf; + src_inc2 = 0; + } + if (remaining_src_cols <= 3) { + src_ptr3 = zerobuf; + src_inc3 = 0; + } + if (remaining_src_cols <= 4) { + src_ptr4 = zerobuf; + src_inc4 = 0; + } + if (remaining_src_cols <= 5) { + src_ptr5 = zerobuf; + src_inc5 = 0; + } + if (remaining_src_cols <= 6) { + src_ptr6 = zerobuf; + src_inc6 = 0; + } + src_ptr7 = zerobuf; + src_inc7 = 0; + } + + for (int k = 0; k < src_rows; k += kPackRows) { + const int available_src_rows = src_rows - k; + // Effectively, + // available_src_rows = std::max(0, std::min(kPackDim, src_rows - k)); + // but treat each case separately. + if (available_src_rows >= kPackRows) { + __m256 t0, t1, t2, t3, t4, t5, t6, t7; + __m256 r0, r1, r2, r3, r4, r5, r6, r7; + + t0 = _mm256_loadu_ps(src_ptr0); + t4 = _mm256_loadu_ps(src_ptr4); + t1 = _mm256_loadu_ps(src_ptr1); + t5 = _mm256_loadu_ps(src_ptr5); + t2 = _mm256_loadu_ps(src_ptr2); + t6 = _mm256_loadu_ps(src_ptr6); + t3 = _mm256_loadu_ps(src_ptr3); + t7 = _mm256_loadu_ps(src_ptr7); + + r0 = _mm256_unpacklo_ps(t0, t1); + r4 = _mm256_unpacklo_ps(t4, t5); + r2 = _mm256_unpackhi_ps(t0, t1); + r6 = _mm256_unpackhi_ps(t4, t5); + r1 = _mm256_unpacklo_ps(t2, t3); + r5 = _mm256_unpacklo_ps(t6, t7); + r3 = _mm256_unpackhi_ps(t2, t3); + r7 = _mm256_unpackhi_ps(t6, t7); + + t0 = Mm256UnpackloPsx2(r0, r1); + t4 = Mm256UnpackloPsx2(r4, r5); + t2 = Mm256UnpackhiPsx2(r0, r1); + t6 = Mm256UnpackhiPsx2(r4, r5); + t1 = Mm256UnpackloPsx2(r2, r3); + t5 = Mm256UnpackloPsx2(r6, r7); + t3 = Mm256UnpackhiPsx2(r2, r3); + t7 = Mm256UnpackhiPsx2(r6, r7); + + // The preceding sets of rearrangement operations interleaved by 4 bytes + // and then by 8 bytes *within* lanes. The following set interleave by 16 + // bytes (128-bit), operating *between* AVX lanes. For instance (t0, t4) + // are interleaved to create (r0, r1). This complexity follows from the + // way that AVX is centered around MM 128-bit lanes. + r0 = _mm256_permute2f128_ps(t0, t4, 0x20); + r4 = _mm256_permute2f128_ps(t1, t5, 0x20); + r1 = _mm256_permute2f128_ps(t0, t4, 0x31); + r5 = _mm256_permute2f128_ps(t1, t5, 0x31); + r2 = _mm256_permute2f128_ps(t2, t6, 0x20); + r6 = _mm256_permute2f128_ps(t3, t7, 0x20); + r3 = _mm256_permute2f128_ps(t2, t6, 0x31); + r7 = _mm256_permute2f128_ps(t3, t7, 0x31); + + _mm256_storeu_ps(packed_ptr + 0 * 8, r0); + _mm256_storeu_ps(packed_ptr + 2 * 8, r4); + _mm256_storeu_ps(packed_ptr + 4 * 8, r1); + _mm256_storeu_ps(packed_ptr + 6 * 8, r5); + _mm256_storeu_ps(packed_ptr + 1 * 8, r2); + _mm256_storeu_ps(packed_ptr + 3 * 8, r6); + _mm256_storeu_ps(packed_ptr + 5 * 8, r3); + _mm256_storeu_ps(packed_ptr + 7 * 8, r7); + } else if (available_src_rows > 0) { + const __m256i series = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0); + const __m256i row_mask_v = + _mm256_cmpgt_epi32(_mm256_set1_epi32(available_src_rows), series); + + __m256 t0, t1, t2, t3, t4, t5, t6, t7; + __m256 r0, r1, r2, r3, r4, r5, r6, r7; + + t0 = _mm256_maskload_ps(src_ptr0, row_mask_v); + t4 = _mm256_maskload_ps(src_ptr4, row_mask_v); + t1 = _mm256_maskload_ps(src_ptr1, row_mask_v); + t5 = _mm256_maskload_ps(src_ptr5, row_mask_v); + t2 = _mm256_maskload_ps(src_ptr2, row_mask_v); + t6 = _mm256_maskload_ps(src_ptr6, row_mask_v); + t3 = _mm256_maskload_ps(src_ptr3, row_mask_v); + t7 = _mm256_maskload_ps(src_ptr7, row_mask_v); + + r0 = _mm256_unpacklo_ps(t0, t1); + r4 = _mm256_unpacklo_ps(t4, t5); + r2 = _mm256_unpackhi_ps(t0, t1); + r6 = _mm256_unpackhi_ps(t4, t5); + r1 = _mm256_unpacklo_ps(t2, t3); + r5 = _mm256_unpacklo_ps(t6, t7); + r3 = _mm256_unpackhi_ps(t2, t3); + r7 = _mm256_unpackhi_ps(t6, t7); + + t0 = Mm256UnpackloPsx2(r0, r1); + t4 = Mm256UnpackloPsx2(r4, r5); + t2 = Mm256UnpackhiPsx2(r0, r1); + t6 = Mm256UnpackhiPsx2(r4, r5); + t1 = Mm256UnpackloPsx2(r2, r3); + t5 = Mm256UnpackloPsx2(r6, r7); + t3 = Mm256UnpackhiPsx2(r2, r3); + t7 = Mm256UnpackhiPsx2(r6, r7); + + // The preceding sets of rearrangement operations interleaved by 4 bytes + // and then by 8 bytes *within* lanes. The following set interleave by 16 + // bytes (128-bit), operating *between* AVX lanes. For instance (t0, t4) + // are interleaved to create (r0, r1). This complexity follows from the + // way that AVX is centered around MM 128-bit lanes. + r0 = _mm256_permute2f128_ps(t0, t4, 0x20); + r4 = _mm256_permute2f128_ps(t1, t5, 0x20); + r1 = _mm256_permute2f128_ps(t0, t4, 0x31); + r5 = _mm256_permute2f128_ps(t1, t5, 0x31); + r2 = _mm256_permute2f128_ps(t2, t6, 0x20); + r6 = _mm256_permute2f128_ps(t3, t7, 0x20); + r3 = _mm256_permute2f128_ps(t2, t6, 0x31); + // r7 no longer needed. + + _mm256_storeu_ps(trailing_buf + 0 * 8, r0); + _mm256_storeu_ps(trailing_buf + 2 * 8, r4); + _mm256_storeu_ps(trailing_buf + 4 * 8, r1); + _mm256_storeu_ps(trailing_buf + 6 * 8, r5); + _mm256_storeu_ps(trailing_buf + 1 * 8, r2); + _mm256_storeu_ps(trailing_buf + 3 * 8, r6); + _mm256_storeu_ps(trailing_buf + 5 * 8, r3); + // No store to (trailing_buf + 7 * 8), space not allocated. + } + + packed_ptr += kPackRows * kPackCols; + src_ptr0 += src_inc0; + src_ptr1 += src_inc1; + src_ptr2 += src_inc2; + src_ptr3 += src_inc3; + src_ptr4 += src_inc4; + src_ptr5 += src_inc5; + src_ptr6 += src_inc6; + src_ptr7 += src_inc7; + } +} + +} // namespace. + +void Pack8bitAvx2(const std::int8_t* src_ptr, std::int8_t input_xor, + const std::int8_t* zerobuf, int src_stride, + int remaining_src_cols, int src_rows, std::int8_t* packed_ptr, + std::int32_t* sums_ptr) { + profiler::ScopeLabel label("Pack kAvx2 8bit"); + + using Layout = PackImpl8bitAvx2::Layout; + RUY_DCHECK_EQ(Layout::kCols, 8); + RUY_DCHECK_EQ(Layout::kRows, 4); + + // Each Layout::Rows is 4 contiguous input, contiguous packed elements. + // We process 8 of these chunks at a time, padding short input chunks. + static constexpr int kNumRowChunks = 8; // Short input is padded. + + // Each packed block is 4*8, and there are normally 8. The trailing block is + // only slightly shorter. + constexpr int kTrailingBufSize = + kNumRowChunks * Layout::kCols * Layout::kRows; + std::int8_t trailing_buf[kTrailingBufSize]; + memset(trailing_buf, 0, kTrailingBufSize * sizeof(std::int8_t)); + + Pack8bitAvx2Packer(src_ptr, input_xor, zerobuf, src_stride, + remaining_src_cols, src_rows, packed_ptr, sums_ptr, + trailing_buf); + + constexpr int kChunkedRowMask = kNumRowChunks * Layout::kRows - 1; + const bool trailing_data = (src_rows & kChunkedRowMask) > 0; + // If the number of source rows is not a multiple of kChunkedRowMask, there + // will be data in the trailing buffer, + if (trailing_data > 0) { + const int non_trailing_rows = src_rows & ~kChunkedRowMask; + // Destination "rows" are padded to next highest multiple of Layout::kRows. + const int dst_rows = (src_rows + 3) & ~3; + const int trailing_rows = dst_rows - non_trailing_rows; + memcpy(packed_ptr + Layout::kCols * non_trailing_rows, trailing_buf, + Layout::kCols * trailing_rows * sizeof(std::int8_t)); + } +} + +void PackFloatAvx2(const float* src_ptr, const float* zerobuf, int src_stride, + int remaining_src_cols, int src_rows, float* packed_ptr) { + profiler::ScopeLabel label("Pack kAvx2 float"); + static constexpr int kPackCols = 8; // Source cols packed together. + static constexpr int kPackRows = 8; // Short input is padded. + float trailing_buf[(kPackRows - 1) * kPackCols]; + if (remaining_src_cols < 8) { + memset(trailing_buf, 0, sizeof(trailing_buf)); + } + PackFloatAvx2Packer(src_ptr, zerobuf, src_stride, remaining_src_cols, + src_rows, packed_ptr, trailing_buf); + + const int trailing_rows = src_rows & (kPackRows - 1); + if (trailing_rows > 0) { + const int non_trailing_rows = src_rows & ~(kPackRows - 1); + memcpy(packed_ptr + kPackCols * non_trailing_rows, trailing_buf, + kPackCols * trailing_rows * sizeof(float)); + } +} + +#endif // RUY_PLATFORM(AVX2) && RUY_OPT_ENABLED(RUY_OPT_INTRINSICS) + +} // namespace ruy diff --git a/ruy/pack_avx512.cc b/ruy/pack_avx512.cc new file mode 100644 index 0000000..ecad3a2 --- /dev/null +++ b/ruy/pack_avx512.cc @@ -0,0 +1,693 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "ruy/check_macros.h" +#include "ruy/matrix.h" +#include "ruy/opt_set.h" +#include "ruy/pack.h" +#include "ruy/path.h" +#include "ruy/platform.h" +#include "ruy/profiler/instrumentation.h" + +#if RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_INTRINSICS) +#include // IWYU pragma: keep +#endif + +namespace ruy { + +#if !(RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_ASM)) + +void Pack8bitAvx512(const std::int8_t* src_ptr, std::int8_t input_xor, + const std::int8_t* zerobuf, int src_stride, + int remaining_src_cols, int src_rows, + std::int8_t* packed_ptr, std::int32_t* sums_ptr) { + // CPU-ID-based checks should disable the path that would reach this point. + RUY_DCHECK(false); +} + +void PackFloatAvx512(const float* src_ptr, const float* zerobuf, int src_stride, + int remaining_src_cols, int src_rows, float* packed_ptr) { + // CPU-ID-based checks should disable the path that would reach this point. + RUY_DCHECK(false); +} + +#else // RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_ASM) + +// The first int8_t template parameter is arbitrary: this routine is common to +// all 8-bit source matrix types. +using PackImpl8bitAvx512 = + PackImpl, + std::int8_t, std::int8_t, std::int32_t>; + +namespace { + +inline void ZeroHalf8bitAvx512(int src_rows, std::int8_t packed_zero_point, + std::int8_t* packed_ptr) { + using Layout = PackImpl8bitAvx512::Layout; + static constexpr int kHalfLayoutCols = + PackImpl8bitAvx512::kHalfLayoutCols; // Half the number of cols in a + // block. + RUY_DCHECK_EQ(kHalfLayoutCols, 8); + RUY_DCHECK_EQ(Layout::kCols, 16); + RUY_DCHECK_EQ(Layout::kRows, 4); + + const int non_trailing_blocks = (src_rows & ~31) >> 2; + // This routine fills half blocks, and typically fills the second halves. + // Thus packed_ptr is already offset by 8 * 4. + for (int k = 0; k < non_trailing_blocks; ++k) { + for (int j = 0; j < (kHalfLayoutCols * Layout::kRows); ++j) { + packed_ptr[Layout::kCols * Layout::kRows * k + j] = packed_zero_point; + } + } +} + +inline __m512i LoaduTwo(const std::int8_t* addr_lo, + const std::int8_t* addr_hi) { + __m512i lower_filled = _mm512_castsi256_si512(_mm256_loadu_epi8(addr_lo)); + return _mm512_inserti32x8(lower_filled, _mm256_loadu_epi8(addr_hi), 1); +} + +inline __m512i MaskLoaduTwo(__mmask32 row_mask, const __m256i default_value_v, + const std::int8_t* addr_lo, + const std::int8_t* addr_hi) { + const __m512i lower_filled = _mm512_castsi256_si512( + _mm256_mask_loadu_epi8(default_value_v, row_mask, addr_lo)); + return _mm512_inserti32x8( + lower_filled, _mm256_mask_loadu_epi8(default_value_v, row_mask, addr_hi), + 1); +} + +inline void HalfPack8bitAvx512(const std::int8_t* src_ptr, + std::int8_t input_xor, + const std::int8_t* zerobuf, int src_stride, + int remaining_src_cols, int src_rows, + std::int8_t* packed_ptr, std::int32_t* sums_ptr, + std::int8_t* trailing_buf) { + using Layout = PackImpl8bitAvx512::Layout; + RUY_DCHECK_EQ(Layout::kCols, 16); + RUY_DCHECK_EQ(Layout::kRows, 4); + // Each Layout::Rows is 4 contiguous input, contiguous packed elements. + // We process 8 of these chunks at a time, padding short input chunks. + constexpr int kNumRowChunks = 8; + constexpr int kNumChunkedSrcRows = kNumRowChunks * Layout::kRows; + + const std::int8_t* src_ptr0 = src_ptr; + const std::int8_t* src_ptr1 = src_ptr0 + src_stride; + const std::int8_t* src_ptr2 = src_ptr1 + src_stride; + const std::int8_t* src_ptr3 = src_ptr2 + src_stride; + const std::int8_t* src_ptr4 = src_ptr3 + src_stride; + const std::int8_t* src_ptr5 = src_ptr4 + src_stride; + const std::int8_t* src_ptr6 = src_ptr5 + src_stride; + const std::int8_t* src_ptr7 = src_ptr6 + src_stride; + std::int64_t src_inc0 = kNumChunkedSrcRows; + std::int64_t src_inc1 = kNumChunkedSrcRows; + std::int64_t src_inc2 = kNumChunkedSrcRows; + std::int64_t src_inc3 = kNumChunkedSrcRows; + std::int64_t src_inc4 = kNumChunkedSrcRows; + std::int64_t src_inc5 = kNumChunkedSrcRows; + std::int64_t src_inc6 = kNumChunkedSrcRows; + std::int64_t src_inc7 = kNumChunkedSrcRows; + // Handle cases where source does not have kHalfLayoutCols (8) columns. + if (remaining_src_cols < 8) { + if (remaining_src_cols <= 0) { + src_ptr0 = zerobuf; + src_inc0 = 0; + } + if (remaining_src_cols <= 1) { + src_ptr1 = zerobuf; + src_inc1 = 0; + } + if (remaining_src_cols <= 2) { + src_ptr2 = zerobuf; + src_inc2 = 0; + } + if (remaining_src_cols <= 3) { + src_ptr3 = zerobuf; + src_inc3 = 0; + } + if (remaining_src_cols <= 4) { + src_ptr4 = zerobuf; + src_inc4 = 0; + } + if (remaining_src_cols <= 5) { + src_ptr5 = zerobuf; + src_inc5 = 0; + } + if (remaining_src_cols <= 6) { + src_ptr6 = zerobuf; + src_inc6 = 0; + } + src_ptr7 = zerobuf; + src_inc7 = 0; + } + + const std::int8_t zero_point = zerobuf[0]; + + if (sums_ptr) { + // i: kHalfLayoutCols. + for (int i = 0; i < 8; ++i) { + sums_ptr[i] = 0; + } + } + std::int32_t sums_adjustment = 0; + const __m512i ones_16bit = _mm512_set1_epi16(1); + __m512i sums_8x2_32bit = _mm512_set1_epi32(0); + + // The overall packing effectively pads the source rows to + // (src_rows + 63) & ~63. The iteration over k may skip when m=1, and then we + // only pack for (src_rows + 31) & ~31. When there is an incomplete + // destination block, this is stored into trailing_buf instead of packed_ptr. + for (int k = 0; k < src_rows; k += 2 * kNumChunkedSrcRows) { + // m: {0, 1} for 2 chunks of rows. + for (int m = 0; m < 2; ++m) { + // Available source rows. + // If this is less than 0 (for m=1), we skip, having filled trailing + // buffer for m=0. Also, if source rows is zero on m=1, then we filled + // exactly to the end of the column in the packed buffer. + const int available_src_rows = src_rows - k - m * kNumChunkedSrcRows; + // Effectively, + // available rows = std::max(0, std::min(8, src_rows - k - 8 * 4 * m)); + // treat each case separately. + if (available_src_rows >= kNumChunkedSrcRows) { + // i: chunks, s: Layout::Rows. + if (sums_ptr) { + __m512i t0, t1, t2, t3; + __m512i r0, r1, r2, r3; + const __m512i input_xor_v = _mm512_set1_epi8(input_xor); + + t0 = LoaduTwo(src_ptr0, src_ptr4); + t1 = LoaduTwo(src_ptr1, src_ptr5); + t2 = LoaduTwo(src_ptr2, src_ptr6); + t3 = LoaduTwo(src_ptr3, src_ptr7); + + r0 = _mm512_unpacklo_epi32(t0, t1); + r2 = _mm512_unpackhi_epi32(t0, t1); + r1 = _mm512_unpacklo_epi32(t2, t3); + r3 = _mm512_unpackhi_epi32(t2, t3); + + t0 = _mm512_unpacklo_epi64(r0, r1); + t2 = _mm512_unpackhi_epi64(r0, r1); + t1 = _mm512_unpacklo_epi64(r2, r3); + t3 = _mm512_unpackhi_epi64(r2, r3); + + r0 = _mm512_shuffle_i32x4(t0, t1, 0x88); + r1 = _mm512_shuffle_i32x4(t0, t1, 0xdd); + r2 = _mm512_shuffle_i32x4(t2, t3, 0x88); + r3 = _mm512_shuffle_i32x4(t2, t3, 0xdd); + + r0 = _mm512_xor_si512(r0, input_xor_v); + r1 = _mm512_xor_si512(r1, input_xor_v); + r2 = _mm512_xor_si512(r2, input_xor_v); + r3 = _mm512_xor_si512(r3, input_xor_v); + + const __m256i r0_0 = _mm512_castsi512_si256(r0); + const __m256i r0_1 = _mm512_extracti32x8_epi32(r0, 1); + const __m256i r1_0 = _mm512_castsi512_si256(r1); + const __m256i r1_1 = _mm512_extracti32x8_epi32(r1, 1); + const __m256i r2_0 = _mm512_castsi512_si256(r2); + const __m256i r2_1 = _mm512_extracti32x8_epi32(r2, 1); + const __m256i r3_0 = _mm512_castsi512_si256(r3); + const __m256i r3_1 = _mm512_extracti32x8_epi32(r3, 1); + + __m512i sums_8x4_16bit; + sums_8x4_16bit = _mm512_cvtepi8_epi16(r0_0); + sums_8x4_16bit = + _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r0_1)); + sums_8x4_16bit = + _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r1_0)); + sums_8x4_16bit = + _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r1_1)); + sums_8x4_16bit = + _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r2_0)); + sums_8x4_16bit = + _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r2_1)); + sums_8x4_16bit = + _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r3_0)); + sums_8x4_16bit = + _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r3_1)); + // The sums have been performed across columns, and now we have + // 4x16-bit sums packed together. We use madd for pairwise 32-bit + // sums. + const __m512i sums_8x2_32bit_new = + _mm512_madd_epi16(sums_8x4_16bit, ones_16bit); + sums_8x2_32bit = _mm512_add_epi32(sums_8x2_32bit, sums_8x2_32bit_new); + + _mm256_storeu_epi8(packed_ptr + 0 * 16 * 4, r0_0); + _mm256_storeu_epi8(packed_ptr + 2 * 16 * 4, r0_1); + _mm256_storeu_epi8(packed_ptr + 4 * 16 * 4, r1_0); + _mm256_storeu_epi8(packed_ptr + 6 * 16 * 4, r1_1); + _mm256_storeu_epi8(packed_ptr + 1 * 16 * 4, r2_0); + _mm256_storeu_epi8(packed_ptr + 3 * 16 * 4, r2_1); + _mm256_storeu_epi8(packed_ptr + 5 * 16 * 4, r3_0); + _mm256_storeu_epi8(packed_ptr + 7 * 16 * 4, r3_1); + } else { + __m512i t0, t1, t2, t3; + __m512i r0, r1, r2, r3; + const __m512i input_xor_v = _mm512_set1_epi8(input_xor); + + t0 = LoaduTwo(src_ptr0, src_ptr4); + t1 = LoaduTwo(src_ptr1, src_ptr5); + t2 = LoaduTwo(src_ptr2, src_ptr6); + t3 = LoaduTwo(src_ptr3, src_ptr7); + + r0 = _mm512_unpacklo_epi32(t0, t1); + r2 = _mm512_unpackhi_epi32(t0, t1); + r1 = _mm512_unpacklo_epi32(t2, t3); + r3 = _mm512_unpackhi_epi32(t2, t3); + + t0 = _mm512_unpacklo_epi64(r0, r1); + t2 = _mm512_unpackhi_epi64(r0, r1); + t1 = _mm512_unpacklo_epi64(r2, r3); + t3 = _mm512_unpackhi_epi64(r2, r3); + + r0 = _mm512_shuffle_i32x4(t0, t1, 0x88); + r1 = _mm512_shuffle_i32x4(t0, t1, 0xdd); + r2 = _mm512_shuffle_i32x4(t2, t3, 0x88); + r3 = _mm512_shuffle_i32x4(t2, t3, 0xdd); + + r0 = _mm512_xor_si512(r0, input_xor_v); + r1 = _mm512_xor_si512(r1, input_xor_v); + r2 = _mm512_xor_si512(r2, input_xor_v); + r3 = _mm512_xor_si512(r3, input_xor_v); + + const __m256i r0_0 = _mm512_castsi512_si256(r0); + const __m256i r0_1 = _mm512_extracti32x8_epi32(r0, 1); + const __m256i r1_0 = _mm512_castsi512_si256(r1); + const __m256i r1_1 = _mm512_extracti32x8_epi32(r1, 1); + const __m256i r2_0 = _mm512_castsi512_si256(r2); + const __m256i r2_1 = _mm512_extracti32x8_epi32(r2, 1); + const __m256i r3_0 = _mm512_castsi512_si256(r3); + const __m256i r3_1 = _mm512_extracti32x8_epi32(r3, 1); + _mm256_storeu_epi8(packed_ptr + 0 * 16 * 4, r0_0); + _mm256_storeu_epi8(packed_ptr + 2 * 16 * 4, r0_1); + _mm256_storeu_epi8(packed_ptr + 4 * 16 * 4, r1_0); + _mm256_storeu_epi8(packed_ptr + 6 * 16 * 4, r1_1); + _mm256_storeu_epi8(packed_ptr + 1 * 16 * 4, r2_0); + _mm256_storeu_epi8(packed_ptr + 3 * 16 * 4, r2_1); + _mm256_storeu_epi8(packed_ptr + 5 * 16 * 4, r3_0); + _mm256_storeu_epi8(packed_ptr + 7 * 16 * 4, r3_1); + } + } else if (available_src_rows > 0) { + RUY_DCHECK_LT(available_src_rows >> 2, kNumChunkedSrcRows); + const __mmask32 row_mask = + (static_cast(1) << available_src_rows) - 1; + + // We do not care what goes into the trailing buffer, but we want + // in_data[...] ^ input_xor == 0 for irrelevant values in the summation. + // + // We compensate for padding-with-zero_point by initializing the + // summations with the compensating offset, effectively + // ((input_xor ^ input_xor) - (zero_point ^ input_xor)) * + // 4 * (8 - ((available_src_rows + 3) >> 2)). + // + // Note that (zero_point ^ input_xor) is performed in 8-bits and then + // cast. + sums_adjustment += -(zero_point ^ input_xor) * 4 * + (8 - ((available_src_rows + 3) >> 2)); + + __m512i t0, t1, t2, t3; + __m512i r0, r1, r2, r3; + const __m512i input_xor_v = _mm512_set1_epi8(input_xor); + const __m256i zero_point_v = _mm256_set1_epi8(zero_point); + + t0 = MaskLoaduTwo(row_mask, zero_point_v, src_ptr0, src_ptr4); + t1 = MaskLoaduTwo(row_mask, zero_point_v, src_ptr1, src_ptr5); + t2 = MaskLoaduTwo(row_mask, zero_point_v, src_ptr2, src_ptr6); + t3 = MaskLoaduTwo(row_mask, zero_point_v, src_ptr3, src_ptr7); + + r0 = _mm512_unpacklo_epi32(t0, t1); + r2 = _mm512_unpackhi_epi32(t0, t1); + r1 = _mm512_unpacklo_epi32(t2, t3); + r3 = _mm512_unpackhi_epi32(t2, t3); + + t0 = _mm512_unpacklo_epi64(r0, r1); + t2 = _mm512_unpackhi_epi64(r0, r1); + t1 = _mm512_unpacklo_epi64(r2, r3); + t3 = _mm512_unpackhi_epi64(r2, r3); + + r0 = _mm512_shuffle_i32x4(t0, t1, 0x88); + r1 = _mm512_shuffle_i32x4(t0, t1, 0xdd); + r2 = _mm512_shuffle_i32x4(t2, t3, 0x88); + r3 = _mm512_shuffle_i32x4(t2, t3, 0xdd); + + r0 = _mm512_xor_si512(r0, input_xor_v); + r1 = _mm512_xor_si512(r1, input_xor_v); + r2 = _mm512_xor_si512(r2, input_xor_v); + r3 = _mm512_xor_si512(r3, input_xor_v); + + const __m256i r0_0 = _mm512_castsi512_si256(r0); + const __m256i r0_1 = _mm512_extracti32x8_epi32(r0, 1); + const __m256i r1_0 = _mm512_castsi512_si256(r1); + const __m256i r1_1 = _mm512_extracti32x8_epi32(r1, 1); + const __m256i r2_0 = _mm512_castsi512_si256(r2); + const __m256i r2_1 = _mm512_extracti32x8_epi32(r2, 1); + const __m256i r3_0 = _mm512_castsi512_si256(r3); + const __m256i r3_1 = _mm512_extracti32x8_epi32(r3, 1); + + __m512i sums_8x4_16bit; + sums_8x4_16bit = _mm512_cvtepi8_epi16(r0_0); + sums_8x4_16bit = + _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r0_1)); + sums_8x4_16bit = + _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r1_0)); + sums_8x4_16bit = + _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r1_1)); + sums_8x4_16bit = + _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r2_0)); + sums_8x4_16bit = + _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r2_1)); + sums_8x4_16bit = + _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r3_0)); + sums_8x4_16bit = + _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r3_1)); + // The sums have been performed across columns, and now we have + // 4x16-bit sums packed together. We use madd for pairwise 32-bit + // sums. + const __m512i sums_8x2_32bit_new = + _mm512_madd_epi16(sums_8x4_16bit, ones_16bit); + sums_8x2_32bit = _mm512_add_epi32(sums_8x2_32bit, sums_8x2_32bit_new); + + _mm256_storeu_epi8(trailing_buf + 0 * 16 * 4, r0_0); + _mm256_storeu_epi8(trailing_buf + 2 * 16 * 4, r0_1); + _mm256_storeu_epi8(trailing_buf + 4 * 16 * 4, r1_0); + _mm256_storeu_epi8(trailing_buf + 6 * 16 * 4, r1_1); + _mm256_storeu_epi8(trailing_buf + 1 * 16 * 4, r2_0); + _mm256_storeu_epi8(trailing_buf + 3 * 16 * 4, r2_1); + _mm256_storeu_epi8(trailing_buf + 5 * 16 * 4, r3_0); + _mm256_storeu_epi8(trailing_buf + 7 * 16 * 4, r3_1); + } + + packed_ptr += 16 * kNumChunkedSrcRows; + src_ptr0 += src_inc0; + src_ptr1 += src_inc1; + src_ptr2 += src_inc2; + src_ptr3 += src_inc3; + src_ptr4 += src_inc4; + src_ptr5 += src_inc5; + src_ptr6 += src_inc6; + src_ptr7 += src_inc7; + } + } + + if (sums_ptr) { + const __m256i sums_adjustment_v = _mm256_set1_epi32(sums_adjustment); + + __m256i sums = _mm256_loadu_epi32(sums_ptr); + const __m512i idx = + _mm512_set_epi32(15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0); + + // We earlier used madd for pairwise 32-bit sums, and now we deinterlace the + // neighbours, finshing up by adding them to the stored accumulated sums. + const __m512i sums_2x8_32bit = + _mm512_permutexvar_epi32(idx, sums_8x2_32bit); + sums = _mm256_add_epi32(sums, sums_adjustment_v); + sums = _mm256_add_epi32(sums, _mm512_castsi512_si256(sums_2x8_32bit)); + sums = _mm256_add_epi32(sums, _mm512_extracti32x8_epi32(sums_2x8_32bit, 1)); + + _mm256_storeu_epi32(sums_ptr, sums); + } +} + +inline __m512 LoaduTwo(const float* addr_lo, const float* addr_hi) { + const __m512 lower_filled = _mm512_castps256_ps512(_mm256_loadu_ps(addr_lo)); + return _mm512_insertf32x8(lower_filled, _mm256_loadu_ps(addr_hi), 1); +} + +inline __m512 MaskLoaduTwo(__mmask8 row_mask, const float* addr_lo, + const float* addr_hi) { + const __m512 lower_filled = + _mm512_castps256_ps512(_mm256_maskz_loadu_ps(row_mask, addr_lo)); + return _mm512_insertf32x8(lower_filled, + _mm256_maskz_loadu_ps(row_mask, addr_hi), 1); +} + +inline __m512 Mm512UnpackloPsx2(const __m512 a, const __m512 b) { + return _mm512_castpd_ps( + _mm512_unpacklo_pd(_mm512_castps_pd(a), _mm512_castps_pd(b))); +} + +inline __m512 Mm512UnpackhiPsx2(const __m512 a, const __m512 b) { + return _mm512_castpd_ps( + _mm512_unpackhi_pd(_mm512_castps_pd(a), _mm512_castps_pd(b))); +} + +inline void HalfPackFloatAvx512(const float* src_ptr, const float* zerobuf, + int src_stride, int remaining_src_cols, + int src_rows, float* packed_ptr, + float* trailing_buf) { + const float* src_ptr0 = src_ptr; + const float* src_ptr1 = src_ptr0 + src_stride; + const float* src_ptr2 = src_ptr1 + src_stride; + const float* src_ptr3 = src_ptr2 + src_stride; + const float* src_ptr4 = src_ptr3 + src_stride; + const float* src_ptr5 = src_ptr4 + src_stride; + const float* src_ptr6 = src_ptr5 + src_stride; + const float* src_ptr7 = src_ptr6 + src_stride; + std::int64_t src_inc0 = 8; + std::int64_t src_inc1 = 8; + std::int64_t src_inc2 = 8; + std::int64_t src_inc3 = 8; + std::int64_t src_inc4 = 8; + std::int64_t src_inc5 = 8; + std::int64_t src_inc6 = 8; + std::int64_t src_inc7 = 8; + if (remaining_src_cols < 8) { + if (remaining_src_cols <= 0) { + src_ptr0 = zerobuf; + src_inc0 = 0; + } + if (remaining_src_cols <= 1) { + src_ptr1 = zerobuf; + src_inc1 = 0; + } + if (remaining_src_cols <= 2) { + src_ptr2 = zerobuf; + src_inc2 = 0; + } + if (remaining_src_cols <= 3) { + src_ptr3 = zerobuf; + src_inc3 = 0; + } + if (remaining_src_cols <= 4) { + src_ptr4 = zerobuf; + src_inc4 = 0; + } + if (remaining_src_cols <= 5) { + src_ptr5 = zerobuf; + src_inc5 = 0; + } + if (remaining_src_cols <= 6) { + src_ptr6 = zerobuf; + src_inc6 = 0; + } + src_ptr7 = zerobuf; + src_inc7 = 0; + } + + for (int k = 0; k < src_rows; k += 16) { + for (int m = 0; m < 2; ++m) { + const int available_src_rows = src_rows - k - 8 * m; + // Effectively, + // available_src_rows = std::max(0, std::min(8, src_rows - k - 8 * m)); + // but treat each case separately. + if (available_src_rows > 7) { + __m512 t0, t1, t2, t3; + __m512 r0, r1, r2, r3; + + t0 = LoaduTwo(src_ptr0, src_ptr4); + t1 = LoaduTwo(src_ptr1, src_ptr5); + t2 = LoaduTwo(src_ptr2, src_ptr6); + t3 = LoaduTwo(src_ptr3, src_ptr7); + + r0 = _mm512_unpacklo_ps(t0, t1); + r2 = _mm512_unpackhi_ps(t0, t1); + r1 = _mm512_unpacklo_ps(t2, t3); + r3 = _mm512_unpackhi_ps(t2, t3); + + t0 = Mm512UnpackloPsx2(r0, r1); + t2 = Mm512UnpackhiPsx2(r0, r1); + t1 = Mm512UnpackloPsx2(r2, r3); + t3 = Mm512UnpackhiPsx2(r2, r3); + + r0 = _mm512_shuffle_f32x4(t0, t1, 0x88); + r1 = _mm512_shuffle_f32x4(t0, t1, 0xdd); + r2 = _mm512_shuffle_f32x4(t2, t3, 0x88); + r3 = _mm512_shuffle_f32x4(t2, t3, 0xdd); + + _mm256_storeu_ps(packed_ptr + 0 * 16, _mm512_castps512_ps256(r0)); + _mm256_storeu_ps(packed_ptr + 2 * 16, _mm512_extractf32x8_ps(r0, 1)); + _mm256_storeu_ps(packed_ptr + 4 * 16, _mm512_castps512_ps256(r1)); + _mm256_storeu_ps(packed_ptr + 6 * 16, _mm512_extractf32x8_ps(r1, 1)); + _mm256_storeu_ps(packed_ptr + 1 * 16, _mm512_castps512_ps256(r2)); + _mm256_storeu_ps(packed_ptr + 3 * 16, _mm512_extractf32x8_ps(r2, 1)); + _mm256_storeu_ps(packed_ptr + 5 * 16, _mm512_castps512_ps256(r3)); + _mm256_storeu_ps(packed_ptr + 7 * 16, _mm512_extractf32x8_ps(r3, 1)); + } else if (available_src_rows > 0) { + const __mmask8 row_mask = + (static_cast(1) << available_src_rows) - 1; + + __m512 t0, t1, t2, t3; + __m512 r0, r1, r2, r3; + + t0 = MaskLoaduTwo(row_mask, src_ptr0, src_ptr4); + t1 = MaskLoaduTwo(row_mask, src_ptr1, src_ptr5); + t2 = MaskLoaduTwo(row_mask, src_ptr2, src_ptr6); + t3 = MaskLoaduTwo(row_mask, src_ptr3, src_ptr7); + + r0 = _mm512_unpacklo_ps(t0, t1); + r2 = _mm512_unpackhi_ps(t0, t1); + r1 = _mm512_unpacklo_ps(t2, t3); + r3 = _mm512_unpackhi_ps(t2, t3); + + t0 = Mm512UnpackloPsx2(r0, r1); + t2 = Mm512UnpackhiPsx2(r0, r1); + t1 = Mm512UnpackloPsx2(r2, r3); + t3 = Mm512UnpackhiPsx2(r2, r3); + + r0 = _mm512_shuffle_f32x4(t0, t1, 0x88); + r1 = _mm512_shuffle_f32x4(t0, t1, 0xdd); + r2 = _mm512_shuffle_f32x4(t2, t3, 0x88); + r3 = _mm512_shuffle_f32x4(t2, t3, 0xdd); + + _mm256_storeu_ps(trailing_buf + 0 * 16, _mm512_castps512_ps256(r0)); + _mm256_storeu_ps(trailing_buf + 2 * 16, _mm512_extractf32x8_ps(r0, 1)); + _mm256_storeu_ps(trailing_buf + 4 * 16, _mm512_castps512_ps256(r1)); + _mm256_storeu_ps(trailing_buf + 6 * 16, _mm512_extractf32x8_ps(r1, 1)); + _mm256_storeu_ps(trailing_buf + 1 * 16, _mm512_castps512_ps256(r2)); + _mm256_storeu_ps(trailing_buf + 3 * 16, _mm512_extractf32x8_ps(r2, 1)); + _mm256_storeu_ps(trailing_buf + 5 * 16, _mm512_castps512_ps256(r3)); + // Do not store _mm512_extractf32x8_ps(r3, 1). + } + + packed_ptr += 16 * 8; + src_ptr0 += src_inc0; + src_ptr1 += src_inc1; + src_ptr2 += src_inc2; + src_ptr3 += src_inc3; + src_ptr4 += src_inc4; + src_ptr5 += src_inc5; + src_ptr6 += src_inc6; + src_ptr7 += src_inc7; + } + } +} + +inline void ZeroHalfFloatAvx512(int src_rows, float* packed_ptr) { + const int non_trailing_rows = src_rows & ~7; + for (int k = 0; k < non_trailing_rows; ++k) { + for (int j = 0; j < 8; ++j) { + packed_ptr[j] = 0.0f; + } + packed_ptr += 16; + } +} + +} // namespace. + +void Pack8bitAvx512(const std::int8_t* src_ptr, std::int8_t input_xor, + const std::int8_t* zerobuf, int src_stride, + int remaining_src_cols, int src_rows, + std::int8_t* packed_ptr, std::int32_t* sums_ptr) { + profiler::ScopeLabel label("Pack kAvx512 8bit"); + + using Layout = PackImpl8bitAvx512::Layout; + constexpr int kHalfBlockOffset = 32; + RUY_DCHECK_EQ(kHalfBlockOffset * 2, Layout::kRows * Layout::kCols); + static constexpr int kHalfLayoutCols = + PackImpl8bitAvx512::kHalfLayoutCols; // Half the number of cols in a + // block. + RUY_DCHECK_EQ(kHalfLayoutCols, 8); + RUY_DCHECK_EQ(Layout::kCols, 16); + RUY_DCHECK_EQ(Layout::kRows, 4); + + // Each Layout::Rows is 4 contiguous input, contiguous packed elements. + // We process 8 of these chunks at a time, padding short input chunks. + constexpr int kNumRowChunks = 8; + + // Each packed block is 4*16, and there are normally 8. The trailing block is + // only slightly shorter. + constexpr int kTrailingBufSize = + kNumRowChunks * Layout::kCols * Layout::kRows; + std::int8_t trailing_buf[kTrailingBufSize]; + memset(trailing_buf, 0, kTrailingBufSize * sizeof(std::int8_t)); + + std::int32_t* second_sums_ptr = + sums_ptr ? sums_ptr + kHalfLayoutCols : nullptr; + if (remaining_src_cols > kHalfLayoutCols) { + HalfPack8bitAvx512(src_ptr, input_xor, zerobuf, src_stride, + remaining_src_cols, src_rows, packed_ptr, sums_ptr, + trailing_buf); + HalfPack8bitAvx512(src_ptr + src_stride * kHalfLayoutCols, input_xor, + zerobuf, src_stride, + remaining_src_cols - kHalfLayoutCols, src_rows, + packed_ptr + kHalfBlockOffset, second_sums_ptr, + trailing_buf + kHalfBlockOffset); + } else { + HalfPack8bitAvx512(src_ptr, input_xor, zerobuf, src_stride, + remaining_src_cols, src_rows, packed_ptr, sums_ptr, + trailing_buf); + ZeroHalf8bitAvx512(src_rows, zerobuf[0] ^ input_xor, + packed_ptr + kHalfBlockOffset); + // The kernel may not need the second half-blocks sums to be set. + if (second_sums_ptr) { + for (int i = 0; i < kHalfLayoutCols; ++i) { + second_sums_ptr[i] = (zerobuf[0] ^ input_xor) * ((src_rows + 3) & ~3); + } + } + } + constexpr int kChunkedRowMask = kNumRowChunks * Layout::kRows - 1; + const bool trailing_data = (src_rows & kChunkedRowMask) > 0; + // If the number of source rows is not a multiple of kChunkedRowMask, there + // will be data in the trailing buffer, + if (trailing_data > 0) { + const int non_trailing_rows = src_rows & ~kChunkedRowMask; + // Destination "rows" are padded to next highest multiple of Layout::kRows. + const int dst_rows = (src_rows + 3) & ~3; + const int trailing_rows = dst_rows - non_trailing_rows; + memcpy(packed_ptr + Layout::kCols * non_trailing_rows, trailing_buf, + Layout::kCols * trailing_rows * sizeof(std::int8_t)); + } +} + +void PackFloatAvx512(const float* src_ptr, const float* zerobuf, int src_stride, + int remaining_src_cols, int src_rows, float* packed_ptr) { + profiler::ScopeLabel label("Pack kAvx512 float"); + float trailing_buf[7 * 16]; + if (remaining_src_cols > 8) { + HalfPackFloatAvx512(src_ptr, zerobuf, src_stride, remaining_src_cols, + src_rows, packed_ptr, trailing_buf); + HalfPackFloatAvx512(src_ptr + src_stride * 8, zerobuf, src_stride, + remaining_src_cols - 8, src_rows, packed_ptr + 8, + trailing_buf + 8); + } else { + memset(trailing_buf, 0, sizeof(trailing_buf)); + HalfPackFloatAvx512(src_ptr, zerobuf, src_stride, remaining_src_cols, + src_rows, packed_ptr, trailing_buf); + ZeroHalfFloatAvx512(src_rows, packed_ptr + 8); + } + const int trailing_rows = src_rows & 7; + if (trailing_rows > 0) { + const int non_trailing_rows = src_rows & ~7; + memcpy(packed_ptr + 16 * non_trailing_rows, trailing_buf, + 16 * trailing_rows * sizeof(float)); + } +} + +#endif // RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_INTRINSICS) + +} // namespace ruy diff --git a/ruy/pack_avxvnni.cc b/ruy/pack_avxvnni.cc new file mode 100644 index 0000000..bb9a730 --- /dev/null +++ b/ruy/pack_avxvnni.cc @@ -0,0 +1,478 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "ruy/check_macros.h" +#include "ruy/matrix.h" +#include "ruy/opt_set.h" +#include "ruy/pack.h" +#include "ruy/path.h" +#include "ruy/platform.h" +#include "ruy/profiler/instrumentation.h" + +#if RUY_PLATFORM(AVX_VNNI) && RUY_OPT_ENABLED(RUY_OPT_INTRINSICS) +#include // IWYU pragma: keep +#endif + +namespace ruy { + +#if !(RUY_PLATFORM(AVX_VNNI) && RUY_OPT_ENABLED(RUY_OPT_ASM)) + +void Pack8bitAvxVnni(const std::int8_t* src_ptr, std::int8_t input_xor, + const std::int8_t* zerobuf, int src_stride, + int remaining_src_cols, int src_rows, + std::int8_t* packed_ptr, std::int32_t* sums_ptr) { + // CPU-ID-based checks should disable the path that would reach this point. + RUY_DCHECK(false); +} + +void PackFloatAvxVnni(const float* src_ptr, const float* zerobuf, + int src_stride, int remaining_src_cols, int src_rows, + float* packed_ptr) { + // CPU-ID-based checks should disable the path that would reach this point. + RUY_DCHECK(false); +} + +#else // RUY_PLATFORM(AVX_VNNI) && RUY_OPT_ENABLED(RUY_OPT_ASM) + +// The first int8_t template parameter is arbitrary: this routine is common to +// all 8-bit source matrix types. +using PackImpl8bitAvxVnni = + PackImpl, + std::int8_t, std::int8_t, std::int32_t>; + +namespace { + +inline void ZeroHalf8bitAvxVnni(int src_rows, std::int8_t packed_zero_point, + std::int8_t* packed_ptr) { + const int non_trailing_blocks = (src_rows & ~31) >> 2; + // This routine fills half blocks, and typically fills the second halves. Thus + // packed_ptr is already offset by 8*4. + for (int k = 0; k < non_trailing_blocks; ++k) { + for (int j = 0; j < (8 * 4); ++j) { + packed_ptr[16 * 4 * k + j] = packed_zero_point; + } + } +} + +inline void HalfPack8bitAvxVnni(const std::int8_t* src_ptr, + std::int8_t input_xor, + const std::int8_t* zerobuf, int src_stride, + int remaining_src_cols, int src_rows, + std::int8_t* packed_ptr, std::int32_t* sums_ptr, + std::int8_t* trailing_buf) { + std::int8_t in_data[8][8][4]; + + const std::int8_t* src_ptr0 = src_ptr; + const std::int8_t* src_ptr1 = src_ptr0 + src_stride; + const std::int8_t* src_ptr2 = src_ptr1 + src_stride; + const std::int8_t* src_ptr3 = src_ptr2 + src_stride; + const std::int8_t* src_ptr4 = src_ptr3 + src_stride; + const std::int8_t* src_ptr5 = src_ptr4 + src_stride; + const std::int8_t* src_ptr6 = src_ptr5 + src_stride; + const std::int8_t* src_ptr7 = src_ptr6 + src_stride; + std::int64_t src_inc0 = 8 * 4; + std::int64_t src_inc1 = 8 * 4; + std::int64_t src_inc2 = 8 * 4; + std::int64_t src_inc3 = 8 * 4; + std::int64_t src_inc4 = 8 * 4; + std::int64_t src_inc5 = 8 * 4; + std::int64_t src_inc6 = 8 * 4; + std::int64_t src_inc7 = 8 * 4; + if (remaining_src_cols < 8) { + if (remaining_src_cols <= 0) { + src_ptr0 = zerobuf; + src_inc0 = 0; + } + if (remaining_src_cols <= 1) { + src_ptr1 = zerobuf; + src_inc1 = 0; + } + if (remaining_src_cols <= 2) { + src_ptr2 = zerobuf; + src_inc2 = 0; + } + if (remaining_src_cols <= 3) { + src_ptr3 = zerobuf; + src_inc3 = 0; + } + if (remaining_src_cols <= 4) { + src_ptr4 = zerobuf; + src_inc4 = 0; + } + if (remaining_src_cols <= 5) { + src_ptr5 = zerobuf; + src_inc5 = 0; + } + if (remaining_src_cols <= 6) { + src_ptr6 = zerobuf; + src_inc6 = 0; + } + src_ptr7 = zerobuf; + src_inc7 = 0; + } + + const std::int8_t zero_point = zerobuf[0]; + + if (sums_ptr) { + for (int i = 0; i < 8; ++i) { + sums_ptr[i] = 0; + } + } + + // The overall packing effectively pads the source rows to + // (src_rows + 63) & ~63. The iteration over k may skip when m=1, and then we + // only pack for (src_rows + 31) & ~31. When there is an incomplete + // destination block, this is stored into trailing_buf instead of packed_ptr. + for (int k = 0; k < src_rows; k += 16 * 4) { + for (int m = 0; m < 2; ++m) { + // Available source rows. + // If this is less than 0 (for m=1), we skip, having filled trailing + // buffer for m=0. Also, if source rows is zero on m=1, then we filled + // exactly to the end of the column in the packed buffer. + const int packed_rows = src_rows - k - 8 * m * 4; + // Effectively, + // packed_rows = std::max(0, std::min(8, src_rows - k - 8 * m)); + // but treat each case separately. + if (packed_rows >= (8 * 4)) { + for (int i = 0; i < 8; ++i) { + for (int s = 0; s < 4; ++s) { + in_data[0][i][s] = src_ptr0[i * 4 + s]; + in_data[1][i][s] = src_ptr1[i * 4 + s]; + in_data[2][i][s] = src_ptr2[i * 4 + s]; + in_data[3][i][s] = src_ptr3[i * 4 + s]; + in_data[4][i][s] = src_ptr4[i * 4 + s]; + in_data[5][i][s] = src_ptr5[i * 4 + s]; + in_data[6][i][s] = src_ptr6[i * 4 + s]; + in_data[7][i][s] = src_ptr7[i * 4 + s]; + } + } + for (int i = 0; i < 8; ++i) { + for (int j = 0; j < 8; ++j) { + for (int s = 0; s < 4; ++s) { + packed_ptr[(16 * i + j) * 4 + s] = + static_cast(in_data[j][i][s] ^ input_xor); + } + if (sums_ptr) { + for (int s = 0; s < 4; ++s) { + sums_ptr[j] += in_data[j][i][s] ^ input_xor; + } + } + } + } + } else if (packed_rows > 0) { + RUY_DCHECK_LT(packed_rows >> 2, 8); + int i = 0; + for (; i < (packed_rows >> 2); ++i) { + for (int s = 0; s < 4; ++s) { + in_data[0][i][s] = src_ptr0[i * 4 + s]; + in_data[1][i][s] = src_ptr1[i * 4 + s]; + in_data[2][i][s] = src_ptr2[i * 4 + s]; + in_data[3][i][s] = src_ptr3[i * 4 + s]; + in_data[4][i][s] = src_ptr4[i * 4 + s]; + in_data[5][i][s] = src_ptr5[i * 4 + s]; + in_data[6][i][s] = src_ptr6[i * 4 + s]; + in_data[7][i][s] = src_ptr7[i * 4 + s]; + } + } + if (i < ((packed_rows + 3) >> 2)) { + int s = 0; + for (; s < (packed_rows & 3); ++s) { + in_data[0][i][s] = src_ptr0[i * 4 + s]; + in_data[1][i][s] = src_ptr1[i * 4 + s]; + in_data[2][i][s] = src_ptr2[i * 4 + s]; + in_data[3][i][s] = src_ptr3[i * 4 + s]; + in_data[4][i][s] = src_ptr4[i * 4 + s]; + in_data[5][i][s] = src_ptr5[i * 4 + s]; + in_data[6][i][s] = src_ptr6[i * 4 + s]; + in_data[7][i][s] = src_ptr7[i * 4 + s]; + } + RUY_DCHECK_LE(s, 4); + for (; s < 4; ++s) { + for (int j = 0; j < 8; ++j) { + in_data[j][i][s] = zero_point; + } + } + ++i; + } + // We do not care what goes into the trailing buffer, but we want + // in_data[...] ^ input_xor == 0 for irrelevant values in the summation. + // + // It might prove better in optimized code to pad uniformly with + // zero_point, and compensate by initializing the summations with the + // compensating offset, effectively + // ((input_xor - zero_point) ^ input_xor) * + // 4 * (8 - ((packed_rows + 3) >> 2)). + for (; i < 8; ++i) { + for (int s = 0; s < 4; ++s) { + for (int j = 0; j < 8; ++j) { + in_data[j][i][s] = input_xor; + } + } + } + // We loop through [0, 8) rather than [0, (packed_rows + 3) >> 2), since + // that emulates what we might do in fully-optimized code. + if (sums_ptr) { + for (int i = 0; i < 8; ++i) { + for (int j = 0; j < 8; ++j) { + for (int s = 0; s < 4; ++s) { + trailing_buf[(16 * i + j) * 4 + s] = + static_cast(in_data[j][i][s] ^ input_xor); + sums_ptr[j] += in_data[j][i][s] ^ input_xor; + } + } + } + } else { + for (int i = 0; i < 8; ++i) { + for (int j = 0; j < 8; ++j) { + for (int s = 0; s < 4; ++s) { + trailing_buf[(16 * i + j) * 4 + s] = + static_cast(in_data[j][i][s] ^ input_xor); + } + } + } + } + } + + packed_ptr += 16 * 8 * 4; + src_ptr0 += src_inc0; + src_ptr1 += src_inc1; + src_ptr2 += src_inc2; + src_ptr3 += src_inc3; + src_ptr4 += src_inc4; + src_ptr5 += src_inc5; + src_ptr6 += src_inc6; + src_ptr7 += src_inc7; + } + } +} + +inline void HalfPackFloatAvxVnni(const float* src_ptr, const float* zerobuf, + int src_stride, int remaining_src_cols, + int src_rows, float* packed_ptr, + float* trailing_buf) { + float in_data[8][8]; + + const float* src_ptr0 = src_ptr; + const float* src_ptr1 = src_ptr0 + src_stride; + const float* src_ptr2 = src_ptr1 + src_stride; + const float* src_ptr3 = src_ptr2 + src_stride; + const float* src_ptr4 = src_ptr3 + src_stride; + const float* src_ptr5 = src_ptr4 + src_stride; + const float* src_ptr6 = src_ptr5 + src_stride; + const float* src_ptr7 = src_ptr6 + src_stride; + std::int64_t src_inc0 = 8; + std::int64_t src_inc1 = 8; + std::int64_t src_inc2 = 8; + std::int64_t src_inc3 = 8; + std::int64_t src_inc4 = 8; + std::int64_t src_inc5 = 8; + std::int64_t src_inc6 = 8; + std::int64_t src_inc7 = 8; + if (remaining_src_cols < 8) { + if (remaining_src_cols <= 0) { + src_ptr0 = zerobuf; + src_inc0 = 0; + } + if (remaining_src_cols <= 1) { + src_ptr1 = zerobuf; + src_inc1 = 0; + } + if (remaining_src_cols <= 2) { + src_ptr2 = zerobuf; + src_inc2 = 0; + } + if (remaining_src_cols <= 3) { + src_ptr3 = zerobuf; + src_inc3 = 0; + } + if (remaining_src_cols <= 4) { + src_ptr4 = zerobuf; + src_inc4 = 0; + } + if (remaining_src_cols <= 5) { + src_ptr5 = zerobuf; + src_inc5 = 0; + } + if (remaining_src_cols <= 6) { + src_ptr6 = zerobuf; + src_inc6 = 0; + } + src_ptr7 = zerobuf; + src_inc7 = 0; + } + + for (int k = 0; k < src_rows; k += 16) { + for (int m = 0; m < 2; ++m) { + const int packed_rows = src_rows - k - 8 * m; + // Effectively, + // packed_rows = std::max(0, std::min(8, src_rows - k - 8 * m)); + // but treat each case separately. + if (packed_rows > 7) { + for (int i = 0; i < 8; ++i) { + in_data[0][i] = src_ptr0[i]; + in_data[1][i] = src_ptr1[i]; + in_data[2][i] = src_ptr2[i]; + in_data[3][i] = src_ptr3[i]; + in_data[4][i] = src_ptr4[i]; + in_data[5][i] = src_ptr5[i]; + in_data[6][i] = src_ptr6[i]; + in_data[7][i] = src_ptr7[i]; + } + for (int i = 0; i < 8; ++i) { + for (int j = 0; j < 8; ++j) { + packed_ptr[16 * i + j] = in_data[j][i]; + } + } + } else if (packed_rows > 0) { + for (int i = 0; i < packed_rows; ++i) { + in_data[0][i] = src_ptr0[i]; + in_data[1][i] = src_ptr1[i]; + in_data[2][i] = src_ptr2[i]; + in_data[3][i] = src_ptr3[i]; + in_data[4][i] = src_ptr4[i]; + in_data[5][i] = src_ptr5[i]; + in_data[6][i] = src_ptr6[i]; + in_data[7][i] = src_ptr7[i]; + } + for (int i = packed_rows; i < 8; ++i) { + in_data[0][i] = 0.0f; + in_data[1][i] = 0.0f; + in_data[2][i] = 0.0f; + in_data[3][i] = 0.0f; + in_data[4][i] = 0.0f; + in_data[5][i] = 0.0f; + in_data[6][i] = 0.0f; + in_data[7][i] = 0.0f; + } + // We loop through [0, 7) rather than [0, packed_rows), since that + // emulates what we might do in fully-optimized code. + for (int i = 0; i < 7; ++i) { + for (int j = 0; j < 8; ++j) { + trailing_buf[16 * i + j] = in_data[j][i]; + } + } + } + + packed_ptr += 16 * 8; + src_ptr0 += src_inc0; + src_ptr1 += src_inc1; + src_ptr2 += src_inc2; + src_ptr3 += src_inc3; + src_ptr4 += src_inc4; + src_ptr5 += src_inc5; + src_ptr6 += src_inc6; + src_ptr7 += src_inc7; + } + } +} + +inline void ZeroHalfFloatAvxVnni(int src_rows, float* packed_ptr) { + const int non_trailing_rows = src_rows & ~7; + for (int k = 0; k < non_trailing_rows; ++k) { + for (int j = 0; j < 8; ++j) { + packed_ptr[j] = 0.0f; + } + packed_ptr += 16; + } +} + +} // namespace. + +// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. +// Optimization is not finished. In particular the dimensions of the kernel +// blocks can be changed as desired. +// +// When removing this comment, update profiling label below. +void Pack8bitAvxVnni(const std::int8_t* src_ptr, std::int8_t input_xor, + const std::int8_t* zerobuf, int src_stride, + int remaining_src_cols, int src_rows, + std::int8_t* packed_ptr, std::int32_t* sums_ptr) { + profiler::ScopeLabel label("Pack kAvxVnni 8bit (UNFINISHED)"); + + // Each packed block is 4*16, and there are normally 8. The trailing block is + // only slightly shorter. + std::int8_t trailing_buf[8 * 16 * 4]; + memset(trailing_buf, 0, 8 * 16 * 4 * sizeof(std::int8_t)); + + std::int32_t* second_sums_ptr = sums_ptr ? sums_ptr + 8 : nullptr; + if (remaining_src_cols > 8) { + HalfPack8bitAvxVnni(src_ptr, input_xor, zerobuf, src_stride, + remaining_src_cols, src_rows, packed_ptr, sums_ptr, + trailing_buf); + HalfPack8bitAvxVnni(src_ptr + src_stride * 8, input_xor, zerobuf, + src_stride, remaining_src_cols - 8, src_rows, + packed_ptr + 8 * 4, second_sums_ptr, + trailing_buf + 8 * 4); + } else { + HalfPack8bitAvxVnni(src_ptr, input_xor, zerobuf, src_stride, + remaining_src_cols, src_rows, packed_ptr, sums_ptr, + trailing_buf); + ZeroHalf8bitAvxVnni(src_rows, zerobuf[0] ^ input_xor, packed_ptr + 8 * 4); + // The kernel may not need the second half-blocks sums to be set. + if (second_sums_ptr) { + for (int i = 0; i < 8; ++i) { + second_sums_ptr[i] = (zerobuf[0] ^ input_xor) * ((src_rows + 3) & ~3); + } + } + } + const bool trailing_data = (src_rows & 31) > 0; + // If the number of source rows is not a multiple of 32, there will be data in + // the trailing buffer, + if (trailing_data > 0) { + const int non_trailing_rows = src_rows & ~31; + // Destination "rows" are padded to next highest multiple of 4. + const int dst_rows = (src_rows + 3) & ~3; + const int trailing_rows = dst_rows - non_trailing_rows; + memcpy(packed_ptr + 16 * non_trailing_rows, trailing_buf, + 16 * trailing_rows * sizeof(std::int8_t)); + } +} + +// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. +// Optimization is not finished. In particular the dimensions of the kernel +// blocks can be changed as desired. +// +// When removing this comment, update profiling label below. +void PackFloatAvxVnni(const float* src_ptr, const float* zerobuf, + int src_stride, int remaining_src_cols, int src_rows, + float* packed_ptr) { + profiler::ScopeLabel label("Pack kAvxVnni float (UNFINISHED)"); + float trailing_buf[7 * 16]; + if (remaining_src_cols > 8) { + HalfPackFloatAvxVnni(src_ptr, zerobuf, src_stride, remaining_src_cols, + src_rows, packed_ptr, trailing_buf); + HalfPackFloatAvxVnni(src_ptr + src_stride * 8, zerobuf, src_stride, + remaining_src_cols - 8, src_rows, packed_ptr + 8, + trailing_buf + 8); + } else { + memset(trailing_buf, 0, sizeof(trailing_buf)); + HalfPackFloatAvxVnni(src_ptr, zerobuf, src_stride, remaining_src_cols, + src_rows, packed_ptr, trailing_buf); + ZeroHalfFloatAvxVnni(src_rows, packed_ptr + 8); + } + const int trailing_rows = src_rows & 7; + if (trailing_rows > 0) { + const int non_trailing_rows = src_rows & ~7; + memcpy(packed_ptr + 16 * non_trailing_rows, trailing_buf, + 16 * trailing_rows * sizeof(float)); + } +} + +#endif // RUY_PLATFORM(AVX_VNNI) && RUY_OPT_ENABLED(RUY_OPT_INTRINSICS) + +} // namespace ruy diff --git a/ruy/pack_common.h b/ruy/pack_common.h new file mode 100644 index 0000000..5c03afd --- /dev/null +++ b/ruy/pack_common.h @@ -0,0 +1,246 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// # What is "packing"? +// +// Before feeding data to the gemm kernels (the parts of Ruy that do lots +// of multiply-add operations), Ruy first performs a data transformation (which +// we call "packing") on the input matrices. This transformation has two main +// goals: +// - rearrange data into blocks that are a convenient size/layout for the gemm +// kernels to consume. This helps make the memory access pattern of the gemm +// kernel simpler and more contiguous, and puts the data in a layout most +// convenient for specific arithmetic instructions in the gemm kernel. +// - compute row/column sums needed for handling quantization with non-symmetric +// zero points. +// +// # Simplified algorithmic analysis of packing +// +// Packing is a relatively simple transformation which does a small constant +// amount of work on each element of an input matrix, and hence for an NxM +// matrix performs O(N*M) work. If N and M are of the same order, then this is +// O(N^2) work. +// +// A NxKxM matrix multiplication requires N*K*M multiply-accumulate operations. +// Note that if N, K, and M are all the same order, then the number of +// multiply-accumulate operations is O(N^3). +// +// Thus, the O(N^2) cost of packing is small compared to the O(N^3) work, in the +// case of all dimensions being roughly the same order. +// +// # Packing cost can be significant +// +// When matrix * matrix multiplications begin to look more like matrix * vector +// multiplications, packing cost can become significant. We sometimes call these +// cases "gemv-like". +// +// Continuing the algorithmic analysis above, if we consider a case where an +// NxKxM matrix multiplication has either N = O(1) or M = O(1), then the +// situation is different. In this case, the multiply-accumulate work is only +// quadratic, so the quadratic cost of packing can be come significant. +// +// Another way to say this is that the cost of packing an input matrix (either +// the LHS or RHS) is amortized across the non-depth dimension of the opposite +// input matrix. Thus, when the LHS has very few rows or the RHS has very few +// columns, the cost of packing the opposite input matrix can become +// significant. +// +// As a rough rule of thumb, the cost of packing starts to become significant +// when either N or M is below 32 (and other dimensions are hundreds), with very +// significant packing costs at 8 or below. This varies by data type, Path, and +// tuning, so these numbers are only rough guides. +// +// One practical use case that is affected by this is inference of +// fully connected neural network layers with a low batch size. The weight +// matrix (which is a constant for inference) is the one affected by significant +// packing cost. +// +// Ruy provides an API in ruy_advanced.h for advanced users to pre-pack +// input matrices that are affected by significant packing costs. +// +// # Implementation notes +// +// Ruy's packing routines always operate on a range of columns and can be +// applied to either the LHS or RHS. This is possible because Ruy internally +// implements a TrMul, so the accumulation along depth is done along columns of +// both the LHS and RHS (whereas for a normal Mul the accumulation along depth +// for the LHS is along rows). As another example, we are always computing +// column sums for quantization (and never row sums, since the LHS is +// transposed). + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PACK_COMMON_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PACK_COMMON_H_ + +#include + +#include "ruy/check_macros.h" +#include "ruy/common.h" +#include "ruy/internal_matrix.h" +#include "ruy/matrix.h" +#include "ruy/opt_set.h" +#include "ruy/path.h" +#include "ruy/platform.h" +#include "ruy/profiler/instrumentation.h" +#include "ruy/tune.h" + +namespace ruy { + +template +struct PackedTypeImpl { + using Type = Scalar; +}; + +#if RUY_PLATFORM(NEON_32) +struct PackParams8bit { + const void* src_ptr0; + const void* src_ptr1; + const void* src_ptr2; + const void* src_ptr3; + const std::int32_t* sums_ptr; + const std::int8_t* packed_ptr; + int src_inc0; + int src_inc1; + int src_inc2; + int src_inc3; + int src_rows; + int src_zero_point; + int input_xor; +}; + +inline void MakePackParams8bit(const void* src_ptr0, const void* src_ptr1, + const void* src_ptr2, const void* src_ptr3, + const std::int32_t* sums_ptr, + const std::int8_t* packed_ptr, int src_inc0, + int src_inc1, int src_inc2, int src_inc3, + int src_rows, int src_zero_point, int input_xor, + PackParams8bit* params) { + params->src_ptr0 = src_ptr0; + params->src_ptr1 = src_ptr1; + params->src_ptr2 = src_ptr2; + params->src_ptr3 = src_ptr3; + params->sums_ptr = sums_ptr; + params->packed_ptr = packed_ptr; + params->src_inc0 = src_inc0; + params->src_inc1 = src_inc1; + params->src_inc2 = src_inc2; + params->src_inc3 = src_inc3; + params->src_rows = src_rows; + params->src_zero_point = src_zero_point; + params->input_xor = input_xor; +} +#endif + +#if RUY_PLATFORM(NEON) +template <> +struct PackedTypeImpl { + using Type = std::int8_t; +}; +template <> +struct PackedTypeImpl { + using Type = std::int8_t; +}; +#elif RUY_PLATFORM(X86) +template <> +struct PackedTypeImpl { + using Type = std::int8_t; +}; +template <> +struct PackedTypeImpl { + using Type = std::int8_t; +}; +template <> +struct PackedTypeImpl { + using Type = std::int8_t; +}; +template <> +struct PackedTypeImpl { + using Type = std::int8_t; +}; +#endif + +template +using PackedType = typename PackedTypeImpl::Type; + +template +PackedScalar Pack(Scalar x) { + return x - SymmetricZeroPoint() + SymmetricZeroPoint(); +} + +template +struct PackImpl {}; + +#define RUY_INHERIT_PACK(PARENT, CHILD) \ + template \ + struct PackImpl \ + : PackImpl { \ + }; + +template +struct PackImpl { + static void Run(Tuning, const Matrix& src_matrix, + PackedMatrix* packed_matrix, int start_col, + int end_col) { + profiler::ScopeLabel label("Pack (generic)"); + RUY_DCHECK_EQ((end_col - start_col) % FixedKernelLayout::kCols, 0); + SumsType* sums = packed_matrix->sums; + for (int col = start_col; col < end_col; col++) { + SumsType accum = 0; + for (int row = 0; row < packed_matrix->layout.rows; row++) { + PackedScalar packed_val; + if (col < src_matrix.layout.cols && row < src_matrix.layout.rows) { + packed_val = Pack(Element(src_matrix, row, col)); + } else { + packed_val = packed_matrix->zero_point; + } + accum += packed_val; + *ElementPtr(packed_matrix, row, col) = packed_val; + } + if (sums) { + sums[col] = accum; + } + } + } +}; + +#if RUY_PLATFORM(NEON) +RUY_INHERIT_PACK(Path::kStandardCpp, Path::kNeon) +RUY_INHERIT_PACK(Path::kNeon, Path::kNeonDotprod) +#elif RUY_PLATFORM(X86) +RUY_INHERIT_PACK(Path::kStandardCpp, Path::kSse42) +RUY_INHERIT_PACK(Path::kSse42, Path::kAvx2) +RUY_INHERIT_PACK(Path::kAvx2, Path::kAvx512) +RUY_INHERIT_PACK(Path::kAvx512, Path::kAvxVnni) +#endif + +// Main entry point for packing. +template +void RunPack(Tuning tuning, const DMatrix& src_matrix, PMatrix* packed_matrix, + int start_col, int end_col) { + using SumsType = typename PackedMatrix::SumsType; + Matrix src = ToMatrix(src_matrix); + PackedMatrix packed = + ToPackedMatrix(*packed_matrix); + PackImpl::Run( + tuning, src, &packed, start_col, end_col); +} + +} // namespace ruy + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PACK_COMMON_H_ diff --git a/ruy/pack_sse42.cc b/ruy/pack_sse42.cc new file mode 100644 index 0000000..90c7250 --- /dev/null +++ b/ruy/pack_sse42.cc @@ -0,0 +1,471 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "ruy/check_macros.h" +#include "ruy/matrix.h" +#include "ruy/opt_set.h" +#include "ruy/pack.h" +#include "ruy/path.h" +#include "ruy/platform.h" +#include "ruy/profiler/instrumentation.h" + +#if RUY_PLATFORM(SSE42) && RUY_OPT_ENABLED(RUY_OPT_INTRINSICS) +#include // IWYU pragma: keep +#endif + +namespace ruy { + +#if !(RUY_PLATFORM(SSE42) && RUY_OPT_ENABLED(RUY_OPT_ASM)) + +void Pack8bitSse42(const std::int8_t* src_ptr, std::int8_t input_xor, + const std::int8_t* zerobuf, int src_stride, + int remaining_src_cols, int src_rows, + std::int8_t* packed_ptr, std::int32_t* sums_ptr) { + // CPU-ID-based checks should disable the path that would reach this point. + RUY_DCHECK(false); +} + +void PackFloatSse42(const float* src_ptr, const float* zerobuf, int src_stride, + int remaining_src_cols, int src_rows, float* packed_ptr) { + // CPU-ID-based checks should disable the path that would reach this point. + RUY_DCHECK(false); +} + +#else // RUY_PLATFORM(SSE42) && RUY_OPT_ENABLED(RUY_OPT_ASM) + +// The first int8_t template parameter is arbitrary: this routine is common to +// all 8-bit source matrix types. +using PackImpl8bitSse42 = + PackImpl, + std::int8_t, std::int8_t, std::int32_t>; + +using PackImplFloatSse42 = + PackImpl, float, + float, float>; + +namespace { + +inline void Pack8bitSse42Packer(const std::int8_t* src_ptr, + std::int8_t input_xor, + const std::int8_t* zerobuf, int src_stride, + int remaining_src_cols, int src_rows, + std::int8_t* packed_ptr, std::int32_t* sums_ptr, + std::int8_t* trailing_buf) { + using Layout = PackImpl8bitSse42::Layout; + RUY_DCHECK_EQ(Layout::kCols, 8); + RUY_DCHECK_EQ(Layout::kRows, 4); + // Each Layout::Rows is 4 contiguous input, contiguous packed elements. + // We process 8 of these chunks at a time, padding short input chunks. + constexpr int kNumRowChunks = 8; + constexpr int kNumChunkedSrcRows = kNumRowChunks * Layout::kRows; + + std::int8_t in_data[Layout::kCols][kNumRowChunks][Layout::kRows]; + + const std::int8_t* src_ptr0 = src_ptr; + const std::int8_t* src_ptr1 = src_ptr0 + src_stride; + const std::int8_t* src_ptr2 = src_ptr1 + src_stride; + const std::int8_t* src_ptr3 = src_ptr2 + src_stride; + const std::int8_t* src_ptr4 = src_ptr3 + src_stride; + const std::int8_t* src_ptr5 = src_ptr4 + src_stride; + const std::int8_t* src_ptr6 = src_ptr5 + src_stride; + const std::int8_t* src_ptr7 = src_ptr6 + src_stride; + std::int64_t src_inc0 = kNumChunkedSrcRows; + std::int64_t src_inc1 = kNumChunkedSrcRows; + std::int64_t src_inc2 = kNumChunkedSrcRows; + std::int64_t src_inc3 = kNumChunkedSrcRows; + std::int64_t src_inc4 = kNumChunkedSrcRows; + std::int64_t src_inc5 = kNumChunkedSrcRows; + std::int64_t src_inc6 = kNumChunkedSrcRows; + std::int64_t src_inc7 = kNumChunkedSrcRows; + // Handle cases where source does not have Layout::kCols (8) columns. + if (remaining_src_cols < 8) { + if (remaining_src_cols <= 0) { + src_ptr0 = zerobuf; + src_inc0 = 0; + } + if (remaining_src_cols <= 1) { + src_ptr1 = zerobuf; + src_inc1 = 0; + } + if (remaining_src_cols <= 2) { + src_ptr2 = zerobuf; + src_inc2 = 0; + } + if (remaining_src_cols <= 3) { + src_ptr3 = zerobuf; + src_inc3 = 0; + } + if (remaining_src_cols <= 4) { + src_ptr4 = zerobuf; + src_inc4 = 0; + } + if (remaining_src_cols <= 5) { + src_ptr5 = zerobuf; + src_inc5 = 0; + } + if (remaining_src_cols <= 6) { + src_ptr6 = zerobuf; + src_inc6 = 0; + } + src_ptr7 = zerobuf; + src_inc7 = 0; + } + + const std::int8_t zero_point = zerobuf[0]; + + if (sums_ptr) { + // i: Layout::kCols. + for (int i = 0; i < 8; ++i) { + sums_ptr[i] = 0; + } + } + + // The overall packing effectively pads the source rows to + // (src_rows + 63) & ~63. The iteration over k may skip when m=1, and then we + // only pack for (src_rows + 31) & ~31. When there is an incomplete + // destination block, this is stored into trailing_buf instead of packed_ptr. + for (int k = 0; k < src_rows; k += kNumChunkedSrcRows) { + // Available source rows. + // If this is less than 0 (for m=1), we skip, having filled trailing + // buffer for m=0. Also, if source rows is zero on m=1, then we filled + // exactly to the end of the column in the packed buffer. + const int available_src_rows = src_rows - k; + // Effectively, + // available rows = std::max(0, std::min(8, src_rows - k)); + // treat each case separately. + if (available_src_rows >= kNumChunkedSrcRows) { + // i: chunks, s: Layout::Rows. + for (int i = 0; i < 8; ++i) { + for (int s = 0; s < 4; ++s) { + in_data[0][i][s] = src_ptr0[i * 4 + s]; + in_data[1][i][s] = src_ptr1[i * 4 + s]; + in_data[2][i][s] = src_ptr2[i * 4 + s]; + in_data[3][i][s] = src_ptr3[i * 4 + s]; + in_data[4][i][s] = src_ptr4[i * 4 + s]; + in_data[5][i][s] = src_ptr5[i * 4 + s]; + in_data[6][i][s] = src_ptr6[i * 4 + s]; + in_data[7][i][s] = src_ptr7[i * 4 + s]; + } + } + // i: chunks, j: Layout::kCols, s: Layout::Rows. + for (int i = 0; i < 8; ++i) { + for (int j = 0; j < 8; ++j) { + for (int s = 0; s < 4; ++s) { + // 8 * 4 * i is offset for each block, that is + // (Layout::kCols * Layout::kRows * i) + packed_ptr[(8 * i + j) * 4 + s] = in_data[j][i][s] ^ input_xor; + } + if (sums_ptr) { + for (int s = 0; s < 4; ++s) { + sums_ptr[j] += in_data[j][i][s] ^ input_xor; + } + } + } + } + } else if (available_src_rows > 0) { + RUY_DCHECK_LT(available_src_rows, kNumChunkedSrcRows); + int i = 0; + // Consume chunks of 4 rows that are complete. + for (; i < (available_src_rows >> 2); ++i) { + for (int s = 0; s < 4; ++s) { + in_data[0][i][s] = src_ptr0[i * 4 + s]; + in_data[1][i][s] = src_ptr1[i * 4 + s]; + in_data[2][i][s] = src_ptr2[i * 4 + s]; + in_data[3][i][s] = src_ptr3[i * 4 + s]; + in_data[4][i][s] = src_ptr4[i * 4 + s]; + in_data[5][i][s] = src_ptr5[i * 4 + s]; + in_data[6][i][s] = src_ptr6[i * 4 + s]; + in_data[7][i][s] = src_ptr7[i * 4 + s]; + } + } + // Consume any incomplete chunk. + if (i < ((available_src_rows + 3) >> 2)) { + int s = 0; + for (; s < (available_src_rows & 3); ++s) { + in_data[0][i][s] = src_ptr0[i * 4 + s]; + in_data[1][i][s] = src_ptr1[i * 4 + s]; + in_data[2][i][s] = src_ptr2[i * 4 + s]; + in_data[3][i][s] = src_ptr3[i * 4 + s]; + in_data[4][i][s] = src_ptr4[i * 4 + s]; + in_data[5][i][s] = src_ptr5[i * 4 + s]; + in_data[6][i][s] = src_ptr6[i * 4 + s]; + in_data[7][i][s] = src_ptr7[i * 4 + s]; + } + RUY_DCHECK_LE(s, 4); + for (; s < 4; ++s) { + // j: Layout::kCols. + for (int j = 0; j < 8; ++j) { + in_data[j][i][s] = zero_point; + } + } + ++i; + } + // We do not care what goes into the trailing buffer, but we want + // in_data[...] ^ input_xor == 0 for irrelevant values in the summation. + // + // It might prove better in optimized code to pad uniformly with + // zero_point, and compensate by initializing the summations with the + // compensating offset, effectively + // ((input_xor - zero_point) ^ input_xor) * + // 4 * (8 - ((available_src_rows + 3) >> 2)). + for (; i < 8; ++i) { + for (int s = 0; s < 4; ++s) { + for (int j = 0; j < 8; ++j) { + in_data[j][i][s] = input_xor; + } + } + } + // We loop through [0, 8) rather than + // [0, (available_src_rows + 3) >> 2), since that emulates what we might + // do in fully-optimized code. + // + // i: chunks, j: Layout::kCols, s: Layout::Rows. + if (sums_ptr) { + for (int i = 0; i < 8; ++i) { + for (int j = 0; j < 8; ++j) { + for (int s = 0; s < 4; ++s) { + trailing_buf[(8 * i + j) * 4 + s] = in_data[j][i][s] ^ input_xor; + sums_ptr[j] = sums_ptr[j] + (in_data[j][i][s] ^ input_xor); + } + } + } + } else { + for (int i = 0; i < 8; ++i) { + for (int j = 0; j < 8; ++j) { + for (int s = 0; s < 4; ++s) { + trailing_buf[(8 * i + j) * 4 + s] = in_data[j][i][s] ^ input_xor; + } + } + } + } + } + + packed_ptr += 8 * kNumChunkedSrcRows; + src_ptr0 += src_inc0; + src_ptr1 += src_inc1; + src_ptr2 += src_inc2; + src_ptr3 += src_inc3; + src_ptr4 += src_inc4; + src_ptr5 += src_inc5; + src_ptr6 += src_inc6; + src_ptr7 += src_inc7; + } +} + +inline void PackFloatSse42Packer(const float* src_ptr, const float* zerobuf, + int src_stride, int remaining_src_cols, + int src_rows, float* packed_ptr, + float* trailing_buf) { + using Layout = PackImplFloatSse42::Layout; + RUY_DCHECK_EQ(Layout::kCols, 8); + RUY_DCHECK_EQ(Layout::kRows, 1); + + // This packing amounts to tranposition of 8x8 blocks. + static constexpr int kPackCols = 8; // Source cols packed together. + static constexpr int kPackRows = 8; // Short input is padded. + + float in_data[kPackCols][kPackRows]; + + const float* src_ptr0 = src_ptr; + const float* src_ptr1 = src_ptr0 + src_stride; + const float* src_ptr2 = src_ptr1 + src_stride; + const float* src_ptr3 = src_ptr2 + src_stride; + const float* src_ptr4 = src_ptr3 + src_stride; + const float* src_ptr5 = src_ptr4 + src_stride; + const float* src_ptr6 = src_ptr5 + src_stride; + const float* src_ptr7 = src_ptr6 + src_stride; + std::int64_t src_inc0 = 8; + std::int64_t src_inc1 = 8; + std::int64_t src_inc2 = 8; + std::int64_t src_inc3 = 8; + std::int64_t src_inc4 = 8; + std::int64_t src_inc5 = 8; + std::int64_t src_inc6 = 8; + std::int64_t src_inc7 = 8; + // Handle cases where source does not have kPackDim (8) columns. + if (remaining_src_cols < kPackCols) { + if (remaining_src_cols <= 0) { + src_ptr0 = zerobuf; + src_inc0 = 0; + } + if (remaining_src_cols <= 1) { + src_ptr1 = zerobuf; + src_inc1 = 0; + } + if (remaining_src_cols <= 2) { + src_ptr2 = zerobuf; + src_inc2 = 0; + } + if (remaining_src_cols <= 3) { + src_ptr3 = zerobuf; + src_inc3 = 0; + } + if (remaining_src_cols <= 4) { + src_ptr4 = zerobuf; + src_inc4 = 0; + } + if (remaining_src_cols <= 5) { + src_ptr5 = zerobuf; + src_inc5 = 0; + } + if (remaining_src_cols <= 6) { + src_ptr6 = zerobuf; + src_inc6 = 0; + } + src_ptr7 = zerobuf; + src_inc7 = 0; + } + + for (int k = 0; k < src_rows; k += kPackRows) { + const int available_src_rows = src_rows - k; + // Effectively, + // available_src_rows = std::max(0, std::min(kPackDim, src_rows - k)); + // but treat each case separately. + if (available_src_rows >= kPackRows) { + for (int i = 0; i < 8; ++i) { + in_data[0][i] = src_ptr0[i]; + in_data[1][i] = src_ptr1[i]; + in_data[2][i] = src_ptr2[i]; + in_data[3][i] = src_ptr3[i]; + in_data[4][i] = src_ptr4[i]; + in_data[5][i] = src_ptr5[i]; + in_data[6][i] = src_ptr6[i]; + in_data[7][i] = src_ptr7[i]; + } + for (int i = 0; i < 8; ++i) { + for (int j = 0; j < 8; ++j) { + packed_ptr[8 * i + j] = in_data[j][i]; + } + } + } else if (available_src_rows > 0) { + for (int i = 0; i < available_src_rows; ++i) { + in_data[0][i] = src_ptr0[i]; + in_data[1][i] = src_ptr1[i]; + in_data[2][i] = src_ptr2[i]; + in_data[3][i] = src_ptr3[i]; + in_data[4][i] = src_ptr4[i]; + in_data[5][i] = src_ptr5[i]; + in_data[6][i] = src_ptr6[i]; + in_data[7][i] = src_ptr7[i]; + } + for (int i = available_src_rows; i < kPackRows; ++i) { + in_data[0][i] = 0.0f; + in_data[1][i] = 0.0f; + in_data[2][i] = 0.0f; + in_data[3][i] = 0.0f; + in_data[4][i] = 0.0f; + in_data[5][i] = 0.0f; + in_data[6][i] = 0.0f; + in_data[7][i] = 0.0f; + } + // We loop through [0, 7) rather than [0, packed_rows), since that + // emulates what we might do in fully-optimized code. + // i: (kPackRows - 1), j: kPackCols. + for (int i = 0; i < 7; ++i) { + for (int j = 0; j < 8; ++j) { + trailing_buf[kPackRows * i + j] = in_data[j][i]; + } + } + } + + packed_ptr += kPackRows * kPackCols; + src_ptr0 += src_inc0; + src_ptr1 += src_inc1; + src_ptr2 += src_inc2; + src_ptr3 += src_inc3; + src_ptr4 += src_inc4; + src_ptr5 += src_inc5; + src_ptr6 += src_inc6; + src_ptr7 += src_inc7; + } +} + +} // namespace. + +// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. +// Optimization is not finished. In particular the dimensions of the kernel +// blocks can be changed as desired. +// +// When removing this comment, update profiling label below. +void Pack8bitSse42(const std::int8_t* src_ptr, std::int8_t input_xor, + const std::int8_t* zerobuf, int src_stride, + int remaining_src_cols, int src_rows, + std::int8_t* packed_ptr, std::int32_t* sums_ptr) { + profiler::ScopeLabel label("Pack kSse42 8bit (UNFINISHED)"); + + using Layout = PackImpl8bitSse42::Layout; + RUY_DCHECK_EQ(Layout::kCols, 8); + RUY_DCHECK_EQ(Layout::kRows, 4); + + // Each Layout::Rows is 4 contiguous input, contiguous packed elements. + // We process 8 of these chunks at a time, padding short input chunks. + static constexpr int kNumRowChunks = 8; // Short input is padded. + + // Each packed block is 4*8, and there are normally 8. The trailing block is + // only slightly shorter. + constexpr int kTrailingBufSize = + kNumRowChunks * Layout::kCols * Layout::kRows; + std::int8_t trailing_buf[kTrailingBufSize]; + memset(trailing_buf, 0, kTrailingBufSize * sizeof(std::int8_t)); + + Pack8bitSse42Packer(src_ptr, input_xor, zerobuf, src_stride, + remaining_src_cols, src_rows, packed_ptr, sums_ptr, + trailing_buf); + + constexpr int kChunkedRowMask = kNumRowChunks * Layout::kRows - 1; + const bool trailing_data = (src_rows & kChunkedRowMask) > 0; + // If the number of source rows is not a multiple of kChunkedRowMask, there + // will be data in the trailing buffer, + if (trailing_data > 0) { + const int non_trailing_rows = src_rows & ~kChunkedRowMask; + // Destination "rows" are padded to next highest multiple of Layout::kRows. + const int dst_rows = (src_rows + 3) & ~3; + const int trailing_rows = dst_rows - non_trailing_rows; + memcpy(packed_ptr + Layout::kCols * non_trailing_rows, trailing_buf, + Layout::kCols * trailing_rows * sizeof(std::int8_t)); + } +} + +// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. +// Optimization is not finished. In particular the dimensions of the kernel +// blocks can be changed as desired. +// +// When removing this comment, update profiling label below. +void PackFloatSse42(const float* src_ptr, const float* zerobuf, int src_stride, + int remaining_src_cols, int src_rows, float* packed_ptr) { + profiler::ScopeLabel label("Pack kSse42 float (UNFINISHED)"); + static constexpr int kPackCols = 8; // Source cols packed together. + static constexpr int kPackRows = 8; // Short input is padded. + float trailing_buf[(kPackRows - 1) * kPackCols]; + if (remaining_src_cols < 8) { + memset(trailing_buf, 0, sizeof(trailing_buf)); + } + PackFloatSse42Packer(src_ptr, zerobuf, src_stride, remaining_src_cols, + src_rows, packed_ptr, trailing_buf); + + const int trailing_rows = src_rows & (kPackRows - 1); + if (trailing_rows > 0) { + const int non_trailing_rows = src_rows & ~(kPackRows - 1); + memcpy(packed_ptr + kPackCols * non_trailing_rows, trailing_buf, + kPackCols * trailing_rows * sizeof(float)); + } +} + +#endif // RUY_PLATFORM(SSE42) && RUY_OPT_ENABLED(RUY_OPT_INTRINSICS) + +} // namespace ruy diff --git a/ruy/pack_x86.h b/ruy/pack_x86.h new file mode 100644 index 0000000..b777cc1 --- /dev/null +++ b/ruy/pack_x86.h @@ -0,0 +1,461 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// # What is "packing"? +// +// Before feeding data to the gemm kernels (the parts of Ruy that do lots +// of multiply-add operations), Ruy first performs a data transformation (which +// we call "packing") on the input matrices. This transformation has two main +// goals: +// - rearrange data into blocks that are a convenient size/layout for the gemm +// kernels to consume. This helps make the memory access pattern of the gemm +// kernel simpler and more contiguous, and puts the data in a layout most +// convenient for specific arithmetic instructions in the gemm kernel. +// - compute row/column sums needed for handling quantization with non-symmetric +// zero points. +// +// # Simplified algorithmic analysis of packing +// +// Packing is a relatively simple transformation which does a small constant +// amount of work on each element of an input matrix, and hence for an NxM +// matrix performs O(N*M) work. If N and M are of the same order, then this is +// O(N^2) work. +// +// A NxKxM matrix multiplication requires N*K*M multiply-accumulate operations. +// Note that if N, K, and M are all the same order, then the number of +// multiply-accumulate operations is O(N^3). +// +// Thus, the O(N^2) cost of packing is small compared to the O(N^3) work, in the +// case of all dimensions being roughly the same order. +// +// # Packing cost can be significant +// +// When matrix * matrix multiplications begin to look more like matrix * vector +// multiplications, packing cost can become significant. We sometimes call these +// cases "gemv-like". +// +// Continuing the algorithmic analysis above, if we consider a case where an +// NxKxM matrix multiplication has either N = O(1) or M = O(1), then the +// situation is different. In this case, the multiply-accumulate work is only +// quadratic, so the quadratic cost of packing can be come significant. +// +// Another way to say this is that the cost of packing an input matrix (either +// the LHS or RHS) is amortized across the non-depth dimension of the opposite +// input matrix. Thus, when the LHS has very few rows or the RHS has very few +// columns, the cost of packing the opposite input matrix can become +// significant. +// +// As a rough rule of thumb, the cost of packing starts to become significant +// when either N or M is below 32 (and other dimensions are hundreds), with very +// significant packing costs at 8 or below. This varies by data type, Path, and +// tuning, so these numbers are only rough guides. +// +// One practical use case that is affected by this is inference of +// fully connected neural network layers with a low batch size. The weight +// matrix (which is a constant for inference) is the one affected by significant +// packing cost. +// +// Ruy provides an API in ruy_advanced.h for advanced users to pre-pack +// input matrices that are affected by significant packing costs. +// +// # Implementation notes +// +// Ruy's packing routines always operate on a range of columns and can be +// applied to either the LHS or RHS. This is possible because Ruy internally +// implements a TrMul, so the accumulation along depth is done along columns of +// both the LHS and RHS (whereas for a normal Mul the accumulation along depth +// for the LHS is along rows). As another example, we are always computing +// column sums for quantization (and never row sums, since the LHS is +// transposed). + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PACK_X86_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PACK_X86_H_ + +#include +#include +#include + +#include "ruy/check_macros.h" +#include "ruy/common.h" +#include "ruy/internal_matrix.h" +#include "ruy/matrix.h" +#include "ruy/opt_set.h" +#include "ruy/pack_common.h" +#include "ruy/path.h" +#include "ruy/platform.h" +#include "ruy/profiler/instrumentation.h" +#include "ruy/tune.h" + +namespace ruy { + +#if RUY_PLATFORM(X86) +// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. +// Optimization is not finished. In particular the dimensions of the kernel +// blocks can be changed as desired. +// +// Note that source and zero buffers can be uint8 type, but in the packing +// function are reinterpreted as int8, and are XOR-ed with input_xor. +void Pack8bitSse42(const std::int8_t* src_ptr, std::int8_t input_xor, + const std::int8_t* zerobuf, int src_stride, + int remaining_src_cols, int src_rows, + std::int8_t* packed_ptr, std::int32_t* sums_ptr); + +template +struct PackImpl, Scalar, + std::int8_t, std::int32_t> { + static_assert(std::is_same::value || + std::is_same::value, + ""); + using Layout = FixedKernelLayout; + static constexpr std::int8_t kInputXor = + std::is_same::value ? 0 : 0x80; + + static void Run(Tuning tuning, const Matrix& src_matrix, + PackedMatrix* packed_matrix, int start_col, + int end_col) { + profiler::ScopeLabel label("Pack (SSE 4.2 8-bit)"); + + RUY_DCHECK(IsColMajor(src_matrix.layout)); + RUY_DCHECK(IsColMajor(packed_matrix->layout)); + RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0); + RUY_DCHECK_EQ(start_col % Layout::kCols, 0); + std::int32_t* sums = packed_matrix->sums; + Scalar zerobuf[Layout::kCols * Layout::kRows]; + memset(zerobuf, packed_matrix->zero_point ^ kInputXor, + Layout::kCols * Layout::kRows * sizeof(Scalar)); + for (int block_col = start_col; block_col < end_col; + block_col += Layout::kCols) { + std::int32_t* sums_ptr = sums ? sums + block_col : nullptr; + int src_stride = src_matrix.layout.stride; + const Scalar* src_ptr = src_matrix.data.get() + src_stride * block_col; + int remaining_src_cols = src_matrix.layout.cols - block_col; + + static constexpr int block_col_mask = ~(Layout::kCols - 1); // High bits. + std::int8_t* packed_ptr = + packed_matrix->data + + packed_matrix->layout.stride * (block_col & block_col_mask); + Pack8bitSse42(reinterpret_cast(src_ptr), kInputXor, + reinterpret_cast(zerobuf), src_stride, + remaining_src_cols, src_matrix.layout.rows, packed_ptr, + sums_ptr); + } + } +}; + +// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. +// Optimization is not finished. In particular the dimensions of the kernel +// blocks can be changed as desired. +// +void PackFloatSse42(const float* src_ptr, const float* zerobuf, int src_stride, + int remaining_src_cols, int src_rows, float* packed_ptr); + +template <> +struct PackImpl, float, + float, float> { + using Layout = FixedKernelLayout; + static void Run(Tuning, const Matrix& src_matrix, + PackedMatrix* packed_matrix, int start_col, + int end_col) { + profiler::ScopeLabel label("Pack (SSE 4.2 float)"); + + RUY_DCHECK(IsColMajor(src_matrix.layout)); + RUY_DCHECK(IsColMajor(packed_matrix->layout)); + RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0); + RUY_DCHECK_EQ(start_col % Layout::kCols, 0); + const float zerobuf[Layout::kCols] = { + 0.0f}; // Remainder default inits to 0.0f. + for (int block_col = start_col; block_col < end_col; + block_col += Layout::kCols) { + int src_stride = src_matrix.layout.stride; + const float* src_ptr = src_matrix.data.get() + src_stride * block_col; + int remaining_src_cols = src_matrix.layout.cols - block_col; + + static constexpr int block_col_mask = ~(Layout::kCols - 1); // High bits. + float* packed_ptr = + packed_matrix->data + + packed_matrix->layout.stride * (block_col & block_col_mask); + PackFloatSse42(src_ptr, zerobuf, src_stride, remaining_src_cols, + src_matrix.layout.rows, packed_ptr); + } + } +}; + +// Note that source and zero buffers can be uint8 type, but in the packing +// function are reinterpreted as int8, and are XOR-ed with input_xor. +void Pack8bitAvx2(const std::int8_t* src_ptr, std::int8_t input_xor, + const std::int8_t* zerobuf, int src_stride, + int remaining_src_cols, int src_rows, std::int8_t* packed_ptr, + std::int32_t* sums_ptr); + +template +struct PackImpl, Scalar, + std::int8_t, std::int32_t> { + static_assert(std::is_same::value || + std::is_same::value, + ""); + using Layout = FixedKernelLayout; + static constexpr std::int8_t kInputXor = + std::is_same::value ? 0 : 0x80; + + static void Run(Tuning tuning, const Matrix& src_matrix, + PackedMatrix* packed_matrix, int start_col, + int end_col) { + profiler::ScopeLabel label("Pack (AVX2 8-bit)"); + + RUY_DCHECK(IsColMajor(src_matrix.layout)); + RUY_DCHECK(IsColMajor(packed_matrix->layout)); + RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0); + RUY_DCHECK_EQ(start_col % Layout::kCols, 0); + std::int32_t* sums = packed_matrix->sums; + Scalar zerobuf[Layout::kCols * Layout::kRows]; + memset(zerobuf, packed_matrix->zero_point ^ kInputXor, + Layout::kCols * Layout::kRows * sizeof(Scalar)); + for (int block_col = start_col; block_col < end_col; + block_col += Layout::kCols) { + std::int32_t* sums_ptr = sums ? sums + block_col : nullptr; + int src_stride = src_matrix.layout.stride; + const Scalar* src_ptr = src_matrix.data.get() + src_stride * block_col; + int remaining_src_cols = src_matrix.layout.cols - block_col; + + static constexpr int block_col_mask = ~(Layout::kCols - 1); // High bits. + std::int8_t* packed_ptr = + packed_matrix->data + + packed_matrix->layout.stride * (block_col & block_col_mask); + Pack8bitAvx2(reinterpret_cast(src_ptr), kInputXor, + reinterpret_cast(zerobuf), src_stride, + remaining_src_cols, src_matrix.layout.rows, packed_ptr, + sums_ptr); + } + } +}; + +void PackFloatAvx2(const float* src_ptr, const float* zerobuf, int src_stride, + int remaining_src_cols, int src_rows, float* packed_ptr); + +template <> +struct PackImpl, float, + float, float> { + using Layout = FixedKernelLayout; + static void Run(Tuning, const Matrix& src_matrix, + PackedMatrix* packed_matrix, int start_col, + int end_col) { + profiler::ScopeLabel label("Pack (AVX2 float)"); + + RUY_DCHECK(IsColMajor(src_matrix.layout)); + RUY_DCHECK(IsColMajor(packed_matrix->layout)); + RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0); + RUY_DCHECK_EQ(start_col % Layout::kCols, 0); + const float zerobuf[Layout::kCols] = { + 0.0f}; // Remainder default inits to 0.0f. + for (int block_col = start_col; block_col < end_col; + block_col += Layout::kCols) { + int src_stride = src_matrix.layout.stride; + const float* src_ptr = src_matrix.data.get() + src_stride * block_col; + int remaining_src_cols = src_matrix.layout.cols - block_col; + + static constexpr int block_col_mask = ~(Layout::kCols - 1); // High bits. + float* packed_ptr = + packed_matrix->data + + packed_matrix->layout.stride * (block_col & block_col_mask); + PackFloatAvx2(src_ptr, zerobuf, src_stride, remaining_src_cols, + src_matrix.layout.rows, packed_ptr); + } + } +}; + +// Note that source and zero buffers can be uint8 type, but in the packing +// function are reinterpreted as int8, and are XOR-ed with input_xor. +void Pack8bitAvx512(const std::int8_t* src_ptr, std::int8_t input_xor, + const std::int8_t* zerobuf, int src_stride, + int remaining_src_cols, int src_rows, + std::int8_t* packed_ptr, std::int32_t* sums_ptr); + +template +struct PackImpl, + Scalar, std::int8_t, std::int32_t> { + static_assert(std::is_same::value || + std::is_same::value, + ""); + using Layout = FixedKernelLayout; + static constexpr int kHalfLayoutCols = + 8; // Half the number of cols in a block. + static constexpr std::int8_t kInputXor = + std::is_same::value ? 0 : 0x80; + + static void Run(Tuning tuning, const Matrix& src_matrix, + PackedMatrix* packed_matrix, int start_col, + int end_col) { + profiler::ScopeLabel label("Pack (AVX-512 8-bit)"); + + RUY_DCHECK(IsColMajor(src_matrix.layout)); + RUY_DCHECK(IsColMajor(packed_matrix->layout)); + RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0); + RUY_DCHECK_EQ(start_col % Layout::kCols, 0); + RUY_DCHECK_EQ(kHalfLayoutCols * 2, Layout::kCols); + std::int32_t* sums = packed_matrix->sums; + Scalar zerobuf[kHalfLayoutCols * Layout::kRows]; + memset(zerobuf, packed_matrix->zero_point ^ kInputXor, + kHalfLayoutCols * Layout::kRows * sizeof(Scalar)); + for (int block_col = start_col; block_col < end_col; + block_col += Layout::kCols) { + std::int32_t* sums_ptr = sums ? sums + block_col : nullptr; + int src_stride = src_matrix.layout.stride; + const Scalar* src_ptr = src_matrix.data.get() + src_stride * block_col; + int remaining_src_cols = src_matrix.layout.cols - block_col; + + static constexpr int block_col_mask = ~(Layout::kCols - 1); // High bits. + std::int8_t* packed_ptr = + packed_matrix->data + + packed_matrix->layout.stride * (block_col & block_col_mask); + Pack8bitAvx512(reinterpret_cast(src_ptr), kInputXor, + reinterpret_cast(zerobuf), src_stride, + remaining_src_cols, src_matrix.layout.rows, packed_ptr, + sums_ptr); + } + } +}; + +void PackFloatAvx512(const float* src_ptr, const float* zerobuf, int src_stride, + int remaining_src_cols, int src_rows, float* packed_ptr); + +template <> +struct PackImpl, + float, float, float> { + static void Run(Tuning, const Matrix& src_matrix, + PackedMatrix* packed_matrix, int start_col, + int end_col) { + profiler::ScopeLabel label("Pack (AVX-512 float)"); + using Layout = FixedKernelLayout; + RUY_DCHECK(IsColMajor(src_matrix.layout)); + RUY_DCHECK(IsColMajor(packed_matrix->layout)); + RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0); + RUY_DCHECK_EQ(start_col % Layout::kCols, 0); + const float zerobuf[Layout::kCols] = { + 0.0f}; // Remainder default inits to 0.0f. + for (int block_col = start_col; block_col < end_col; + block_col += Layout::kCols) { + int src_stride = src_matrix.layout.stride; + const float* src_ptr = src_matrix.data.get() + src_stride * block_col; + int remaining_src_cols = src_matrix.layout.cols - block_col; + + static constexpr int block_col_mask = ~(Layout::kCols - 1); // High bits. + float* packed_ptr = + packed_matrix->data + + packed_matrix->layout.stride * (block_col & block_col_mask); + PackFloatAvx512(src_ptr, zerobuf, src_stride, remaining_src_cols, + src_matrix.layout.rows, packed_ptr); + } + } +}; + +// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. +// Optimization is not finished. In particular the dimensions of the kernel +// blocks can be changed as desired. +// +// Note that source and zero buffers can be uint8 type, but in the packing +// function are reinterpreted as int8, and are XOR-ed with input_xor. +void Pack8bitAvxVnni(const std::int8_t* src_ptr, std::int8_t input_xor, + const std::int8_t* zerobuf, int src_stride, + int remaining_src_cols, int src_rows, + std::int8_t* packed_ptr, std::int32_t* sums_ptr); + +template +struct PackImpl, + Scalar, std::int8_t, std::int32_t> { + static_assert(std::is_same::value || + std::is_same::value, + ""); + using Layout = FixedKernelLayout; + static constexpr int kHalfLayoutCols = + 8; // Half the number of cols in a block. + static constexpr std::int8_t kInputXor = + std::is_same::value ? 0 : 0x80; + + static void Run(Tuning tuning, const Matrix& src_matrix, + PackedMatrix* packed_matrix, int start_col, + int end_col) { + profiler::ScopeLabel label("Pack (AVX-512 8-bit)"); + + RUY_DCHECK(IsColMajor(src_matrix.layout)); + RUY_DCHECK(IsColMajor(packed_matrix->layout)); + RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0); + RUY_DCHECK_EQ(start_col % Layout::kCols, 0); + RUY_DCHECK_EQ(kHalfLayoutCols * 2, Layout::kCols); + std::int32_t* sums = packed_matrix->sums; + Scalar zerobuf[kHalfLayoutCols * Layout::kRows]; + memset(zerobuf, packed_matrix->zero_point ^ kInputXor, + kHalfLayoutCols * Layout::kRows * sizeof(Scalar)); + for (int block_col = start_col; block_col < end_col; + block_col += Layout::kCols) { + std::int32_t* sums_ptr = sums ? sums + block_col : nullptr; + int src_stride = src_matrix.layout.stride; + const Scalar* src_ptr = src_matrix.data.get() + src_stride * block_col; + int remaining_src_cols = src_matrix.layout.cols - block_col; + + static constexpr int block_col_mask = ~(Layout::kCols - 1); // High bits. + std::int8_t* packed_ptr = + packed_matrix->data + + packed_matrix->layout.stride * (block_col & block_col_mask); + Pack8bitAvxVnni(reinterpret_cast(src_ptr), kInputXor, + reinterpret_cast(zerobuf), src_stride, + remaining_src_cols, src_matrix.layout.rows, packed_ptr, + sums_ptr); + } + } +}; + +// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. +// Optimization is not finished. In particular the dimensions of the kernel +// blocks can be changed as desired. +// +void PackFloatAvxVnni(const float* src_ptr, const float* zerobuf, + int src_stride, int remaining_src_cols, int src_rows, + float* packed_ptr); + +template <> +struct PackImpl, + float, float, float> { + static void Run(Tuning, const Matrix& src_matrix, + PackedMatrix* packed_matrix, int start_col, + int end_col) { + profiler::ScopeLabel label("Pack (AVX-512 float)"); + + using Layout = FixedKernelLayout; + RUY_DCHECK(IsColMajor(src_matrix.layout)); + RUY_DCHECK(IsColMajor(packed_matrix->layout)); + RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0); + RUY_DCHECK_EQ(start_col % Layout::kCols, 0); + const float zerobuf[Layout::kCols] = { + 0.0f}; // Remainder default inits to 0.0f. + for (int block_col = start_col; block_col < end_col; + block_col += Layout::kCols) { + int src_stride = src_matrix.layout.stride; + const float* src_ptr = src_matrix.data.get() + src_stride * block_col; + int remaining_src_cols = src_matrix.layout.cols - block_col; + + static constexpr int block_col_mask = ~(Layout::kCols - 1); // High bits. + float* packed_ptr = + packed_matrix->data + + packed_matrix->layout.stride * (block_col & block_col_mask); + PackFloatAvxVnni(src_ptr, zerobuf, src_stride, remaining_src_cols, + src_matrix.layout.rows, packed_ptr); + } + } +}; +#endif // RUY_PLATFORM(X86) + +} // namespace ruy + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PACK_X86_H_ diff --git a/ruy/path.h b/ruy/path.h new file mode 100644 index 0000000..7141b16 --- /dev/null +++ b/ruy/path.h @@ -0,0 +1,162 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PATH_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PATH_H_ + +#include + +#include "ruy/platform.h" +#include "ruy/size_util.h" + +namespace ruy { + +// A Path is a choice of implementation path, e.g. between reference code +// and optimized code, or between different optimized code paths using different +// instruction sets. +// +// It's important that any symbol that depends on such implementation +// details, is somehow templatized in such a Path, so that different Path values +// yield different symbols, so we never have the situation where a symbols has +// multiple inequivalent definitions based on which code paths are compiled. +// That would be a violation of the ODR (One Definition Rule) which is Undefined +// Behavior, and one of the most serious issues plaguing both Eigen and +// gemmlowp. +// +// This enum is actually a bit-field: aside from kNone, all other values are +// powers of two, thus are one bit each. We define bit-wise operators below +// for this enum. Some places in Ruy accept a Path bit-field where multiple +// Paths may be selected, while some other places require a single Path (i.e. +// just one of the enum values here). Typically, user-facing parts of Ruy +// accept arbitrary bit-fields, allowing the user to compile support for +// multiple paths and to inform Ruy of all the paths that are to be enabled +// at runtime; then, typically in dispatch.h, we internally pick one +// specific path and from there on, internal Ruy code deals with only one +// path. +// +// When a user selects a set of compiled paths, Ruy internally dispatches to the +// "best" one, which typically means the newest optimized instructions for a +// given base architecture (such as ARM). Higher values of this enum correspond +// to "better" code paths within a given base architecture for which Ruy has +// optimized code paths. +// +// Values are reused across architectures. +// Rationale: Scale better to N architectures, it is good to have small values +// both for the compile-time logic to select paths, and when manually spelling +// out Path values, such as when invoking a test or benchmark. +enum class Path : std::uint8_t { + // This is a special null value, representing the absence of any path. + kNone = 0, + // Reference multiplication code. + // The main purpose of this path is to have a very simple standalone Mul + // implementation to check against. + // This path bypasses almost all of Ruy's internal implementation details. + // + // This is intended for testing/development. + kReference = 0x1, + // Standard C++ implementation of Ruy's architecture-specific parts. + // Unlike Path::kReference, this path exercises most of Ruy's internal logic. + // + // This is intended for testing/development. + kStandardCpp = 0x2, + +#if RUY_PLATFORM(ARM) + // ARM architectures. + // + // Optimized path using a widely available subset of ARM NEON instructions. + kNeon = 0x4, + // Optimized path making use of ARM NEON dot product instructions that are + // available on newer ARM cores. + kNeonDotprod = 0x8, +#endif // RUY_PLATFORM(ARM) + +#if RUY_PLATFORM(X86) + // x86 architectures. + // + // TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / + // placeholder. + // Optimization is not finished. In particular the dimensions of the kernel + // blocks can be changed as desired. + // + // Optimized for SSE 4.2. + kSse42 = 0x4, + // Optimized for AVX2. + kAvx2 = 0x8, + // Optimized for AVX-512. + kAvx512 = 0x10, + // TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / + // placeholder. + // Optimization is not finished. In particular the dimensions of the kernel + // blocks can be changed as desired. + // + // Optimized for AVX-VNNI. + kAvxVnni = 0x20, +#endif // RUY_PLATFORM(X86) +}; + +inline constexpr Path operator|(Path p, Path q) { + return static_cast(static_cast(p) | + static_cast(q)); +} + +inline constexpr Path operator&(Path p, Path q) { + return static_cast(static_cast(p) & + static_cast(q)); +} + +inline constexpr Path operator^(Path p, Path q) { + return static_cast(static_cast(p) ^ + static_cast(q)); +} + +inline constexpr Path operator~(Path p) { + return static_cast(~static_cast(p)); +} + +inline Path GetMostSignificantPath(Path path_mask) { + return static_cast(round_down_pot(static_cast(path_mask))); +} + +// ruy::kAllPaths represents all Path's that make sense to on a given +// base architecture. +#ifdef __linux__ +#if RUY_PLATFORM(NEON_64) +constexpr Path kAllPaths = + Path::kReference | Path::kStandardCpp | Path::kNeon | Path::kNeonDotprod; +#elif RUY_PLATFORM(NEON_32) +constexpr Path kAllPaths = Path::kReference | Path::kStandardCpp | Path::kNeon; +#elif RUY_PLATFORM(X86) +constexpr Path kAllPaths = Path::kReference | Path::kStandardCpp | + Path::kSse42 | Path::kAvx2 | Path::kAvx512 | + Path::kAvxVnni; +#else +constexpr Path kAllPaths = Path::kReference | Path::kStandardCpp; +#endif +#else // __linux__ +// We don't know how to do runtime dotprod detection outside of linux for now. +#if RUY_PLATFORM(NEON) +constexpr Path kAllPaths = Path::kReference | Path::kStandardCpp | Path::kNeon; +#elif RUY_PLATFORM(X86) +constexpr Path kAllPaths = Path::kReference | Path::kStandardCpp | + Path::kSse42 | Path::kAvx2 | Path::kAvx512 | + Path::kAvxVnni; +#else +constexpr Path kAllPaths = Path::kReference | Path::kStandardCpp; +#endif +#endif // __linux__ + +} // namespace ruy + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PATH_H_ diff --git a/ruy/platform.h b/ruy/platform.h new file mode 100644 index 0000000..d6e86e6 --- /dev/null +++ b/ruy/platform.h @@ -0,0 +1,156 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PLATFORM_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PLATFORM_H_ + +#ifdef __ANDROID_NDK__ +#include +#endif + +#define RUY_PLATFORM(X) ((RUY_DONOTUSEDIRECTLY_##X) != 0) + +// Architecture-level platform detection. +// +// Ruy requires these to be mutually exclusive. + +// Detect x86. +#if defined(__x86_64__) || defined(__i386__) || defined(__i386) || \ + defined(__x86__) || defined(__X86__) || defined(_X86_) || \ + defined(_M_IX86) || defined(_M_X64) +#define RUY_DONOTUSEDIRECTLY_X86 1 +#else +#define RUY_DONOTUSEDIRECTLY_X86 0 +#endif + +// Detect ARM 32-bit. +#ifdef __arm__ +#define RUY_DONOTUSEDIRECTLY_ARM_32 1 +#else +#define RUY_DONOTUSEDIRECTLY_ARM_32 0 +#endif + +// Detect ARM 64-bit. +#ifdef __aarch64__ +#define RUY_DONOTUSEDIRECTLY_ARM_64 1 +#else +#define RUY_DONOTUSEDIRECTLY_ARM_64 0 +#endif + +// Combined ARM. +#define RUY_DONOTUSEDIRECTLY_ARM \ + (RUY_DONOTUSEDIRECTLY_ARM_64 || RUY_DONOTUSEDIRECTLY_ARM_32) + +// Feature and capability platform detection. +// +// These are mostly sub-selections of architectures. + +// Detect NEON. Explicitly avoid emulation, or anything like it, on x86. +#if (defined(__ARM_NEON) || defined(__ARM_NEON__)) && !RUY_PLATFORM(X86) +#define RUY_DONOTUSEDIRECTLY_NEON 1 +#else +#define RUY_DONOTUSEDIRECTLY_NEON 0 +#endif + +// Define ARM 32-bit NEON. +#define RUY_DONOTUSEDIRECTLY_NEON_32 \ + (RUY_DONOTUSEDIRECTLY_NEON && RUY_DONOTUSEDIRECTLY_ARM_32) + +// Define ARM 64-bit NEON. +// Note: NEON is implied by ARM64, so this define is redundant. +// It still allows some conveyance of intent. +#define RUY_DONOTUSEDIRECTLY_NEON_64 \ + (RUY_DONOTUSEDIRECTLY_NEON && RUY_DONOTUSEDIRECTLY_ARM_64) + +// Disable X86 enhancements on __APPLE__ because b/138922878, see comment #8, we +// may only need to disable this on XCode <= 10.2. +// +// Disable when not using Clang-Linux, because too many user issues arise from +// compilation variations. +// +// NOTE: Consider guarding by !defined(__APPLE__) when removing Linux-only +// restriction. +// +// __EMSCRIPTEN__ is checked because the runtime Path resolution can use asm. +// +// The Android NDK logic excludes earlier and very broken versions of intrinsics +// headers. +#if defined(RUY_FORCE_ENABLE_X86_ENHANCEMENTS) || \ + (defined(__clang__) && (__clang_major__ >= 8) && defined(__linux__) && \ + !defined(__EMSCRIPTEN__) && \ + (!defined(__ANDROID_NDK__) || \ + (defined(__NDK_MAJOR__) && (__NDK_MAJOR__ >= 20)))) +#define RUY_DONOTUSEDIRECTLY_X86_ENHANCEMENTS 1 +#else +#define RUY_DONOTUSEDIRECTLY_X86_ENHANCEMENTS 0 +#endif + +// These CPU capabilities will all be true when Skylake, etc, are enabled during +// compilation. +#if RUY_PLATFORM(X86_ENHANCEMENTS) && RUY_PLATFORM(X86) && \ + defined(__AVX512F__) && defined(__AVX512DQ__) && defined(__AVX512CD__) && \ + defined(__AVX512BW__) && defined(__AVX512VL__) +#define RUY_DONOTUSEDIRECTLY_AVX512 1 +#else +#define RUY_DONOTUSEDIRECTLY_AVX512 0 +#endif + +#if RUY_PLATFORM(X86_ENHANCEMENTS) && RUY_PLATFORM(X86) && defined(__AVX2__) +#define RUY_DONOTUSEDIRECTLY_AVX2 1 +#else +#define RUY_DONOTUSEDIRECTLY_AVX2 0 +#endif + +// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. +// Optimization is not finished. In particular the dimensions of the kernel +// blocks can be changed as desired. +// +// Note does not check for LZCNT or POPCNT. +#if defined(RUY_ENABLE_SSE_ENHANCEMENTS) && RUY_PLATFORM(X86_ENHANCEMENTS) && \ + RUY_PLATFORM(X86) && defined(__SSE4_2__) && defined(__FMA__) +#define RUY_DONOTUSEDIRECTLY_SSE42 1 +#else +#define RUY_DONOTUSEDIRECTLY_SSE42 0 +#endif + +// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. +// Optimization is not finished. In particular the dimensions of the kernel +// blocks can be changed as desired. +// +// Note that defined(__AVX512VBMI2__) can be false for compilation with +// -march=cascadelake. +// TODO(b/146646451) Check if we should also gate on defined(__AVX512VBMI2__). +#if defined(RUY_ENABLE_VNNI_ENHANCEMENTS) && RUY_PLATFORM(AVX512) && \ + defined(__AVX512VNNI__) +#define RUY_DONOTUSEDIRECTLY_AVX_VNNI 1 +#else +#define RUY_DONOTUSEDIRECTLY_AVX_VNNI 0 +#endif + +// Detect APPLE. +#ifdef __APPLE__ +#define RUY_DONOTUSEDIRECTLY_APPLE 1 +#else +#define RUY_DONOTUSEDIRECTLY_APPLE 0 +#endif + +// Detect Emscripten, typically Wasm. +#ifdef __EMSCRIPTEN__ +#define RUY_DONOTUSEDIRECTLY_EMSCRIPTEN 1 +#else +#define RUY_DONOTUSEDIRECTLY_EMSCRIPTEN 0 +#endif + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PLATFORM_H_ diff --git a/ruy/pmu.cc b/ruy/pmu.cc new file mode 100644 index 0000000..1d87b1f --- /dev/null +++ b/ruy/pmu.cc @@ -0,0 +1,281 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/pmu.h" + +#include "ruy/check_macros.h" + +#ifdef __linux__ +#include +#include +#include +#include +#include + +#include +#endif + +#include +#include +#include +#include + +namespace ruy { + +// Linux-specific. Not ARM-specific. +#ifdef __linux__ +class PerfEvent { + public: + PerfEvent(std::uint32_t type, std::uint64_t config) { + perf_event_attr pe; + memset(&pe, 0, sizeof(pe)); + pe.size = sizeof(pe); + pe.type = type; + pe.config = config; + pe.disabled = 1; + pe.exclude_kernel = 1; + pe.exclude_hv = 1; + pe.inherit = 1; + fd_ = syscall(__NR_perf_event_open, &pe, 0, -1, -1, 0); + if (fd_ == -1) { + fprintf(stderr, "perf_event_open failed for config 0x%lx\n", + static_cast(config)); + // abort(); + } + } + + ~PerfEvent() { + RUY_CHECK(!started_); + close(fd_); + } + + void Start() { + RUY_CHECK(!started_); + started_ = true; + ioctl(fd_, PERF_EVENT_IOC_RESET, 0); + ioctl(fd_, PERF_EVENT_IOC_ENABLE, 0); + count_at_start_ = Read(); + } + + void Stop() { + RUY_CHECK(started_); + started_ = false; + ioctl(fd_, PERF_EVENT_IOC_DISABLE, 0); + count_at_stop_ = Read(); + } + + std::int64_t Count() const { + RUY_CHECK(!started_); + return count_at_stop_ - count_at_start_; + } + + private: + std::int64_t Read() const { + std::int64_t count; + RUY_CHECK_NE(read(fd_, &count, sizeof(count)), -1); + return count; + } + std::int64_t count_at_start_ = -1; + std::int64_t count_at_stop_ = -1; + bool started_ = false; + int fd_ = -1; +}; +#else +// Placeholder implementation to at least compile outside of linux. +#define PERF_TYPE_RAW 0 +class PerfEvent { + public: + PerfEvent(std::uint32_t, std::uint64_t) {} + ~PerfEvent() {} + void Start() {} + void Stop() {} + std::int64_t Count() const { return 0; } +}; +#endif + +// ARM-specific. Query ARM PMU counters as Linux perf events using +// PERF_TYPE_RAW. +namespace arm_pmuv3 { + +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-const-variable" + +// These event numbers are listed in the ARMv8 architecture reference manual. +constexpr std::uint16_t L1I_CACHE_REFILL = 0x01; +constexpr std::uint16_t L1I_TLB_REFILL = 0x02; +constexpr std::uint16_t L1D_CACHE_REFILL = 0x03; +constexpr std::uint16_t L1D_CACHE = 0x04; +constexpr std::uint16_t L1D_TLB_REFILL = 0x05; +constexpr std::uint16_t LD_RETIRED = 0x06; +constexpr std::uint16_t ST_RETIRED = 0x07; +constexpr std::uint16_t INST_RETIRED = 0x08; +constexpr std::uint16_t EXC_TAKEN = 0x09; +constexpr std::uint16_t EXC_RETURN = 0x0A; +constexpr std::uint16_t CID_WRITE_RETIRED = 0x0B; +constexpr std::uint16_t PC_WRITE_RETIRED = 0x0C; +constexpr std::uint16_t BR_IMMED_RETIRED = 0x0D; +constexpr std::uint16_t BR_RETURN_RETIRED = 0x0E; +constexpr std::uint16_t UNALIGNED_LDST_RETIRED = 0x0F; +constexpr std::uint16_t BR_MIS_PRED = 0x10; +constexpr std::uint16_t CPU_CYCLES = 0x11; +constexpr std::uint16_t BR_PRED = 0x12; +constexpr std::uint16_t MEM_ACCESS = 0x13; +constexpr std::uint16_t L1I_CACHE = 0x14; +constexpr std::uint16_t L1D_CACHE_WB = 0x15; +constexpr std::uint16_t L2D_CACHE = 0x16; +constexpr std::uint16_t L2D_CACHE_REFILL = 0x17; +constexpr std::uint16_t L2D_CACHE_WB = 0x18; +constexpr std::uint16_t BUS_ACCESS = 0x19; +constexpr std::uint16_t MEMORY_ERROR = 0x1A; +constexpr std::uint16_t INST_SPEC = 0x1B; +constexpr std::uint16_t TTBR_WRITE_RETIRED = 0x1C; +constexpr std::uint16_t BUS_CYCLES = 0x1D; +constexpr std::uint16_t CHAIN = 0x1E; +constexpr std::uint16_t L1D_CACHE_ALLOCATE = 0x1F; +constexpr std::uint16_t L2D_CACHE_ALLOCATE = 0x20; +constexpr std::uint16_t BR_RETIRED = 0x21; +constexpr std::uint16_t BR_MIS_PRED_RETIRED = 0x22; +constexpr std::uint16_t STALL_FRONTEND = 0x23; +constexpr std::uint16_t STALL_BACKEND = 0x24; +constexpr std::uint16_t L1D_TLB = 0x25; +constexpr std::uint16_t L1I_TLB = 0x26; +constexpr std::uint16_t L2I_CACHE = 0x27; +constexpr std::uint16_t L2I_CACHE_REFILL = 0x28; +constexpr std::uint16_t L3D_CACHE_ALLOCATE = 0x29; +constexpr std::uint16_t L3D_CACHE_REFILL = 0x2A; +constexpr std::uint16_t L3D_CACHE = 0x2B; +constexpr std::uint16_t L3D_CACHE_WB = 0x2C; +constexpr std::uint16_t L2D_TLB_REFILL = 0x2D; +constexpr std::uint16_t L2I_TLB_REFILL = 0x2E; +constexpr std::uint16_t L2D_TLB = 0x2F; +constexpr std::uint16_t L2I_TLB = 0x30; +constexpr std::uint16_t LL_CACHE = 0x32; +constexpr std::uint16_t LL_CACHE_MISS = 0x33; +constexpr std::uint16_t DTLB_WALK = 0x34; +constexpr std::uint16_t LL_CACHE_RD = 0x36; +constexpr std::uint16_t LL_CACHE_MISS_RD = 0x37; + +// Additional implementation-defined events found by googling around. +constexpr std::uint16_t L1D_CACHE_RD = 0x40; +constexpr std::uint16_t L1D_CACHE_REFILL_RD = 0x42; +constexpr std::uint16_t L1D_TLB_REFILL_RD = 0x4C; +constexpr std::uint16_t L1D_TLB_RD = 0x4E; +constexpr std::uint16_t L2D_CACHE_RD = 0x50; +constexpr std::uint16_t L2D_CACHE_REFILL_RD = 0x52; +constexpr std::uint16_t BUS_ACCESS_RD = 0x60; +constexpr std::uint16_t MEM_ACCESS_RD = 0x66; +constexpr std::uint16_t L3D_CACHE_RD = 0xA0; +constexpr std::uint16_t L3D_CACHE_REFILL_RD = 0xA2; + +#pragma GCC diagnostic pop + +}; // namespace arm_pmuv3 + +class PmuEventsPrivate { + public: + PmuEventsPrivate() + : l1d_cache_refill(PERF_TYPE_RAW, arm_pmuv3::L1D_CACHE_REFILL), + l2d_cache_refill(PERF_TYPE_RAW, arm_pmuv3::L2D_CACHE_REFILL), + l3d_cache_refill(PERF_TYPE_RAW, arm_pmuv3::L3D_CACHE_REFILL), + ll_cache_miss(PERF_TYPE_RAW, arm_pmuv3::LL_CACHE_MISS), + l1d_tlb_refill(PERF_TYPE_RAW, arm_pmuv3::L1D_TLB_REFILL), + l2d_tlb_refill(PERF_TYPE_RAW, arm_pmuv3::L2D_TLB_REFILL), + stall_frontend(PERF_TYPE_RAW, arm_pmuv3::STALL_FRONTEND), + stall_backend(PERF_TYPE_RAW, arm_pmuv3::STALL_BACKEND), + br_mis_pred(PERF_TYPE_RAW, arm_pmuv3::BR_MIS_PRED) {} + + private: + friend class PmuEvents; + PerfEvent l1d_cache_refill; + PerfEvent l2d_cache_refill; + PerfEvent l3d_cache_refill; + PerfEvent ll_cache_miss; + PerfEvent l1d_tlb_refill; + PerfEvent l2d_tlb_refill; + PerfEvent stall_frontend; + PerfEvent stall_backend; + PerfEvent br_mis_pred; +}; + +PmuEvents::PmuEvents() : priv(new PmuEventsPrivate) {} +PmuEvents::~PmuEvents() { delete priv; } + +void PmuEvents::StartRecording() { + priv->l1d_cache_refill.Start(); + priv->l2d_cache_refill.Start(); + priv->l3d_cache_refill.Start(); + priv->ll_cache_miss.Start(); + priv->l1d_tlb_refill.Start(); + priv->l2d_tlb_refill.Start(); + priv->stall_frontend.Start(); + priv->stall_backend.Start(); + priv->br_mis_pred.Start(); +} + +void PmuEvents::StopRecording() { + priv->l1d_cache_refill.Stop(); + priv->l2d_cache_refill.Stop(); + priv->l3d_cache_refill.Stop(); + priv->ll_cache_miss.Stop(); + priv->l1d_tlb_refill.Stop(); + priv->l2d_tlb_refill.Stop(); + priv->stall_frontend.Stop(); + priv->stall_backend.Stop(); + priv->br_mis_pred.Stop(); +} + +float PmuEvents::BranchMispredictionCount() const { + return static_cast(priv->br_mis_pred.Count()); +} + +float PmuEvents::FrontendStallCount() const { + return static_cast(priv->stall_frontend.Count()); +} + +float PmuEvents::BackendStallCount() const { + return static_cast(priv->stall_backend.Count()); +} + +float PmuEvents::L1RefillCount() const { + return static_cast(priv->l1d_cache_refill.Count()); +} + +float PmuEvents::L2RefillCount() const { + return static_cast(priv->l2d_cache_refill.Count()); +} + +float PmuEvents::L3RefillCount() const { + // Important: this was discovered in the context of the above experiments, + // which also tested the _RD variants of these counters. So it's possible that + // it's just not needed here with the default (non _RD) counters. + // + // Some CPUs implement LL_CACHE_MISS[_RD], some implement + // L3D_CACHE_REFILL[_RD]. It seems that either one of these two counters is + // zero, or they roughly both agree with each other. Therefore, taking the max + // of them is a reasonable way to get something more portable across various + // CPUs. + return static_cast( + std::max(priv->l3d_cache_refill.Count(), priv->ll_cache_miss.Count())); +} + +float PmuEvents::L1TLBRefillCount() const { + return static_cast(priv->l1d_tlb_refill.Count()); +} + +float PmuEvents::L2TLBRefillCount() const { + return static_cast(priv->l2d_tlb_refill.Count()); +} + +} // namespace ruy diff --git a/ruy/pmu.h b/ruy/pmu.h new file mode 100644 index 0000000..721c1d5 --- /dev/null +++ b/ruy/pmu.h @@ -0,0 +1,44 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PMU_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PMU_H_ + +namespace ruy { + +class PmuEventsPrivate; + +class PmuEvents { + public: + PmuEvents(); + ~PmuEvents(); + void StartRecording(); + void StopRecording(); + float L1RefillCount() const; + float L2RefillCount() const; + float L3RefillCount() const; + float BranchMispredictionCount() const; + float FrontendStallCount() const; + float BackendStallCount() const; + float L1TLBRefillCount() const; + float L2TLBRefillCount() const; + + private: + PmuEventsPrivate* priv = nullptr; +}; + +} // namespace ruy + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PMU_H_ diff --git a/ruy/prepack.h b/ruy/prepack.h new file mode 100644 index 0000000..4bfc9ed --- /dev/null +++ b/ruy/prepack.h @@ -0,0 +1,108 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Implementation of low-level pre-packing API. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PREPACK_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PREPACK_H_ + +#include +#include + +#include "ruy/check_macros.h" +#include "ruy/context.h" +#include "ruy/dispatch.h" +#include "ruy/internal_matrix.h" +#include "ruy/matrix.h" +#include "ruy/path.h" +#include "ruy/profiler/instrumentation.h" +#include "ruy/side_pair.h" +#include "ruy/spec.h" +#include "ruy/trmul.h" +#include "ruy/trmul_params.h" +#include "ruy/tune.h" + +namespace ruy { + +template +void PrePackForMulInternal(const Matrix& lhs, + const Matrix& rhs, const Spec& spec, + Context* context, Matrix* dst, + SidePair prepacked, + std::function alloc_fn) { + profiler::ScopeLabel label("PrePackForMul"); + Path the_path = context->GetPathToTake(); + RUY_CHECK_NE(the_path, Path::kReference); + constexpr Path TrMulCompiledPaths = CompiledPaths & ~Path::kReference; + Matrix transposed_lhs(lhs); + Transpose(&transposed_lhs); + TrMulParams params; + CreateTrMulParams(transposed_lhs, rhs, spec, context, dst, + the_path, ¶ms); + + const SidePair origin{0, 0}; + const SidePair rounded_dims{params.packed[Side::kLhs].layout.cols, + params.packed[Side::kRhs].layout.cols}; + + Tuning tuning = context->GetMainThreadTuning(); + for (Side side : {Side::kLhs, Side::kRhs}) { + if (prepacked[side]) { + prepacked[side]->data_size = DataSize(params.packed[side]); + prepacked[side]->sums_size = SumsSize(params.packed[side]); + prepacked[side]->data = alloc_fn(prepacked[side]->data_size); + prepacked[side]->sums = alloc_fn(prepacked[side]->sums_size); + params.packed[side].data = prepacked[side]->data; + params.packed[side].sums = prepacked[side]->sums; + params.RunPack(side, tuning, origin[side], rounded_dims[side]); + } + } +} + +template +void MulWithPrepackedInternal(const Matrix& lhs, + const Matrix& rhs, const Spec& spec, + Context* context, Matrix* dst, + SidePair prepacked) { + profiler::ScopeLabel label("MulWithPrepacked"); + + EnforceLayoutSupport(lhs.layout, rhs.layout, dst->layout); + EnforceZeroPointSupport(lhs.zero_point, rhs.zero_point, + dst->zero_point); + + Path the_path = context->GetPathToTake(); + RUY_CHECK_NE(the_path, Path::kReference); + constexpr Path TrMulCompiledPaths = CompiledPaths & ~Path::kReference; + Matrix transposed_lhs(lhs); + Transpose(&transposed_lhs); + TrMulParams params; + CreateTrMulParams(transposed_lhs, rhs, spec, context, dst, + the_path, ¶ms); + + for (Side side : {Side::kLhs, Side::kRhs}) { + if (prepacked[side]) { + params.packed[side].data = prepacked[side]->data; + params.packed[side].sums = prepacked[side]->sums; + params.is_prepacked[side] = true; + } + } + + TrMul(¶ms, context); +} + +} // namespace ruy + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PREPACK_H_ diff --git a/ruy/prepacked_cache.cc b/ruy/prepacked_cache.cc new file mode 100644 index 0000000..020fdf7 --- /dev/null +++ b/ruy/prepacked_cache.cc @@ -0,0 +1,82 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/prepacked_cache.h" + +#include "ruy/matrix.h" +#include "ruy/profiler/instrumentation.h" + +namespace ruy { + +using CacheIterator = PrepackedCache::CacheIterator; + +// Looks for an entry with `key`. If found, update its time stamp. +CacheIterator PrepackedCache::FindAndUpdate(const CacheKey &key) { + auto itr = cache_.find(key); + // If found, update with new access time for this entry. + if (itr != cache_.end()) { + const TimePoint time = CacheNow(); + itr->second.second = time; + } + return itr; +} + +void PrepackedCache::Insert(const CacheKey &key, + const PrepackedMatrix &matrix) { + // Calculate size of this new item. + const size_t size_bytes = matrix.data_size + matrix.sums_size; + + // While we are above the threshold of ejection, eject the LRU entry. + while (!cache_.empty() && + ((TotalSize() + size_bytes) > ejection_threshold_)) { + EjectOne(); + } + DoInsert(key, matrix); + cache_size_ += matrix.data_size + matrix.sums_size; +} + +void PrepackedCache::EjectOne() { + TimePoint oldest_time = CacheNow(); + auto oldest = cache_.begin(); + { + profiler::ScopeLabel label("PepackedCacheEjection"); + for (auto itr = cache_.begin(); itr != cache_.end(); ++itr) { + if (itr->second.second < oldest_time) { + oldest_time = itr->second.second; + oldest = itr; + } + } + } + PrepackedMatrix &pmatrix = oldest->second.first; + cache_size_ -= pmatrix.data_size; + cache_size_ -= pmatrix.sums_size; + allocator_.Free(pmatrix.data); + allocator_.Free(pmatrix.sums); + cache_.erase(oldest); +} + +void PrepackedCache::AllocatePrepackedMatrix(PrepackedMatrix *pmatrix) { + pmatrix->data = allocator_.Alloc(pmatrix->data_size); + pmatrix->sums = allocator_.Alloc(pmatrix->sums_size); +} + +void PrepackedCache::DoInsert(const CacheKey &key, + const PrepackedMatrix &matrix) { + const TimePoint t = CacheNow(); + const MatrixWithTimeStamp mts({matrix, t}); + cache_.insert({key, mts}); +} + +} // namespace ruy diff --git a/ruy/prepacked_cache.h b/ruy/prepacked_cache.h new file mode 100644 index 0000000..eedd7e4 --- /dev/null +++ b/ruy/prepacked_cache.h @@ -0,0 +1,130 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PREPACKED_CACHE_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PREPACKED_CACHE_H_ + +#include +#include +#include +#include +#include + +#include "ruy/allocator.h" +#include "ruy/matrix.h" +#include "ruy/time.h" + +namespace ruy { + +namespace detail { + +// Tracks a set of blocks allocated from the underlying system allocator. +class SystemBlockAllocator { + public: + void *Alloc(std::ptrdiff_t num_bytes) { + void *p = detail::SystemAlignedAlloc(num_bytes); + blocks_.push_back(p); + return p; + } + + void Free(void *block) { + for (auto it = blocks_.begin(); it != blocks_.end(); ++it) { + if (*it == block) { + detail::SystemAlignedFree(block); + blocks_.erase(it); + return; + } + } + RUY_DCHECK(false); // Trying to free pointer we did not allocate. + } + + ~SystemBlockAllocator() { + for (void *block : blocks_) { + detail::SystemAlignedFree(block); + } + } + + private: + std::vector blocks_; +}; + +} // namespace detail + +enum CachePolicy { kNoCache, kCacheLHSOnNarrowMul }; + +// "Low effort" Least Recently Used Cache for Prepacked Matrices +// A cache mechanism for prepacked matrices that ejects oldest entries. +// The implementation is "low effort" in the following ways: +// - we just linearly search for the oldest entry when doing an ejection +// - the ejection policy is very simple: if the new size would be above the +// . threshold, we will eject entries until the size is below the threshold. +// Current use cases (RNNs with GEMV operations) indicate that ejection is rare +// and memory constraints are tight, so we devote no additional storage to the +// LRU mechanism and accept O(n) search to eject oldest entry. In practice, +// the number of total entries has not been shown to be large. +// This class is not thread safe. In Ruy, memory allocation for packed matrices +// is done in a single threaded context and the actual packing activity may +// be done in a multi-threaded context. +class PrepackedCache { + public: + static constexpr int kDefaultEjectionThresholdBytes = 1 << 28; + + using CacheKey = std::pair; + + using MatrixWithTimeStamp = std::pair; + + using CacheIterator = std::map::const_iterator; + + using AlignedAllocator = detail::AlignedAllocator; + + explicit PrepackedCache( + int32_t ejection_threshold = kDefaultEjectionThresholdBytes) + : ejection_threshold_(ejection_threshold), cache_size_(0) {} + + // Looks for an entry with `key`. If found, update its time stamp. + CacheIterator FindAndUpdate(const CacheKey &key); + + // Returns end iterator for internal cache. The iterator type is appropriate + // to use with `FindAndUpdate`. + CacheIterator cend() const { return cache_.end(); } + + // Returns the total size (in bytes) of data held in this cache. + int TotalSize() const { return cache_size_; } + + // All calls to get current TimePoints go through here. + // TODO(b/145625614) Profile timestamps on relevant models to see if + // this level of granularity is sufficient. CoarseNow is cheap so + // it would be nice to keep it. + TimePoint CacheNow() const { return CoarseNow(); } + + // Performs the memory allocation for the `data` and `sums` members of a + // PrepackedMatrix. + void AllocatePrepackedMatrix(PrepackedMatrix *pmatrix); + + // Adds the PrepackedMatrix to the cache, possibly ejecting other values. + void Insert(const CacheKey &key, const PrepackedMatrix &matrix); + + private: + void EjectOne(); + void DoInsert(const CacheKey &key, const PrepackedMatrix &matrix); + detail::SystemBlockAllocator allocator_; + std::map cache_; + const int32_t ejection_threshold_; + size_t cache_size_; +}; + +} // namespace ruy + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PREPACKED_CACHE_H_ diff --git a/ruy/prepacked_cache_test.cc b/ruy/prepacked_cache_test.cc new file mode 100644 index 0000000..a65841e --- /dev/null +++ b/ruy/prepacked_cache_test.cc @@ -0,0 +1,210 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/prepacked_cache.h" + +#include // NOLINT(build/c++11) + +#include "testing/base/public/gunit.h" +#include "ruy/ruy.h" +#include "ruy/time.h" + +namespace ruy { +namespace { + +TEST(PrepackedCacheTest, TestCacheEjection) { + // Create the cache. + PrepackedCache prepacked_cache(32); + // Allocate the prepacked matrix. + PrepackedMatrix mat1; + mat1.data_size = 16; + mat1.sums_size = 8; + prepacked_cache.AllocatePrepackedMatrix(&mat1); + auto cache_key1 = std::make_pair(nullptr, mat1.data); + prepacked_cache.Insert(cache_key1, mat1); + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + // Get a time point after the insertion into the cache. + TimePoint current = CoarseNow(); + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + PrepackedCache::CacheIterator itr = prepacked_cache.FindAndUpdate(cache_key1); + EXPECT_NE(itr, prepacked_cache.cend()); + // By finding mat1, we updated its timestamp. Verify that `current` is older + // than the time stamp now associated with mat1. + EXPECT_LT(current, itr->second.second); + PrepackedMatrix mat2; + mat2.data_size = 8; + mat2.sums_size = 4; + prepacked_cache.AllocatePrepackedMatrix(&mat2); + + auto cache_key2 = std::make_pair(nullptr, mat2.data); + prepacked_cache.Insert(cache_key2, mat2); + // The cache size was exceeded by inserting mat2. Ensure that mat1 was + // ejected. + EXPECT_EQ(prepacked_cache.FindAndUpdate(cache_key1), prepacked_cache.cend()); +} + +TEST(PrepackedCacheTest, TestCacheBasic) { + // Create the cache. + PrepackedCache prepacked_cache(48); + // Allocate the prepacked matrix. + PrepackedMatrix mat1; + mat1.data_size = 16; + mat1.sums_size = 8; + prepacked_cache.AllocatePrepackedMatrix(&mat1); + + auto cache_key1 = std::make_pair(nullptr, mat1.data); + prepacked_cache.Insert(cache_key1, mat1); + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + EXPECT_NE(prepacked_cache.FindAndUpdate(cache_key1), prepacked_cache.cend()); + + PrepackedMatrix mat2; + mat2.data_size = 8; + mat2.sums_size = 4; + prepacked_cache.AllocatePrepackedMatrix(&mat2); + + auto cache_key2 = std::make_pair(nullptr, mat2.data); + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + prepacked_cache.Insert(cache_key2, mat2); + // The cache size was not exceeded by inserting mat2. Ensure that mat1 was not + // ejected. + EXPECT_NE(prepacked_cache.FindAndUpdate(cache_key1), prepacked_cache.cend()); +} + +TEST(PrepackedCacheTest, TestCacheEjection2) { + // Create the cache. + PrepackedCache prepacked_cache(73); + // Allocate the prepacked matrix 1. + PrepackedMatrix mat1; + mat1.data_size = 16; + mat1.sums_size = 8; + prepacked_cache.AllocatePrepackedMatrix(&mat1); + auto cache_key1 = std::make_pair(nullptr, mat1.data); + prepacked_cache.Insert(cache_key1, mat1); + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + + // Allocate the prepacked matrix 2. + PrepackedMatrix mat2; + mat2.data_size = 16; + mat2.sums_size = 8; + prepacked_cache.AllocatePrepackedMatrix(&mat2); + auto cache_key2 = std::make_pair(nullptr, mat2.data); + prepacked_cache.Insert(cache_key2, mat2); + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + + // Allocate the prepacked matrix 3. + PrepackedMatrix mat31; + mat31.data_size = 16; + mat31.sums_size = 8; + prepacked_cache.AllocatePrepackedMatrix(&mat31); + auto cache_key3 = std::make_pair(nullptr, mat31.data); + prepacked_cache.Insert(cache_key3, mat31); + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + + // The next insertion will cause the cache size to go over the ejection + // threshold. Touch matrix 1 and matrix 3 to make matrix 2 the oldest + EXPECT_NE(prepacked_cache.FindAndUpdate(cache_key1), prepacked_cache.cend()); + EXPECT_NE(prepacked_cache.FindAndUpdate(cache_key3), prepacked_cache.cend()); + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + + // Allocate the prepacked matrix 4. + PrepackedMatrix mat4; + mat4.data_size = 16; + mat4.sums_size = 8; + prepacked_cache.AllocatePrepackedMatrix(&mat4); + auto cache_key4 = std::make_pair(nullptr, mat4.data); + prepacked_cache.Insert(cache_key4, mat4); + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + + // Ensure that mat2 was ejected, but mat1, mat3, and mat4 were not. + EXPECT_EQ(prepacked_cache.FindAndUpdate(cache_key2), prepacked_cache.cend()); + EXPECT_NE(prepacked_cache.FindAndUpdate(cache_key3), prepacked_cache.cend()); + EXPECT_NE(prepacked_cache.FindAndUpdate(cache_key1), prepacked_cache.cend()); + EXPECT_NE(prepacked_cache.FindAndUpdate(cache_key4), prepacked_cache.cend()); +} + +TEST(PrepackedCacheTest, TestCacheOnCacheable) { + // Create context and set the cache policy + ruy::Context context; + context.cache_policy = ruy::kCacheLHSOnNarrowMul; + PrepackedCache* cache = context.GetPrepackedCache(); + EXPECT_EQ(cache->TotalSize(), 0); + + const float lhs_data[] = {1, 2, 3, 4}; + const float rhs_data[] = {1, 2}; + float dst_data[4]; + + ruy::Matrix lhs; + ruy::MakeSimpleLayout(2, 2, ruy::Order::kRowMajor, &lhs.layout); + lhs.data = lhs_data; + ruy::Matrix rhs; + ruy::MakeSimpleLayout(2, 1, ruy::Order::kColMajor, &rhs.layout); + rhs.data = rhs_data; + ruy::Matrix dst; + ruy::MakeSimpleLayout(2, 1, ruy::Order::kColMajor, &dst.layout); + dst.data = dst_data; + + ruy::BasicSpec spec; + // Perform the multiplication and confirm no caching occurred. + ruy::Mul(lhs, rhs, spec, &context, &dst); + EXPECT_EQ(cache->TotalSize(), 0); + + // Set cacheable for the LHS, repeat the multiplication, and see + // that caching did occur. + lhs.cacheable = true; + ruy::Mul(lhs, rhs, spec, &context, &dst); + EXPECT_NE(cache->TotalSize(), 0); +} + +TEST(PrepackedCacheTest, TestClearCache) { + // Create context and set the cache policy + ruy::Context context; + context.cache_policy = ruy::kCacheLHSOnNarrowMul; + PrepackedCache* cache = context.GetPrepackedCache(); + EXPECT_EQ(cache->TotalSize(), 0); + + const float lhs_data[] = {1, 2, 3, 4}; + const float rhs_data[] = {1, 2}; + float dst_data[4]; + + ruy::Matrix lhs; + ruy::MakeSimpleLayout(2, 2, ruy::Order::kRowMajor, &lhs.layout); + lhs.data = lhs_data; + ruy::Matrix rhs; + ruy::MakeSimpleLayout(2, 1, ruy::Order::kColMajor, &rhs.layout); + rhs.data = rhs_data; + ruy::Matrix dst; + ruy::MakeSimpleLayout(2, 1, ruy::Order::kColMajor, &dst.layout); + dst.data = dst_data; + + ruy::BasicSpec spec; + // Set cacheable for the LHS and see that caching occurs. + lhs.cacheable = true; + ruy::Mul(lhs, rhs, spec, &context, &dst); + EXPECT_NE(cache->TotalSize(), 0); + + // Clear the cache via the Context. + context.ClearPrepackedCache(); + // Verify that the cache is now empty. + cache = context.GetPrepackedCache(); + EXPECT_EQ(cache->TotalSize(), 0); +} + +} // namespace +} // namespace ruy + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/ruy/profiler/BUILD b/ruy/profiler/BUILD new file mode 100644 index 0000000..b0af802 --- /dev/null +++ b/ruy/profiler/BUILD @@ -0,0 +1,52 @@ +# A minimalistic profiler sampling pseudo-stacks + +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) + +config_setting( + name = "ruy_profiler", + define_values = {"ruy_profiler": "true"}, +) + +cc_library( + name = "instrumentation", + srcs = ["instrumentation.cc"], + hdrs = ["instrumentation.h"], + defines = select({ + ":ruy_profiler": ["RUY_PROFILER"], + "//conditions:default": [], + }), +) + +cc_library( + name = "profiler", + srcs = [ + "profiler.cc", + "treeview.cc", + ], + hdrs = [ + "profiler.h", + "treeview.h", + ], + deps = [":instrumentation"], +) + +cc_library( + name = "test_instrumented_library", + testonly = True, + srcs = ["test_instrumented_library.cc"], + hdrs = ["test_instrumented_library.h"], + deps = [":instrumentation"], +) + +cc_test( + name = "test", + srcs = ["test.cc"], + deps = [ + ":profiler", + ":test_instrumented_library", + "@com_google_googletest//:gtest", + ], +) diff --git a/ruy/profiler/README.md b/ruy/profiler/README.md new file mode 100644 index 0000000..8d79025 --- /dev/null +++ b/ruy/profiler/README.md @@ -0,0 +1,149 @@ +# A minimalistic profiler sampling pseudo-stacks + +## Overview + +The present directory is the "ruy profiler". As a time profiler, it allows to +measure where code is spending time. + +Contrary to most typical profilers, what it samples is not real call stacks, but +"pseudo-stacks" which are just simple data structures constructed from within +the program being profiled. Using this profiler requires manually instrumenting +code to construct such pseudo-stack information. + +Another unusual characteristic of this profiler is that it uses only the C++11 +standard library. It does not use any non-portable feature, in particular it +does not rely on signal handlers. The sampling is performed by a thread, the +"profiler thread". + +A discussion of pros/cons of this approach is appended below. + +## How to use this profiler + +### How to instrument code + +An example of instrumented code is given in `test_instrumented_library.cc`. + +Code is instrumented by constructing `ScopeLabel` objects. These are RAII +helpers, ensuring that the thread pseudo-stack contains the label during their +lifetime. In the most common use case, one would construct such an object at the +start of a function, so that its scope is the function scope and it allows to +measure how much time is spent in this function. + +```c++ +#include "ruy/profiler/instrumentation.h" + +... + +void SomeFunction() { + ruy::profiling::ScopeLabel function_label("SomeFunction"); + ... do something ... +} +``` + +A `ScopeLabel` may however have any scope, for instance: + +```c++ +if (some_case) { + ruy::profiling::ScopeLabel extra_work_label("Some more work"); + ... do some more work ... +} +``` + +The string passed to the `ScopeLabel` constructor must be just a pointer to a +literal string (a `char*` pointer). The profiler will assume that these pointers +stay valid until the profile is finalized. + +However, that literal string may be a `printf` format string, and labels may +have up to 4 parameters, of type `int`. For example: + +```c++ +void SomeFunction(int size) { + ruy::profiling::ScopeLabel function_label("SomeFunction (size=%d)", size); + +``` + +### How to run the profiler + +Profiling instrumentation is a no-op unless the preprocessor token +`RUY_PROFILER` is defined, so defining it is the first step when actually +profiling. When building with Bazel, the preferred way to enable that is to pass +this flag on the Bazel command line: + +``` +--define=ruy_profiler=true +``` + +To actually profile a code scope, it is enough to construct a `ScopeProfile` +object, also a RAII helper. It will start the profiler on construction, and on +destruction it will terminate the profiler and report the profile treeview on +standard output by default. Example: + +```c++ +void SomeProfiledBenchmark() { + ruy::profiling::ScopeProfile profile; + + CallSomeInstrumentedCode(); +} +``` + +An example is provided by the `:test` target in the present directory. Run it +with `--define=ruy_profiler=true` as explained above: + +``` +bazel run -c opt \ + --define=ruy_profiler=true \ + //tensorflow/lite/experimental/ruy/profiler:test +``` + +The default behavior dumping the treeview on standard output may be overridden +by passing a pointer to a `TreeView` object to the `ScopeProfile` constructor. +This causes the tree-view to be stored in that `TreeView` object, where it may +be accessed an manipulated using the functions declared in `treeview.h`. The +aforementioned `:test` provides examples for doing so. + +## Advantages and inconvenients + +Compared to a traditional profiler, e.g. Linux's "perf", the present kind of +profiler has the following inconvenients: + +* Requires manual instrumentation of code being profiled. +* Substantial overhead, modifying the performance characteristics of the code + being measured. +* Questionable accuracy. + +But also the following advantages: + +* Profiling can be driven from within a benchmark program, allowing the entire + profiling procedure to be a single command line. +* Not relying on symbol information removes removes exposure to toolchain + details and means less hassle in some build environments, especially + embedded/mobile (single command line to run and profile, no symbols files + required). +* Fully portable (all of this is standard C++11). +* Fully testable (see `:test`). Profiling becomes just another feature of the + code like any other. +* Customized instrumentation can result in easier to read treeviews (only + relevant functions, and custom labels may be more readable than function + names). +* Parametrized/formatted labels allow to do things that aren't possible with + call-stack-sampling profilers. For example, break down a profile where much + time is being spent in matrix multiplications, by the various matrix + multiplication shapes involved. + +The philosophy underlying this profiler is that software performance depends on +software engineers profiling often, and a key factor limiting that in practice +is the difficulty or cumbersome aspects of profiling with more serious profilers +such as Linux's "perf", especially in embedded/mobile development: multiple +command lines are involved to copy symbol files to devices, retrieve profile +data from the device, etc. In that context, it is useful to make profiling as +easy as benchmarking, even on embedded targets, even if the price to pay for +that is lower accuracy, higher overhead, and some intrusive instrumentation +requirement. + +Another key aspect determining what profiling approach is suitable for a given +context, is whether one already has a-priori knowledge of where much of the time +is likely being spent. When one has such a-priori knowledge, it is feasible to +instrument the known possibly-critical code as per the present approach. On the +other hand, in situations where one doesn't have such a-priori knowledge, a real +profiler such as Linux's "perf" allows to right away get a profile of real +stacks, from just symbol information generated by the toolchain. diff --git a/ruy/profiler/instrumentation.cc b/ruy/profiler/instrumentation.cc new file mode 100644 index 0000000..f03f667 --- /dev/null +++ b/ruy/profiler/instrumentation.cc @@ -0,0 +1,130 @@ +/* Copyright 2020 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/profiler/instrumentation.h" + +#ifdef RUY_PROFILER + +namespace ruy { +namespace profiler { + +void Label::operator=(const Label& other) { + format_ = other.format_; + args_count_ = other.args_count_; + for (int i = 0; i < args_count_; i++) { + args_[i] = other.args_[i]; + } +} + +bool Label::operator==(const Label& other) const { + if (std::string(format_) != std::string(other.format_)) { + return false; + } + if (args_count_ != other.args_count_) { + return false; + } + for (int i = 0; i < args_count_; i++) { + if (args_[i] != other.args_[i]) { + return false; + } + } + return true; +} + +std::string Label::Formatted() const { + static constexpr int kBufSize = 256; + char buf[kBufSize]; + if (args_count_ == 0) { + return format_; + } + if (args_count_ == 1) { + snprintf(buf, kBufSize, format_, args_[0]); + } else if (args_count_ == 2) { + snprintf(buf, kBufSize, format_, args_[0], args_[1]); + } else if (args_count_ == 3) { + snprintf(buf, kBufSize, format_, args_[0], args_[1], args_[2]); + } else if (args_count_ == 4) { + snprintf(buf, kBufSize, format_, args_[0], args_[1], args_[2], args_[3]); + } else { + abort(); + } + return buf; +} + +namespace detail { + +std::mutex* GlobalsMutex() { + static std::mutex mutex; + return &mutex; +} + +bool& GlobalIsProfilerRunning() { + static bool b; + return b; +} + +std::vector* GlobalAllThreadStacks() { + static std::vector all_stacks; + return &all_stacks; +} + +ThreadStack* ThreadLocalThreadStack() { + thread_local static ThreadStack thread_stack; + return &thread_stack; +} + +ThreadStack::ThreadStack() { + std::lock_guard lock(*GlobalsMutex()); + static std::uint32_t global_next_thread_stack_id = 0; + stack_.id = global_next_thread_stack_id++; + GlobalAllThreadStacks()->push_back(this); +} + +ThreadStack::~ThreadStack() { + std::lock_guard lock(*GlobalsMutex()); + std::vector* all_stacks = GlobalAllThreadStacks(); + for (auto it = all_stacks->begin(); it != all_stacks->end(); ++it) { + if (*it == this) { + all_stacks->erase(it); + return; + } + } +} +int GetBufferSize(const Stack& stack) { + return sizeof(stack.id) + sizeof(stack.size) + + stack.size * sizeof(stack.labels[0]); +} + +void CopyToBuffer(const Stack& stack, char* dst) { + memcpy(dst, &stack.id, sizeof(stack.id)); + dst += sizeof(stack.id); + memcpy(dst, &stack.size, sizeof(stack.size)); + dst += sizeof(stack.size); + memcpy(dst, stack.labels, stack.size * sizeof(stack.labels[0])); +} + +void ReadFromBuffer(const char* src, Stack* stack) { + memcpy(&stack->id, src, sizeof(stack->id)); + src += sizeof(stack->id); + memcpy(&stack->size, src, sizeof(stack->size)); + src += sizeof(stack->size); + memcpy(stack->labels, src, stack->size * sizeof(stack->labels[0])); +} + +} // namespace detail +} // namespace profiler +} // namespace ruy + +#endif diff --git a/ruy/profiler/instrumentation.h b/ruy/profiler/instrumentation.h new file mode 100644 index 0000000..a9046d4 --- /dev/null +++ b/ruy/profiler/instrumentation.h @@ -0,0 +1,203 @@ +/* Copyright 2020 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PROFILER_INSTRUMENTATION_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PROFILER_INSTRUMENTATION_H_ + +#ifdef RUY_PROFILER +#include +#include +#include +#endif + +namespace ruy { +namespace profiler { + +#ifdef RUY_PROFILER + +// A label is how a code scope is annotated to appear in profiles. +// The stacks that are sampled by the profiler are stacks of such labels. +// A label consists of a literal string, plus optional integer arguments. +class Label { + public: + Label() {} + template + explicit Label(Args... args) { + Set(args...); + } + void Set(const char* format) { + format_ = format; + args_count_ = 0; + } + template + void Set(const char* format, Args... args) { + format_ = format; + args_count_ = sizeof...(args); + SetArgs(0, args...); + } + + void operator=(const Label& other); + + bool operator==(const Label& other) const; + + std::string Formatted() const; + const char* format() const { return format_; } + + private: + void SetArgs(int position, int arg0) { args_[position] = arg0; } + + template + void SetArgs(int position, int arg0, Args... args) { + SetArgs(position, arg0); + SetArgs(position + 1, args...); + } + + static constexpr int kMaxArgs = 4; + const char* format_ = nullptr; + int args_count_ = 0; + int args_[kMaxArgs]; +}; + +namespace detail { + +// Forward-declaration, see class ThreadStack below. +class ThreadStack; + +bool& GlobalIsProfilerRunning(); + +// Returns the global vector of pointers to all stacks, there being one stack +// per thread executing instrumented code. +std::vector* GlobalAllThreadStacks(); + +// Returns the mutex to be locked around any access to GlobalAllThreadStacks(). +std::mutex* GlobalsMutex(); + +// Returns the thread-local stack, specific to the current thread. +ThreadStack* ThreadLocalThreadStack(); + +// This 'stack' is what may be more appropriately called a 'pseudostack': +// It contains Label entries that are 'manually' entered by instrumentation +// code. It's unrelated to real call stacks. +struct Stack { + std::uint32_t id = 0; + static constexpr int kMaxSize = 64; + int size = 0; + Label labels[kMaxSize]; +}; + +// Returns the buffer byte size required by CopyToSample. +int GetBufferSize(const Stack& stack); + +// Copies this Stack into a byte buffer, called a 'sample'. +void CopyToBuffer(const Stack& stack, char* dst); + +// Populates this Stack from an existing sample buffer, typically +// produced by CopyToSample. +void ReadFromBuffer(const char* src, Stack* stack); + +// ThreadStack is meant to be used as a thread-local singleton, assigning to +// each thread a Stack object holding its pseudo-stack of profile labels, +// plus a mutex allowing to synchronize accesses to this pseudo-stack between +// this thread and a possible profiler thread sampling it. +class ThreadStack { + public: + ThreadStack(); + ~ThreadStack(); + + const Stack& stack() const { return stack_; } + + // Returns the mutex to lock around any access to this stack. Each stack is + // accessed by potentially two threads: the thread that it belongs to + // (which calls Push and Pop) and the profiler thread during profiling + // (which calls CopyToSample). + std::mutex& Mutex() const { return mutex_; } + + // Pushes a new label on the top of this Stack. + template + void Push(Args... args) { + // This mutex locking is needed to guard against race conditions as both + // the current thread and the profiler thread may be concurrently accessing + // this stack. In addition to that, this mutex locking also serves the other + // purpose of acting as a barrier (of compiler code reordering, of runtime + // CPU instruction reordering, and of memory access reordering), which + // gives a measure of correctness to this profiler. The downside is some + // latency. As this lock will be uncontended most of the times, the cost + // should be roughly that of an sequentially-consistent atomic access, + // comparable to an access to the level of CPU data cache that is shared + // among all cores, typically 60 cycles on current ARM CPUs, plus side + // effects from barrier instructions. + std::lock_guard lock(mutex_); + // Avoid overrunning the stack, even in 'release' builds. This profiling + // instrumentation code should not ship in release builds anyway, the + // overhead of this check is negligible, and overrunning a stack array would + // be bad. + if (stack_.size >= Stack::kMaxSize) { + abort(); + } + stack_.labels[stack_.size++].Set(args...); + } + + // Pops the top-most label from this Stack. + void Pop() { + // See the comment in Push about this lock. While it would be tempting to + // try to remove this lock and just atomically decrement size_ with a + // store-release, that would not necessarily be a substitute for all of the + // purposes that this lock serves, or if it was done carefully to serve all + // of the same purposes, then that wouldn't be faster than this (mostly + // uncontended) lock. + std::lock_guard lock(mutex_); + stack_.size--; + } + + private: + mutable std::mutex mutex_; + Stack stack_; +}; + +} // namespace detail + +// RAII user-facing way to construct Labels associated with their life scope +// and get them pushed to / popped from the current thread stack. +class ScopeLabel { + public: + template + ScopeLabel(Args... args) : thread_stack_(detail::ThreadLocalThreadStack()) { + thread_stack_->Push(args...); + } + + ~ScopeLabel() { thread_stack_->Pop(); } + + private: + detail::ThreadStack* thread_stack_; +}; + +#else // no RUY_PROFILER + +class ScopeLabel { + public: + template + explicit ScopeLabel(Args...) {} + + // This destructor is needed to consistently silence clang's -Wunused-variable + // which seems to trigger semi-randomly. + ~ScopeLabel() {} +}; + +#endif + +} // namespace profiler +} // namespace ruy + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PROFILER_INSTRUMENTATION_H_ diff --git a/ruy/profiler/profiler.cc b/ruy/profiler/profiler.cc new file mode 100644 index 0000000..ae3a2e2 --- /dev/null +++ b/ruy/profiler/profiler.cc @@ -0,0 +1,109 @@ +/* Copyright 2020 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/profiler/profiler.h" + +#ifdef RUY_PROFILER +#include +#include // NOLINT +#include +#include +#include // NOLINT +#include +#endif + +#include "ruy/profiler/instrumentation.h" +#include "ruy/profiler/treeview.h" + +namespace ruy { +namespace profiler { + +#ifdef RUY_PROFILER + +ScopeProfile::ScopeProfile() { Start(); } +ScopeProfile::ScopeProfile(bool enable) { + if (enable) { + Start(); + } +} +ScopeProfile::~ScopeProfile() { + if (!thread_) { + return; + } + finishing_.store(true); + thread_->join(); + Finish(); +} + +void ScopeProfile::Start() { + { + std::lock_guard lock(*detail::GlobalsMutex()); + if (detail::GlobalIsProfilerRunning()) { + fprintf(stderr, "FATAL: profiler already running!\n"); + abort(); + } + detail::GlobalIsProfilerRunning() = true; + } + finishing_ = false; + thread_.reset(new std::thread(&ScopeProfile::ThreadFunc, this)); +} + +void ScopeProfile::ThreadFunc() { + while (!finishing_.load()) { + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + std::lock_guard lock(*detail::GlobalsMutex()); + auto* thread_stacks = detail::GlobalAllThreadStacks(); + for (detail::ThreadStack* thread_stack : *thread_stacks) { + Sample(*thread_stack); + } + } +} + +void ScopeProfile::Sample(const detail::ThreadStack& thread_stack) { + std::lock_guard lock(thread_stack.Mutex()); + // Drop empty stacks. + // This ensures that profiles aren't polluted by uninteresting threads. + if (thread_stack.stack().size == 0) { + return; + } + int sample_size = detail::GetBufferSize(thread_stack.stack()); + int old_buf_size = samples_buf_.size(); + samples_buf_.resize(old_buf_size + sample_size); + detail::CopyToBuffer(thread_stack.stack(), + samples_buf_.data() + old_buf_size); +} + +void ScopeProfile::Finish() { + { + std::lock_guard lock(*detail::GlobalsMutex()); + if (!detail::GlobalIsProfilerRunning()) { + fprintf(stderr, "FATAL: profiler is not running!\n"); + abort(); + } + detail::GlobalIsProfilerRunning() = false; + } + if (user_treeview_) { + user_treeview_->Populate(samples_buf_); + } else { + TreeView treeview; + treeview.Populate(samples_buf_); + Print(treeview); + } +} + +#endif // RUY_PROFILER + +} // namespace profiler +} // namespace ruy diff --git a/ruy/profiler/profiler.h b/ruy/profiler/profiler.h new file mode 100644 index 0000000..b68ca90 --- /dev/null +++ b/ruy/profiler/profiler.h @@ -0,0 +1,106 @@ +/* Copyright 2020 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PROFILER_PROFILER_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PROFILER_PROFILER_H_ + +#include + +#ifdef RUY_PROFILER +#include +#include +#include +#include +#endif + +#include "ruy/profiler/instrumentation.h" +#include "ruy/profiler/treeview.h" + +namespace ruy { +namespace profiler { + +#ifdef RUY_PROFILER + +// RAII user-facing way to create a profiler and let it profile a code scope, +// and print out an ASCII/MarkDown treeview upon leaving the scope. +class ScopeProfile { + public: + // Default constructor, unconditionally profiling. + ScopeProfile(); + + // Constructor allowing to choose at runtime whether to profile. + explicit ScopeProfile(bool enable); + + // Destructor. It's where the profile is reported. + ~ScopeProfile(); + + // See treeview_ member. + void SetUserTreeView(TreeView* treeview) { user_treeview_ = treeview; } + + private: + void Start(); + + // Thread entry point function for the profiler thread. This thread is + // created on construction. + void ThreadFunc(); + + // Record a stack as a sample. + void Sample(const detail::ThreadStack& stack); + + // Finalize the profile. Called on destruction. + // If user_treeview_ is non-null, it will receive the treeview. + // Otherwise the treeview will just be printed. + void Finish(); + + // Buffer where samples are recorded during profiling. + std::vector samples_buf_; + + // Used to synchronize thread termination. + std::atomic finishing_; + + // Underlying profiler thread, which will perform the sampling. + // This profiler approach relies on a thread rather than on signals. + std::unique_ptr thread_; + + // TreeView to populate upon destruction. If left null (the default), + // a temporary treeview will be used and dumped on stdout. The user + // may override that by passing their own TreeView object for other + // output options or to directly inspect the TreeView. + TreeView* user_treeview_ = nullptr; +}; + +#else // no RUY_PROFILER + +struct ScopeProfile { + ScopeProfile() { +#ifdef GEMMLOWP_PROFILING + fprintf( + stderr, + "\n\n\n**********\n\nWARNING:\n\nLooks like you defined " + "GEMMLOWP_PROFILING, but this code has been ported to the new ruy " + "profiler replacing the old gemmlowp profiler. You should now be " + "defining RUY_PROFILER and not GEMMLOWP_PROFILING. When building using " + "Bazel, just pass --define=ruy_profiler=true.\n\n**********\n\n\n"); +#endif + } + explicit ScopeProfile(bool) {} +}; + +#endif + +} // namespace profiler +} // namespace ruy + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PROFILER_PROFILER_H_ diff --git a/ruy/profiler/test.cc b/ruy/profiler/test.cc new file mode 100644 index 0000000..e94840b --- /dev/null +++ b/ruy/profiler/test.cc @@ -0,0 +1,167 @@ +/* Copyright 2020 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "testing/base/public/gunit.h" +#include "ruy/profiler/profiler.h" +#include "ruy/profiler/test_instrumented_library.h" +#include "ruy/profiler/treeview.h" + +namespace ruy { +namespace profiler { +namespace { + +void DoSomeMergeSort(int size) { + std::vector data(size); + + std::default_random_engine engine; + for (auto& val : data) { + val = engine(); + } + + MergeSort(size, data.data()); +} + +// The purpose of this basic test is to cover the basic path that will be taken +// by a majority of users, not inspecting treeviews but just implicitly printing +// them on stdout, and to have this test enabled even when RUY_PROFILER is not +// defined, so that we have coverage for the non-RUY_PROFILER case. +TEST(ProfilerTest, MergeSortSingleThreadBasicTestEvenWithoutProfiler) { + { + ScopeProfile profile; + DoSomeMergeSort(1 << 20); + } +} + +#ifdef RUY_PROFILER + +TEST(ProfilerTest, MergeSortSingleThread) { + TreeView treeview; + { + ScopeProfile profile; + profile.SetUserTreeView(&treeview); + DoSomeMergeSort(1 << 20); + } + Print(treeview); + EXPECT_EQ(treeview.thread_roots().size(), 1); + const auto& thread_root = *treeview.thread_roots().begin()->second; + EXPECT_EQ(DepthOfTreeBelow(thread_root), 22); + EXPECT_GE( + WeightBelowNodeMatchingUnformatted(thread_root, "Merging sorted halves"), + 0.1 * thread_root.weight); + EXPECT_GE(WeightBelowNodeMatchingFormatted( + thread_root, "MergeSortRecurse (level=20, size=1)"), + 0.01 * thread_root.weight); + + TreeView treeview_collapsed; + CollapseNodesMatchingUnformatted(treeview, 5, "MergeSort (size=%d)", + &treeview_collapsed); + Print(treeview_collapsed); + const auto& collapsed_thread_root = + *treeview_collapsed.thread_roots().begin()->second; + EXPECT_EQ(DepthOfTreeBelow(collapsed_thread_root), 6); + EXPECT_EQ( + WeightBelowNodeMatchingUnformatted(thread_root, "MergeSort (size=%d)"), + WeightBelowNodeMatchingUnformatted(collapsed_thread_root, + "MergeSort (size=%d)")); +} + +TEST(ProfilerTest, MemcpyFourThreads) { + TreeView treeview; + { + ScopeProfile profile; + profile.SetUserTreeView(&treeview); + std::vector> threads; + for (int i = 0; i < 4; i++) { + threads.emplace_back(new std::thread([i]() { + ScopeLabel thread_label("worker thread #%d", i); + std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + ScopeLabel some_more_work_label("some more work"); + std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + })); + } + for (int i = 0; i < 4; i++) { + threads[i]->join(); + } + } + Print(treeview); + // Since we cleared GlobalAllThreadStacks and the current thread hasn't + // created any ScopeLabel, only the 4 worker threads should be recorded. + EXPECT_EQ(treeview.thread_roots().size(), 4); + for (const auto& thread_root : treeview.thread_roots()) { + const TreeView::Node& root_node = *thread_root.second; + // The root node may have 1 or 2 children depending on whether there is + // an "[other]" child. + EXPECT_GE(root_node.children.size(), 1); + EXPECT_LE(root_node.children.size(), 2); + const TreeView::Node& child_node = *root_node.children[0]; + EXPECT_EQ(child_node.label.format(), "worker thread #%d"); + // There must be 2 children, since roughly half the time will be in + // "some more work" leaving the other half in "[other]". + EXPECT_EQ(child_node.children.size(), 2); + const TreeView::Node& child_child_node = *child_node.children[0]; + // Since we sample every millisecond and the threads run for >= 2000 + // milliseconds, the "thread func" label should get roughly 2000 samples. + // Not very rigorous, as we're depending on the profiler thread getting + // scheduled, so to avoid this test being flaky, we use a much more + // conservative value of 500, one quarter of that normal value 2000. + EXPECT_GE(child_node.weight, 500); + // Likewise, allow up to four times more than the normal value 2000. + EXPECT_LE(child_node.weight, 8000); + // Roughly half of time should be spent under the "some more work" label. + float some_more_work_percentage = + 100.f * child_child_node.weight / child_node.weight; + EXPECT_GE(some_more_work_percentage, 40.0f); + EXPECT_LE(some_more_work_percentage, 60.0f); + } +} + +TEST(ProfilerTest, OneThreadAfterAnother) { + TreeView treeview; + { + ScopeProfile profile; + profile.SetUserTreeView(&treeview); + { + std::thread thread([]() { + ScopeLabel thread_label("thread 0"); + std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + }); + thread.join(); + } + { + std::thread thread([]() { + ScopeLabel thread_label("thread 1"); + std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + }); + thread.join(); + } + } + Print(treeview); + EXPECT_EQ(treeview.thread_roots().size(), 2); +} + +#endif // RUY_PROFILER + +} // namespace +} // namespace profiler +} // namespace ruy + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/ruy/profiler/test_instrumented_library.cc b/ruy/profiler/test_instrumented_library.cc new file mode 100644 index 0000000..b017ea9 --- /dev/null +++ b/ruy/profiler/test_instrumented_library.cc @@ -0,0 +1,59 @@ +/* Copyright 2020 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "ruy/profiler/instrumentation.h" + +namespace { + +void MergeSortRecurse(int level, int size, int* data, int* workspace) { + ruy::profiler::ScopeLabel function_label( + "MergeSortRecurse (level=%d, size=%d)", level, size); + if (size <= 1) { + return; + } + int half_size = size / 2; + MergeSortRecurse(level + 1, half_size, data, workspace); + MergeSortRecurse(level + 1, size - half_size, data + half_size, + workspace + half_size); + + ruy::profiler::ScopeLabel merging_sorted_halves_label( + "Merging sorted halves"); + int dst_index = 0; + int left_index = 0; + int right_index = half_size; + while (dst_index < size) { + int val; + if (left_index < half_size && + ((right_index >= size) || data[left_index] < data[right_index])) { + val = data[left_index++]; + } else { + val = data[right_index++]; + } + workspace[dst_index++] = val; + } + for (int i = 0; i < size; i++) { + data[i] = workspace[i]; + } +} + +} // namespace + +void MergeSort(int size, int* data) { + ruy::profiler::ScopeLabel function_label("MergeSort (size=%d)", size); + std::vector workspace(size); + MergeSortRecurse(0, size, data, workspace.data()); +} diff --git a/ruy/profiler/test_instrumented_library.h b/ruy/profiler/test_instrumented_library.h new file mode 100644 index 0000000..53d204e --- /dev/null +++ b/ruy/profiler/test_instrumented_library.h @@ -0,0 +1,23 @@ +/* Copyright 2020 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PROFILER_TEST_INSTRUMENTED_LIBRARY_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PROFILER_TEST_INSTRUMENTED_LIBRARY_H_ + +#include "ruy/profiler/instrumentation.h" + +void MergeSort(int size, int* data); + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PROFILER_TEST_INSTRUMENTED_LIBRARY_H_ diff --git a/ruy/profiler/treeview.cc b/ruy/profiler/treeview.cc new file mode 100644 index 0000000..48d922a --- /dev/null +++ b/ruy/profiler/treeview.cc @@ -0,0 +1,248 @@ +/* Copyright 2020 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifdef RUY_PROFILER + +#include "ruy/profiler/treeview.h" + +#include +#include +#include +#include +#include + +namespace ruy { +namespace profiler { + +namespace { + +void SortNode(TreeView::Node* node) { + using NodePtr = std::unique_ptr; + std::sort(node->children.begin(), node->children.end(), + [](const NodePtr& n1, const NodePtr& n2) { + return n1->weight > n2->weight; + }); + for (const auto& child : node->children) { + SortNode(child.get()); + } +} + +// Records a stack i.e. a sample in a treeview, by incrementing the weights +// of matching existing nodes and/or by creating new nodes as needed, +// recursively, below the given node. +void AddStack(const detail::Stack& stack, TreeView::Node* node, int level) { + node->weight++; + if (stack.size == level) { + return; + } + TreeView::Node* child_to_add_to = nullptr; + for (const auto& child : node->children) { + if (child->label == stack.labels[level]) { + child_to_add_to = child.get(); + break; + } + } + if (!child_to_add_to) { + child_to_add_to = node->children.emplace_back(new TreeView::Node).get(); + child_to_add_to->label = stack.labels[level]; + } + AddStack(stack, child_to_add_to, level + 1); +} + +// Recursively populates the treeview below the given node with 'other' +// entries documenting for each node the difference between its weight and the +// sum of its children's weight. +void AddOther(TreeView::Node* node) { + int top_level_children_weight = 0; + for (const auto& child : node->children) { + AddOther(child.get()); + top_level_children_weight += child->weight; + } + if (top_level_children_weight != 0 && + top_level_children_weight != node->weight) { + const auto& new_child = node->children.emplace_back(new TreeView::Node); + new_child->label = Label("[other]"); + new_child->weight = node->weight - top_level_children_weight; + } +} + +} // namespace + +void TreeView::Populate(const std::vector& samples_buf_) { + thread_roots_.clear(); + // Populate the treeview with regular nodes coming from samples. + const char* buf_ptr = samples_buf_.data(); + const char* const buf_ptr_end = buf_ptr + samples_buf_.size(); + while (buf_ptr < buf_ptr_end) { + detail::Stack stack; + detail::ReadFromBuffer(buf_ptr, &stack); + // Empty stacks should have been dropped during sampling. + assert(stack.size > 0); + buf_ptr += GetBufferSize(stack); + const int id = stack.id; + if (!thread_roots_[id]) { + thread_roots_[id].reset(new Node); + } + AddStack(stack, thread_roots_[id].get(), 0); + } + // Populate the treeview with additional 'other' nodes, sort, and set + // root labels. + for (const auto& thread_root : thread_roots_) { + std::uint32_t id = thread_root.first; + Node* root = thread_root.second.get(); + AddOther(root); + SortNode(root); + root->label.Set("Thread %x (%d samples)", id, root->weight); + } +} + +// Recursively prints the treeview below the given node. The 'root' node +// argument is only needed to compute weights ratios, with the root ratio +// as denominator. +void PrintTreeBelow(const TreeView::Node& node, const TreeView::Node& root, + int level) { + if (&node == &root) { + printf("%s\n\n", node.label.Formatted().c_str()); + } else { + for (int i = 1; i < level; i++) { + printf(" "); + } + printf("* %.2f%% %s\n", 100.0f * node.weight / root.weight, + node.label.Formatted().c_str()); + } + for (const auto& child : node.children) { + PrintTreeBelow(*child, root, level + 1); + } +} + +void Print(const TreeView& treeview) { + printf("\n"); + printf("Profile (%d threads):\n\n", + static_cast(treeview.thread_roots().size())); + for (const auto& thread_root : treeview.thread_roots()) { + const TreeView::Node& root = *thread_root.second; + PrintTreeBelow(root, root, 0); + printf("\n"); + } +} + +int DepthOfTreeBelow(const TreeView::Node& node) { + if (node.children.empty()) { + return 0; + } else { + int max_child_depth = 0; + for (const auto& child : node.children) { + max_child_depth = std::max(max_child_depth, DepthOfTreeBelow(*child)); + } + return 1 + max_child_depth; + } +} + +int WeightBelowNodeMatchingFunction( + const TreeView::Node& node, + const std::function& match) { + int weight = 0; + if (match(node.label)) { + weight += node.weight; + } + for (const auto& child : node.children) { + weight += WeightBelowNodeMatchingFunction(*child, match); + } + return weight; +} + +int WeightBelowNodeMatchingUnformatted(const TreeView::Node& node, + const std::string& format) { + return WeightBelowNodeMatchingFunction( + node, [&format](const Label& label) { return label.format() == format; }); +} + +int WeightBelowNodeMatchingFormatted(const TreeView::Node& node, + const std::string& formatted) { + return WeightBelowNodeMatchingFunction( + node, [&formatted](const Label& label) { + return label.Formatted() == formatted; + }); +} + +void CollapseNode(const TreeView::Node& node_in, int depth, + TreeView::Node* node_out) { + node_out->label = node_in.label; + node_out->weight = node_in.weight; + node_out->children.clear(); + if (depth > 0) { + for (const auto& child_in : node_in.children) { + auto* child_out = new TreeView::Node; + node_out->children.emplace_back(child_out); + CollapseNode(*child_in, depth - 1, child_out); + } + } +} + +void CollapseSubnodesMatchingFunction( + const TreeView::Node& node_in, int depth, + const std::function& match, TreeView::Node* node_out) { + if (match(node_in.label)) { + CollapseNode(node_in, depth, node_out); + } else { + node_out->label = node_in.label; + node_out->weight = node_in.weight; + node_out->children.clear(); + + for (const auto& child_in : node_in.children) { + auto* child_out = new TreeView::Node; + node_out->children.emplace_back(child_out); + CollapseSubnodesMatchingFunction(*child_in, depth, match, child_out); + } + } +} + +void CollapseNodesMatchingFunction( + const TreeView& treeview_in, int depth, + const std::function& match, TreeView* treeview_out) { + treeview_out->mutable_thread_roots()->clear(); + for (const auto& thread_root_in : treeview_in.thread_roots()) { + std::uint32_t id = thread_root_in.first; + const auto& root_in = *thread_root_in.second; + auto* root_out = new TreeView::Node; + treeview_out->mutable_thread_roots()->emplace(id, root_out); + CollapseSubnodesMatchingFunction(root_in, depth, match, root_out); + } +} + +void CollapseNodesMatchingUnformatted(const TreeView& treeview_in, int depth, + const std::string& format, + TreeView* treeview_out) { + CollapseNodesMatchingFunction( + treeview_in, depth, + [&format](const Label& label) { return label.format() == format; }, + treeview_out); +} + +void CollapseNodesMatchingFormatted(const TreeView& treeview_in, int depth, + const std::string& formatted, + TreeView* treeview_out) { + CollapseNodesMatchingFunction( + treeview_in, depth, + [&formatted](const Label& label) { + return label.Formatted() == formatted; + }, + treeview_out); +} + +} // namespace profiler +} // namespace ruy + +#endif // RUY_PROFILER diff --git a/ruy/profiler/treeview.h b/ruy/profiler/treeview.h new file mode 100644 index 0000000..e34b4f9 --- /dev/null +++ b/ruy/profiler/treeview.h @@ -0,0 +1,130 @@ +/* Copyright 2020 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PROFILER_TREEVIEW_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PROFILER_TREEVIEW_H_ + +#ifdef RUY_PROFILER + +#include +#include +#include +#include + +#include "ruy/profiler/instrumentation.h" + +namespace ruy { +namespace profiler { + +// A tree view of a profile. +class TreeView { + public: + struct Node { + std::vector> children; + Label label; + int weight = 0; + }; + + void Populate(const std::vector& samples_buf_); + + // Intentionally an *ordered* map so that threads are enumerated + // in an order that's consistent and typically putting the 'main thread' + // first. + using ThreadRootsMap = std::map>; + + const ThreadRootsMap& thread_roots() const { return thread_roots_; } + ThreadRootsMap* mutable_thread_roots() { return &thread_roots_; } + + private: + ThreadRootsMap thread_roots_; +}; + +/* Below are API functions for manipulating and printing treeviews. */ + +// Prints the treeview to stdout. +void Print(const TreeView& treeview); + +// Prints the treeview below the given node on stdout. +void PrintTreeBelow(const TreeView::Node& node); + +// Returns the tree depth below the given node. +int DepthOfTreeBelow(const TreeView::Node& node); + +// Returns the sum of weights of nodes below the given node and filtered by +// the `match` predicate. +int WeightBelowNodeMatchingFunction( + const TreeView::Node& node, const std::function& match); + +// Returns the sum of weights of nodes below the given node and whose +// unformatted label (i.e. raw format string) matches the given `format` string. +// +// This allows to aggregate nodes whose labels differ only by parameter values. +int WeightBelowNodeMatchingUnformatted(const TreeView::Node& node, + const std::string& format); + +// Returns the sum of weights of nodes below the given node and whose formatted +// label matches the `formatted` string. +// +// In the case of nodes with parametrized labels, this allows to count only +// nodes with specific parameter values. For that purpose, one may also instead +// use WeightBelowNodeMatchingFunction directly, with a `match` predicate +// comparing raw integer parameter values directly, instead of going through +// formatted strings. +int WeightBelowNodeMatchingFormatted(const TreeView::Node& node, + const std::string& formatted); + +// Produces a `node_out` that is a copy of `node_in` but with tree depth below +// it clamped at `depth`, with further subtrees aggregated into single leaf +// nodes. +void CollapseNode(const TreeView::Node& node_in, int depth, + TreeView::Node* node_out); + +// Calls CollapseNode with the given `depth` on every subnode filtered by the +// `match` predicate. Note that this does NOT limit the tree depth below +// `node_out` to `depth`, since each collapsed node below `node_out` may be +// arbitrarily far below it and `depth` is only used as the collapsing depth +// at that point. +void CollapseSubnodesMatchingFunction( + const TreeView::Node& node_in, int depth, + const std::function& match, TreeView::Node* node_out); + +// Calls CollapseNode with the given `depth` on every node filtered by the +// `match` predicate. Note that this does NOT limit the tree depth below +// `node_out` to `depth`, since each collapsed node below `node_out` may be +// arbitrarily far below it and `depth` is only used as the collapsing depth +// at that point. +void CollapseNodesMatchingFunction( + const TreeView& treeview_in, int depth, + const std::function& match, TreeView* treeview_out); + +// Special case of CollapseNodesMatchingFunction matching unformatted labels, +// i.e. raw format strings. +// See the comment on WeightBelowNodeMatchingUnformatted. +void CollapseNodesMatchingUnformatted(const TreeView& treeview_in, int depth, + const std::string& format, + TreeView* treeview_out); + +// Special case of CollapseNodesMatchingFunction matching formatted labels. +// See the comment on WeightBelowNodeMatchingFormatted. +void CollapseNodesMatchingFormatted(const TreeView& treeview_in, int depth, + const std::string& formatted, + TreeView* treeview_out); + +} // namespace profiler +} // namespace ruy + +#endif // RUY_PROFILER + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PROFILER_TREEVIEW_H_ diff --git a/ruy/ruy.h b/ruy/ruy.h new file mode 100644 index 0000000..9cafe14 --- /dev/null +++ b/ruy/ruy.h @@ -0,0 +1,42 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This is the only Ruy header that users should #include. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_RUY_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_RUY_H_ + +#include "ruy/context.h" +#include "ruy/dispatch.h" +#include "ruy/matrix.h" +#include "ruy/path.h" +#include "ruy/spec.h" + +namespace ruy { + +// Performs a multiplication of matrices. This is Ruy's only API entry point. +// Should be self-explanatory given the above documentation for each of Matrix, +// Spec and Context. +template +void Mul(const Matrix& lhs, const Matrix& rhs, + const Spec& spec, Context* context, Matrix* dst) { + DispatchMul( + lhs, rhs, spec, context, dst); +} + +} // namespace ruy + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_RUY_H_ diff --git a/ruy/ruy_advanced.h b/ruy/ruy_advanced.h new file mode 100644 index 0000000..124ddd2 --- /dev/null +++ b/ruy/ruy_advanced.h @@ -0,0 +1,69 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_RUY_ADVANCED_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_RUY_ADVANCED_H_ + +#include +#include + +#include "ruy/context.h" +#include "ruy/matrix.h" +#include "ruy/path.h" +#include "ruy/prepack.h" +#include "ruy/side_pair.h" + +namespace ruy { + +// Low-level, explicit pre-packing API. +// +// The cost of packing an input matrix (either the LHS or RHS) is amortized +// across the non-depth dimension of the opposite input matrix. Thus, when the +// LHS has very few rows or the RHS has very few columns, the cost of packing +// the opposite input matrix can become significant. See pack.h for further +// information on packing. +// +// This file provides an API allowing a user to explicitly pack a matrix and +// reuse the pre-packed matrix, avoiding that cost. +// +// See example_prepack.cc for example usage. + +template +void PrePackForMul(const Matrix& lhs, const Matrix& rhs, + const Spec& spec, Context* context, Matrix* dst, + PrepackedMatrix* prepacked_lhs, + PrepackedMatrix* prepacked_rhs, + std::function alloc_fn) { + SidePair prepacked(prepacked_lhs, prepacked_rhs); + PrePackForMulInternal(lhs, rhs, spec, context, dst, prepacked, + alloc_fn); +} + +template +void MulWithPrepacked(const Matrix& lhs, + const Matrix& rhs, const Spec& spec, + Context* context, Matrix* dst, + PrepackedMatrix* prepacked_lhs, + PrepackedMatrix* prepacked_rhs) { + SidePair prepacked(prepacked_lhs, prepacked_rhs); + MulWithPrepackedInternal(lhs, rhs, spec, context, dst, + prepacked); +} + +} // namespace ruy + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_RUY_ADVANCED_H_ diff --git a/ruy/ruy_test.bzl b/ruy/ruy_test.bzl new file mode 100644 index 0000000..ef7e8b1 --- /dev/null +++ b/ruy/ruy_test.bzl @@ -0,0 +1,34 @@ +# Provides the ruy_test macro for type-parametrized tests. +"""ruy_test is a macro for building a test with multiple paths corresponding to tuples of types for LHS, RHS, accumulator and destination.""" + +def ruy_test(name, srcs, lhs_rhs_accum_dst, copts, tags = [], deps = None): + for (lhs, rhs, accum, dst) in lhs_rhs_accum_dst: + native.cc_test( + name = "%s_%s_%s_%s_%s" % (name, lhs, rhs, accum, dst), + srcs = srcs, + copts = copts + [ + "-DRUY_TEST_LHSSCALAR=%s" % lhs, + "-DRUY_TEST_RHSSCALAR=%s" % rhs, + "-DRUY_TEST_ACCUMSCALAR=%s" % accum, + "-DRUY_TEST_DSTSCALAR=%s" % dst, + ], + deps = deps, + tags = tags, + ) + +def ruy_benchmark(name, srcs, lhs_rhs_accum_dst, copts, deps = None): + tags = ["req_dep=//third_party/gemmlowp:profiler"] + for (lhs, rhs, accum, dst) in lhs_rhs_accum_dst: + native.cc_binary( + name = "%s_%s_%s_%s_%s" % (name, lhs, rhs, accum, dst), + testonly = True, + srcs = srcs, + copts = copts + [ + "-DRUY_TEST_LHSSCALAR=%s" % lhs, + "-DRUY_TEST_RHSSCALAR=%s" % rhs, + "-DRUY_TEST_ACCUMSCALAR=%s" % accum, + "-DRUY_TEST_DSTSCALAR=%s" % dst, + ], + deps = deps, + tags = tags, + ) diff --git a/ruy/ruy_test_ext.bzl b/ruy/ruy_test_ext.bzl new file mode 100644 index 0000000..263121f --- /dev/null +++ b/ruy/ruy_test_ext.bzl @@ -0,0 +1,19 @@ +"""Allows to specialize the ruy BUILD to availability of external libraries""" + +def ruy_test_ext_defines(): + return select({ + "//tools/cc_target_os:windows": [], + "//tools/cc_target_os:wasm": [], + "//tools/cc_target_os:chromiumos": ["RUY_TESTING_ON_CHROMIUMOS"], + "//conditions:default": ["RUY_TEST_EXTERNAL_PATHS"], + }) + +def ruy_test_ext_deps(): + return select({ + "//tools/cc_target_os:windows": [], + "//conditions:default": [ + "//third_party/eigen3", + "//third_party/gemmlowp", + "//third_party/lapack:blas", + ], + }) diff --git a/ruy/ruy_test_ext.bzl.opensource b/ruy/ruy_test_ext.bzl.opensource new file mode 100644 index 0000000..5701fff --- /dev/null +++ b/ruy/ruy_test_ext.bzl.opensource @@ -0,0 +1,7 @@ +"""Allows to specialize the ruy BUILD to availability of external libraries""" + +def ruy_test_ext_defines(): + return [] + +def ruy_test_ext_deps(): + return [] diff --git a/ruy/side_pair.h b/ruy/side_pair.h new file mode 100644 index 0000000..e62968b --- /dev/null +++ b/ruy/side_pair.h @@ -0,0 +1,64 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_SIDE_PAIR_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_SIDE_PAIR_H_ + +#include "ruy/check_macros.h" + +namespace ruy { + +// Enumeration of the sides, i.e. the operands 'slots', in a matrix +// multiplication. The numerical values of these enumeration constants matter +// because these will be used as indices into the array underlying a SidePair. +enum class Side { + // Left-hand side + kLhs = 0, + // Right-hand side + kRhs = 1 +}; + +// SidePair is a pair container where the two elements are indexed by a Side +// enum. +template +class SidePair final { + public: + SidePair() {} + SidePair(const T& a, const T& b) : elem_{a, b} {} + const T& operator[](Side side) const { + const int index = static_cast(side); + // Technically this check is vacuous, since other values would be + // out-of-range for enum Side. + RUY_DCHECK(index == 0 || index == 1); + return elem_[index]; + } + + T& operator[](Side side) { + const int index = static_cast(side); + // Technically this check is vacuous, since other values would be + // out-of-range for enum Side. + RUY_DCHECK(index == 0 || index == 1); + return elem_[index]; + } + + private: + static_assert(static_cast(Side::kLhs) == 0, ""); + static_assert(static_cast(Side::kRhs) == 1, ""); + T elem_[2]; +}; + +} // namespace ruy + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_SIDE_PAIR_H_ diff --git a/ruy/size_util.h b/ruy/size_util.h new file mode 100644 index 0000000..2a4bdb9 --- /dev/null +++ b/ruy/size_util.h @@ -0,0 +1,93 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_SIZE_UTIL_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_SIZE_UTIL_H_ + +#include + +#include "ruy/check_macros.h" + +#ifdef _WIN32 +#include +#endif + +namespace ruy { + +template +inline Integer floor_log2(Integer n) { + static_assert(std::is_integral::value, ""); + static_assert(std::is_signed::value, ""); + static_assert(sizeof(Integer) == 4 || sizeof(Integer) == 8, ""); + + RUY_DCHECK_GE(n, 1); +#ifdef _WIN32 + unsigned long result; // NOLINT[runtime/int] + if (sizeof(Integer) == 4) { + _BitScanReverse(&result, n); + } else { + _BitScanReverse64(&result, n); + } + return result; +#else + if (sizeof(Integer) == 4) { + return 31 - __builtin_clz(n); + } else { + return 63 - __builtin_clzll(n); + } +#endif +} + +template +Integer ceil_log2(Integer n) { + RUY_DCHECK_GE(n, 1); + return n == 1 ? 0 : floor_log2(n - 1) + 1; +} + +template +bool is_pot(Integer value) { + return (value > 0) && ((value & (value - 1)) == 0); +} + +template +Integer pot_log2(Integer n) { + RUY_DCHECK(is_pot(n)); + return floor_log2(n); +} + +template +Integer round_down_pot(Integer value) { + return static_cast(1) << floor_log2(value); +} + +template +Integer round_up_pot(Integer value) { + return static_cast(1) << ceil_log2(value); +} + +template +Integer round_down_pot(Integer value, Modulo modulo) { + RUY_DCHECK_EQ(modulo & (modulo - 1), 0); + return value & ~(modulo - 1); +} + +template +Integer round_up_pot(Integer value, Modulo modulo) { + return round_down_pot(value + modulo - 1, modulo); +} + +} // namespace ruy + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_SIZE_UTIL_H_ diff --git a/ruy/size_util_test.cc b/ruy/size_util_test.cc new file mode 100644 index 0000000..54f0c11 --- /dev/null +++ b/ruy/size_util_test.cc @@ -0,0 +1,101 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/size_util.h" + +#include +#include +#include + +#include "testing/base/public/gunit.h" + +namespace ruy { +namespace { + +template +void SizeUtilTestValue(Integer value) { + if (value == 0) { + return; + } + + EXPECT_LE(0, floor_log2(value)); + EXPECT_LE(floor_log2(value), ceil_log2(value)); + EXPECT_LE(ceil_log2(value), 8 * sizeof(Integer)); + + if (is_pot(value)) { + EXPECT_EQ(floor_log2(value), ceil_log2(value)); + EXPECT_EQ(floor_log2(value), pot_log2(value)); + } else { + EXPECT_EQ(floor_log2(value) + 1, ceil_log2(value)); + } + EXPECT_EQ(value >> floor_log2(value), 1); + EXPECT_EQ(round_down_pot(value), static_cast(1) + << floor_log2(value)); + EXPECT_LE(round_down_pot(value), value); + EXPECT_GE(round_down_pot(value), value >> 1); + EXPECT_TRUE(is_pot(round_down_pot(value))); + + if (ceil_log2(value) < 8 * sizeof(Integer) - 1) { + EXPECT_EQ(value >> ceil_log2(value), is_pot(value) ? 1 : 0); + EXPECT_EQ(round_up_pot(value), static_cast(1) << ceil_log2(value)); + EXPECT_GE(round_up_pot(value), value); + EXPECT_LE(round_up_pot(value) >> 1, value); + EXPECT_TRUE(is_pot(round_up_pot(value))); + } + + for (std::uint8_t modulo : {1, 2, 8, 32, 128}) { + EXPECT_GE(value, round_down_pot(value, modulo)); + EXPECT_EQ(round_down_pot(value, modulo) % modulo, 0); + + if (value <= std::numeric_limits::max() - modulo) { + EXPECT_LE(value, round_up_pot(value, modulo)); + EXPECT_EQ(round_up_pot(value, modulo) % modulo, 0); + } + } +} + +template +void SizeUtilTest() { + for (int exponent = 0; exponent < 8 * sizeof(Integer) - 1; exponent++) { + const Integer pot = static_cast(1) << exponent; + SizeUtilTestValue(pot - 1); + SizeUtilTestValue(pot); + SizeUtilTestValue(pot + 1); + SizeUtilTestValue(pot + 12); + SizeUtilTestValue(pot + 123); + } + SizeUtilTestValue(std::numeric_limits::max() - 1); + SizeUtilTestValue(std::numeric_limits::max()); +} + +TEST(SizeUtilTest, Int) { SizeUtilTest(); } + +TEST(SizeUtilTest, Long) { SizeUtilTest(); } // NOLINT + +TEST(SizeUtilTest, LongLong) { SizeUtilTest(); } // NOLINT + +TEST(SizeUtilTest, Int32) { SizeUtilTest(); } + +TEST(SizeUtilTest, Int64) { SizeUtilTest(); } + +TEST(SizeUtilTest, Ptrdiff) { SizeUtilTest(); } + +} // namespace +} // namespace ruy + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/ruy/spec.h b/ruy/spec.h new file mode 100644 index 0000000..d96b6a9 --- /dev/null +++ b/ruy/spec.h @@ -0,0 +1,118 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_SPEC_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_SPEC_H_ + +#include +#include + +#include "ruy/cpu_cache_size.h" +#include "ruy/matrix.h" + +namespace ruy { + +// Our 'general' loop structure (the default) involves multi-threading and +// complicated loops aiming to optimize cache-friendliness. One may opt out of +// this and pick the 'simple' loop structure instead, which only performs well +// for small matrix sizes and only allows using one thread, in exchange for +// smaller code size. +enum class LoopStructure { kGeneral, kSimple, kAuto }; + +// In general we allow zero_point's to have any Scalar value. This is called +// 'asymmetric' quantization. We do take advantage of the optimization +// opportunities when zero_points happen at runtime to be 'symmetric' (e.g. the +// int8 value 0 or the uint8 value 128), but we still generate code to handle +// the general asymmetric case. By choosing kSymmetric here, one opts out of +// this and supports only the symmetric case, in exchange for smaller code size. +enum class ZeroPointSupport { kGeneral, kSymmetric }; + +// In general we allow all Layout's, even if we may use slow paths for some +// kinds of layouts. By choosing kRCC, one may opt out of this and +// only keep support for the simplest and most efficient combination of +// Layout's, in exchange for smaller code size. The case covered by +// kRCC is where the storage orders are exactly the following: +// - LHS is RowMajor +// - RHS is ColMajor +// - Destination is ColMajor +enum class LayoutSupport { kGeneral, kRCC }; + +// A Spec describes all about a matrix multiplication operation that isn't +// encoded in the LHS, RHS and destination matrices. Some of that information +// is encoded as compile-time constants and types (for instance, the choice +// of accumulator type, AccumScalar). Some of that information is encoded as +// runtime values (for instance, the optional bias vector). +template +struct BasicSpec { + // Accumulator type. The type of accumulators used to compute the dot-products + // before being ultimately casted to the destination type. + using AccumScalar = tAccumScalar; + // The destination scalar type. + using DstScalar = tDstScalar; + // The bias vector data, if not null. + const AccumScalar* bias = nullptr; + // Only for non-floating-point cases. The fixed-point part (i.e. the mantissa) + // of the multiplier by which accumulators are multiplied before being casted + // to the destination type. + AccumScalar multiplier_fixedpoint = 0; + // Only for non-floating-point cases. The exponent part of the aforementioned + // multiplier. + int multiplier_exponent = 0; + // Per-channel variant of multiplier_fixedpoint. If not nullptr, this must + // point to a buffer of as many values as there are rows in the destination + // matrix. Each row of the destination matrix will use the corresponding + // buffer element instead of multiplier_fixedpoint. + const AccumScalar* multiplier_fixedpoint_perchannel = nullptr; + // Per-channel variant of multiplier_exponent. If not nullptr, this must + // point to a buffer of as many values as there are rows in the destination + // matrix. Each row of the destination matrix will use the corresponding + // buffer element instead of multiplier_exponent. + // + // Either none or both of multiplier_exponent_perchannel and + // multiplier_fixedpoint_perchannel must be nullptr. + const int* multiplier_exponent_perchannel = nullptr; + // min clamp bound of destination values. + DstScalar clamp_min = std::is_floating_point::value + ? -std::numeric_limits::infinity() + : std::numeric_limits::lowest(); + // max clamp bound of destination values. + DstScalar clamp_max = std::is_floating_point::value + ? std::numeric_limits::infinity() + : std::numeric_limits::max(); + // See above enum LoopStructure + static constexpr LoopStructure kLoopStructure = LoopStructure::kAuto; + // See above enum LayoutSupport + static constexpr LayoutSupport kLayoutSupport = LayoutSupport::kGeneral; + // See above enum ZeroPointSupport + static constexpr ZeroPointSupport kZeroPointSupport = + ZeroPointSupport::kGeneral; + // Testing-only, not meant to be used by actual users: + // Used for testing of various kernel layouts. + using StandardCppKernelLhsLayout = FixedKernelLayout; + using StandardCppKernelRhsLayout = FixedKernelLayout; + // Returns (a reasonable estimate of) the local CPU cache size. + // See ruy::LocalDataCacheSize() which returns some coarse, sane default for + // each CPU architecture. + // This may be overridden, either to provide more accurate/runtime values, + // or to test with other values to let testcases have more coverage. + static int local_data_cache_size() { return LocalDataCacheSize(); } + // Same as local_data_cache_size but for the total data cache size accessible + // to each CPU core. See ruy::SharedDataCacheSize(). + static int shared_data_cache_size() { return SharedDataCacheSize(); } +}; + +} // namespace ruy + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_SPEC_H_ diff --git a/ruy/test.h b/ruy/test.h new file mode 100644 index 0000000..649a0d9 --- /dev/null +++ b/ruy/test.h @@ -0,0 +1,2125 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TEST_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TEST_H_ + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "testing/base/public/gunit.h" // IWYU pragma: export +#include "ruy/matrix.h" // IWYU pragma: export +#include "ruy/platform.h" +#include "ruy/pmu.h" +#include "ruy/ruy.h" +#include "ruy/ruy_advanced.h" +#include "ruy/spec.h" // IWYU pragma: export +#include "ruy/time.h" + +#ifdef RUY_TEST_EXTERNAL_PATHS +#define EIGEN_USE_THREADS +#define EIGEN_USE_CUSTOM_THREAD_POOL +#include "third_party/eigen3/Eigen/Core" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "third_party/gemmlowp/public/gemmlowp.h" +#include "third_party/lapack/blas.h" +#endif + +#ifdef RUY_PROFILER +#include "ruy/profiler/profiler.h" +#endif + +namespace ruy { + +const float kClampRatio = 0.1f; + +enum class ExternalPath { kNone, kGemmlowp, kEigen, kEigenTensor, kOpenBlas }; + +inline std::vector* CoveredPaths() { + static std::vector covered_paths; + return &covered_paths; +} + +inline const char* PathName(Path path) { +#define RUY_PATHNAME_CASE(NAME) \ + case Path::NAME: \ + return #NAME; + switch (path) { + RUY_PATHNAME_CASE(kReference) + RUY_PATHNAME_CASE(kStandardCpp) +#if RUY_PLATFORM(NEON) + RUY_PATHNAME_CASE(kNeon) + RUY_PATHNAME_CASE(kNeonDotprod) +#elif RUY_PLATFORM(X86) + RUY_PATHNAME_CASE(kSse42) + RUY_PATHNAME_CASE(kAvx2) + RUY_PATHNAME_CASE(kAvx512) + RUY_PATHNAME_CASE(kAvxVnni) +#endif + default: + RUY_CHECK(false); + return nullptr; + } +#undef RUY_PATHNAME_CASE +} + +inline const char* TuningName(Tuning tuning) { +#define RUY_SUBPATHNAME_CASE(NAME) \ + case Tuning::NAME: \ + return #NAME; + switch (tuning) { + RUY_SUBPATHNAME_CASE(kInOrder) + RUY_SUBPATHNAME_CASE(kOutOfOrder) + default: + RUY_CHECK(false); + return nullptr; + } +#undef RUY_SUBPATHNAME_CASE +} + +inline const char* PathName(ExternalPath path) { +#define RUY_PATHNAME_CASE(NAME) \ + case ExternalPath::NAME: \ + return #NAME; + switch (path) { + RUY_PATHNAME_CASE(kGemmlowp) + RUY_PATHNAME_CASE(kEigen) + RUY_PATHNAME_CASE(kEigenTensor) + RUY_PATHNAME_CASE(kOpenBlas) + default: + RUY_CHECK(false); + return nullptr; + } +#undef RUY_PATHNAME_CASE +} + +inline std::ostream& operator<<(std::ostream& stream, Path path) { + return stream << PathName(path); +} + +inline std::ostream& operator<<(std::ostream& stream, + ExternalPath external_path) { + return stream << PathName(external_path); +} + +template +std::string Join(const ContainerType& container) { + if (container.empty()) { + return ""; + } + std::ostringstream stream; + auto it = container.begin(); + stream << *it++; + for (; it != container.end(); ++it) { + stream << ", "; + stream << *it; + } + return stream.str(); +} + +struct LogCoveredPathsOnDestruction final { + ~LogCoveredPathsOnDestruction() { + std::cerr << "Covered paths: " << Join(*CoveredPaths()) << std::endl; + + // When testing on ARM64 ChromiumOS emulator, make sure that we covered + // the dotprod path. We're getting such coverage at the moment thanks to + // using a sufficiently recent emulator, and we don't want to regress that. +#if RUY_PLATFORM(ARM_64) && defined RUY_TESTING_ON_CHROMIUMOS + bool found_dotprod = false; + for (const std::string& covered_path : *CoveredPaths()) { + if (covered_path == "kNeonDotprod") { + found_dotprod = true; + } + } + if (!found_dotprod) { + std::cerr + << "Error: we haven't tested the kNeonDotprod path as we should " + "have. At the moment, this is required on ChromiumOS as this is " + "what we run emulator tests in, that currently supports " + "dot-product " + "instructions, and we care very much about not regressing that. " + "If this test was run in an emulator, please upgrade to a newer " + "emulator version. If this test was run on an actual device, and " + "you need to be able to run ruy tests on devices not supporting " + "dot-product instructions, get in touch with us.\n" + << std::endl; + abort(); + } +#endif + } + static void Singleton() { static LogCoveredPathsOnDestruction singleton; } +}; + +enum class RandomRange { + kGeneral, + kAvoidMinValue, + kOffCenterAvoidMinValue, + kReasonableSrcZeroPoint, + kReasonableDstZeroPoint, + kBias +}; + +template ::value> +struct RandomRangeBounds {}; + +template +struct RandomRangeBounds { + static Scalar GetMinBound(RandomRange range) { + switch (range) { + case RandomRange::kGeneral: + return -1; + case RandomRange::kAvoidMinValue: + return -1; + case RandomRange::kOffCenterAvoidMinValue: + return -1; + case RandomRange::kReasonableSrcZeroPoint: + return 0; + case RandomRange::kReasonableDstZeroPoint: + return 0; + case RandomRange::kBias: + return -1; + default: + RUY_CHECK(false); + return 0; + } + } + static Scalar GetMaxBound(RandomRange range) { + switch (range) { + case RandomRange::kGeneral: + return 1; + case RandomRange::kAvoidMinValue: + return 1; + case RandomRange::kOffCenterAvoidMinValue: + return 1; + case RandomRange::kReasonableSrcZeroPoint: + return 0; + case RandomRange::kReasonableDstZeroPoint: + return 0; + case RandomRange::kBias: + return 1; + default: + RUY_CHECK(false); + return 0; + } + } +}; + +template +Scalar WeightedSum(Scalar s1, float weight1, Scalar s2, float weight2) { + float sum = s1 * weight1 + s2 * weight2; + float clamped = std::min( + std::numeric_limits::max(), + std::max(std::numeric_limits::lowest(), sum)); + return static_cast(clamped); +} + +template +Scalar Parametrized(float param) { + return WeightedSum(std::numeric_limits::max(), param, + std::numeric_limits::lowest(), 1 - param); +} + +template +struct RandomRangeBounds { + static Scalar GetMinBound(RandomRange range) { + static constexpr double offcenteredness = + 0.02; // Shift lower limit by about 5 for range of 255. + switch (range) { + case RandomRange::kGeneral: + return std::numeric_limits::lowest(); + case RandomRange::kAvoidMinValue: + return 1 + std::numeric_limits::lowest(); + case RandomRange::kOffCenterAvoidMinValue: + return 1 + std::numeric_limits::lowest() + + static_cast( + offcenteredness * std::numeric_limits::max() - + offcenteredness * + (std::numeric_limits::lowest() + 1)); + case RandomRange::kReasonableSrcZeroPoint: + return std::numeric_limits::lowest(); + case RandomRange::kReasonableDstZeroPoint: + return Parametrized(0.4); + case RandomRange::kBias: + return std::is_same::value + ? static_cast(-10000) + : 0; + default: + RUY_CHECK(false); + return 0; + } + } + static Scalar GetMaxBound(RandomRange range) { + switch (range) { + case RandomRange::kGeneral: + return std::numeric_limits::max(); + case RandomRange::kAvoidMinValue: + return std::numeric_limits::max(); + case RandomRange::kOffCenterAvoidMinValue: + return std::numeric_limits::max(); + case RandomRange::kReasonableSrcZeroPoint: + return std::numeric_limits::max(); + case RandomRange::kReasonableDstZeroPoint: + return Parametrized(0.6); + case RandomRange::kBias: + return std::is_same::value + ? static_cast(10000) + : 0; + default: + RUY_CHECK(false); + return 0; + } + } +}; + +inline std::default_random_engine& global_random_engine() { + static std::default_random_engine engine; + return engine; +} + +template +struct UniformRandomDistribution { + UniformRandomDistribution(RandomRange range) + : dist(RandomRangeBounds::GetMinBound(range), + RandomRangeBounds::GetMaxBound(range)) {} + Scalar Get() { return dist(global_random_engine()); } + // std::uniform_int_distribution is specified not to support char types, + // only short and wider types. MSVC actually generates an error on + // std::uniform_int_distribution. + using StdDistType = typename std::conditional< + std::is_floating_point::value, + std::uniform_real_distribution, + std::uniform_int_distribution>::type; + StdDistType dist; +}; + +template +void MakeRandomScalar(UniformRandomDistribution* uniform_dist, + Scalar* dst) { + *dst = uniform_dist->Get(); +} + +template +void MakeRandomVector(UniformRandomDistribution* uniform_dist, int size, + std::vector* dst) { + dst->resize(size); + for (auto& x : *dst) { + MakeRandomScalar(uniform_dist, &x); + } +} + +template +void MakeRandomScalar(RandomRange range, Scalar* dst) { + UniformRandomDistribution dist(range); + *dst = dist.Get(); + if (range == RandomRange::kReasonableDstZeroPoint || + range == RandomRange::kReasonableSrcZeroPoint) { + if (global_random_engine()() & 1) { + *dst = SymmetricZeroPoint(); + } + } +} + +template +void MakeRandomVector(RandomRange range, int size, std::vector* dst) { + UniformRandomDistribution dist(range); + dst->resize(size); + for (auto& x : *dst) { + MakeRandomScalar(&dist, &x); + } +} + +enum class LayoutStyle { kPackedLinear, kLinear }; + +inline void MakeLayout(int rows, int cols, Order order, + LayoutStyle layout_style, Layout* layout) { + layout->rows = rows; + layout->cols = cols; + layout->order = order; + + const int packed_stride = order == Order::kColMajor ? rows : cols; + + RUY_CHECK(layout_style == LayoutStyle::kPackedLinear || + layout_style == LayoutStyle::kLinear); + if (layout_style == LayoutStyle::kPackedLinear) { + layout->stride = packed_stride; + } else { + layout->stride = packed_stride + 1; + } +} + +template +struct StorageMatrix { + StorageMatrix() = default; + StorageMatrix(const StorageMatrix&) = delete; + void operator=(const StorageMatrix&) = delete; + std::vector data; + Matrix matrix; +}; + +template +void VerifyConsistentFields(const StorageMatrix& storage_matrix) { + if (storage_matrix.data.empty()) { + RUY_CHECK_EQ(storage_matrix.matrix.data.get(), nullptr); + RUY_CHECK_EQ(storage_matrix.matrix.layout.rows, 0); + RUY_CHECK_EQ(storage_matrix.matrix.layout.cols, 0); + } else { + RUY_CHECK_EQ(storage_matrix.matrix.data.get(), storage_matrix.data.data()); + RUY_CHECK_EQ(FlatSize(storage_matrix.matrix.layout), + storage_matrix.data.size()); + } +} + +template +void MakeRandom(int rows, int cols, Order order, Scalar zero_point, + LayoutStyle layout_style, RandomRange range, + StorageMatrix* storage_matrix) { + MakeLayout(rows, cols, order, layout_style, &storage_matrix->matrix.layout); + storage_matrix->matrix.zero_point = zero_point; + UniformRandomDistribution data_dist(range); + MakeRandomVector(&data_dist, FlatSize(storage_matrix->matrix.layout), + &storage_matrix->data); + storage_matrix->matrix.data = storage_matrix->data.data(); + VerifyConsistentFields(*storage_matrix); +} + +template +struct TestResult { + void operator=(const TestResult&) = delete; + void operator=(const TestResult&&) = delete; + StorageMatrix storage_matrix; + Path path = Path::kNone; + Tuning tuning = Tuning::kAuto; + ExternalPath external_path = ExternalPath::kNone; + float latency; + float l1_refill_rate; + float l2_refill_rate; + float l3_refill_rate; + float l1tlb_refill_rate; + float l2tlb_refill_rate; + float mispred_rate; + float frontend_stall_rate; + float backend_stall_rate; + + // Per-path data for pre-packing. + // This is not used by external paths or by Path::kReference. + Allocator allocator; + PrepackedMatrix prepacked_lhs; + PrepackedMatrix prepacked_rhs; + bool use_prepacked_lhs = false; + bool use_prepacked_rhs = false; +}; + +template +std::string PathName(const TestResult& result) { + std::string pathname; + if (result.path != Path::kNone) { + pathname.assign(PathName(result.path)); + } else if (result.external_path != ExternalPath::kNone) { + pathname.assign(PathName(result.external_path)); + } else { + RUY_CHECK(false); + } + if (result.tuning != Tuning::kAuto) { + pathname.append("/"); + pathname.append(TuningName(result.tuning)); + } + return pathname; +} + +enum class ExpectedOutcome { kSuccess, kDeath }; + +template +struct TestSet final { + using LhsScalar = tLhsScalar; + using RhsScalar = tRhsScalar; + using AccumScalar = typename SpecType::AccumScalar; + using DstScalar = typename SpecType::DstScalar; + using Spec = SpecType; + using TestResultType = TestResult; + + void Run() { + MakeZeroPoints(); + MakeLhsRhs(); + MakeSpec(); + MakeOtherParams(); + MakeResultPaths(); + MakePrepackedMatrices(); + Eval(); + Verify(); + } + + private: + void MakeZeroPoints(); + void MakeLhsRhs(); + void MakeSpec(); + void MakeResultPaths(); + void MakePrepackedMatrices(); + void MakeOtherParams(); + void EvalAndVerify(); + void Eval(); + void Verify(); + + void EvalResult(TestResultType* result); + void EvalRuy(TestResultType* result); + void DoMul(TestResultType* result); + void Benchmark(TestResultType* result); + void VerifyTestResults() const; + + public: + enum class LifeStage { + kInitial, + kHasZeroPoints, + kHasLhsRhs, + kHasSpec, + kHasOtherParams, + kHasResultPaths, + kHasPrepackedMatrices, + kEvaluated, + kFinal + }; + + ~TestSet() { + RUY_CHECK_EQ(life_stage, LifeStage::kFinal); + LogCoveredPathsOnDestruction::Singleton(); + } + + LifeStage life_stage = LifeStage::kInitial; + + int rows = 0; + int cols = 0; + int depth = 0; + Order lhs_order = Order::kRowMajor; + Order rhs_order = Order::kColMajor; + Order dst_order = Order::kColMajor; + LayoutStyle layout_style = LayoutStyle::kPackedLinear; + ExpectedOutcome expected_outcome = ExpectedOutcome::kSuccess; + + bool use_specified_zero_points = false; + LhsScalar lhs_zero_point = 0; + RhsScalar rhs_zero_point = 0; + DstScalar dst_zero_point = 0; + + std::vector per_channel_multiplier_fixedpoint; + std::vector per_channel_multiplier_exponent; + + StorageMatrix lhs; + StorageMatrix rhs; + Spec spec; + std::vector bias_data; + std::vector> results; + + std::vector paths; + std::vector external_paths; + + bool benchmark = false; + bool perchannel = false; + int max_num_threads = 0; + bool benchmark_prepack_lhs = false; + bool benchmark_prepack_rhs = false; +}; + +inline PmuEvents& GlobalPmuEvents() { + static PmuEvents pmu; + return pmu; +} + +inline Context& GlobalContext() { + // Ensure that GlobalPmuEvents is constructed before we create any context. + // This ensures that pmu counters are opened before we create any worker + // thread, which is necessary to count events from worker threads. + GlobalPmuEvents(); + + static Context context; + return context; +} + +#if defined(__has_feature) +#if __has_feature(thread_sanitizer) +#define RUY_TSAN +#endif +#if __has_feature(address_sanitizer) +#define RUY_ASAN +#endif +#endif // defined(__has_feature) + +template +void TestSet::DoMul(TestResultType* result) { + Context* context = &GlobalContext(); + + if (!result->use_prepacked_lhs && !result->use_prepacked_rhs) { + Mul(lhs.matrix, rhs.matrix, spec, context, + &result->storage_matrix.matrix); + return; + } + + // If we prepacked an input matrix, null out its data pointer to check + // that we don't access any data through it. + Matrix null_data_lhs = lhs.matrix; + Matrix null_data_rhs = rhs.matrix; + if (result->use_prepacked_lhs) { + null_data_lhs.data = nullptr; + } + if (result->use_prepacked_rhs) { + null_data_rhs.data = nullptr; + } + + // Do the multiplication with pre-packed matrices. + PrepackedMatrix* prepacked_lhs_ptr = + result->use_prepacked_lhs ? &result->prepacked_lhs : nullptr; + PrepackedMatrix* prepacked_rhs_ptr = + result->use_prepacked_rhs ? &result->prepacked_rhs : nullptr; + MulWithPrepacked(null_data_lhs, null_data_rhs, spec, context, + &result->storage_matrix.matrix, prepacked_lhs_ptr, + prepacked_rhs_ptr); +} + +// When building for WAsm, ASSERT_DEATH is not defined. +#ifdef ASSERT_DEATH +#define RUY_ASSERT_DEATH(CONDITION, MESSAGE) ASSERT_DEATH(CONDITION, MESSAGE) +#else +#define RUY_ASSERT_DEATH(CONDITION, MESSAGE) +#endif + +template +void TestSet::EvalRuy(TestResultType* result) { + GlobalContext().explicit_tuning = result->tuning; + if (max_num_threads) { + GlobalContext().max_num_threads = max_num_threads; + } else if (benchmark) { + GlobalContext().max_num_threads = 1; + } else { + GlobalContext().max_num_threads = 1 + global_random_engine()() % 8; + } + GlobalContext().SetRuntimeEnabledPaths(result->path); + if (expected_outcome == ExpectedOutcome::kSuccess) { + DoMul(result); + RUY_CHECK_EQ(GlobalContext().last_taken_path, result->path); + } else if (expected_outcome == ExpectedOutcome::kDeath) { + // TODO(benoitjacob) TSan and ASan seem to be breaking ASSERT_DEATH. + // Report a bug? +#if (!defined NDEBUG) && (!defined RUY_ASAN) && (!defined RUY_TSAN) + RUY_ASSERT_DEATH(DoMul(result), ""); +#endif + } else { + RUY_CHECK(false); + } + GlobalContext().explicit_tuning = Tuning::kAuto; + GlobalContext().max_num_threads = 1; +} + +#ifdef RUY_TEST_EXTERNAL_PATHS + +template +void WrapGemmlowp(const Matrix& src, + gemmlowp::MatrixMap* dst) { + RUY_CHECK(src.layout.order == (tOrder == gemmlowp::MapOrder::ColMajor + ? Order::kColMajor + : Order::kRowMajor)); + *dst = gemmlowp::MatrixMap( + src.data.get(), src.layout.rows, src.layout.cols, src.layout.stride); +} + +template +void WrapGemmlowpMutable(Matrix* src, + gemmlowp::MatrixMap* dst) { + RUY_CHECK(src->layout.order == (tOrder == gemmlowp::MapOrder::ColMajor + ? Order::kColMajor + : Order::kRowMajor)); + *dst = gemmlowp::MatrixMap( + src->data.get(), src->layout.rows, src->layout.cols, src->layout.stride); +} + +template +struct GemmlowpOrder {}; + +template <> +struct GemmlowpOrder { + static constexpr gemmlowp::MapOrder kValue = gemmlowp::MapOrder::ColMajor; +}; + +template <> +struct GemmlowpOrder { + static constexpr gemmlowp::MapOrder kValue = gemmlowp::MapOrder::RowMajor; +}; + +inline gemmlowp::GemmContext& GlobalGemmlowpContext() { + static gemmlowp::GemmContext context; + return context; +} + +template +void EvalGemmlowp(const Matrix& lhs, const Matrix& rhs, + const Spec& spec, int max_num_threads, + Matrix* dst) { + static constexpr gemmlowp::MapOrder kGemmlowpLhsOrder = + GemmlowpOrder::kValue; + static constexpr gemmlowp::MapOrder kGemmlowpRhsOrder = + GemmlowpOrder::kValue; + static constexpr gemmlowp::MapOrder kGemmlowpDstOrder = + GemmlowpOrder::kValue; + gemmlowp::MatrixMap gemmlowp_lhs; + gemmlowp::MatrixMap gemmlowp_rhs; + gemmlowp::MatrixMap gemmlowp_dst; + WrapGemmlowp(lhs, &gemmlowp_lhs); + WrapGemmlowp(rhs, &gemmlowp_rhs); + WrapGemmlowpMutable(dst, &gemmlowp_dst); + + gemmlowp::OutputStageScaleInt32ByFixedPointAndExponent quantize_down_stage; + quantize_down_stage.result_offset_after_shift = dst->zero_point; + quantize_down_stage.result_fixedpoint_multiplier = spec.multiplier_fixedpoint; + quantize_down_stage.result_exponent = spec.multiplier_exponent; + gemmlowp::OutputStageScaleInt32ByFixedPointAndExponentPC< + gemmlowp::VectorShape::Col> + quantize_down_stage_pc; + quantize_down_stage_pc.result_offset_after_shift = dst->zero_point; + using ColVectorMap = + gemmlowp::VectorMap; + quantize_down_stage_pc.result_fixedpoint_multiplier = + ColVectorMap(spec.multiplier_fixedpoint_perchannel, lhs.layout.rows); + quantize_down_stage_pc.result_exponent = + ColVectorMap(spec.multiplier_exponent_perchannel, lhs.layout.rows); + + gemmlowp::OutputStageClamp clamp_stage; + clamp_stage.min = spec.clamp_min; + clamp_stage.max = spec.clamp_max; + using OutputStageSaturatingCast = typename std::conditional< + std::is_same::value, + gemmlowp::OutputStageSaturatingCastToUint8, + gemmlowp::OutputStageSaturatingCastToInt16>::type; + OutputStageSaturatingCast saturating_cast_stage; + + GlobalGemmlowpContext().set_max_num_threads(max_num_threads ? max_num_threads + : 1); + if (spec.bias) { + using ColVectorMap = + gemmlowp::VectorMap; + gemmlowp::OutputStageBiasAddition bias_add_stage; + bias_add_stage.bias_vector = ColVectorMap(spec.bias, dst->layout.rows); +#ifndef GEMMLOWP_SSE4 // gemmlowp perchannel stuff does not build on SSE + if (spec.multiplier_exponent_perchannel) { + const auto& output_pipeline = + std::make_tuple(bias_add_stage, quantize_down_stage_pc, clamp_stage, + saturating_cast_stage); + gemmlowp::GemmWithOutputPipeline< + LhsScalar, DstScalar, gemmlowp::L8R8WithLhsNonzeroBitDepthParams>( + &GlobalGemmlowpContext(), gemmlowp_lhs, gemmlowp_rhs, &gemmlowp_dst, + -lhs.zero_point, -rhs.zero_point, output_pipeline); + } else // NOLINT[readability/braces] +#endif + { + const auto& output_pipeline = + std::make_tuple(bias_add_stage, quantize_down_stage, clamp_stage, + saturating_cast_stage); + gemmlowp::GemmWithOutputPipeline< + LhsScalar, DstScalar, gemmlowp::L8R8WithLhsNonzeroBitDepthParams>( + &GlobalGemmlowpContext(), gemmlowp_lhs, gemmlowp_rhs, &gemmlowp_dst, + -lhs.zero_point, -rhs.zero_point, output_pipeline); + } + } else { +#ifndef GEMMLOWP_SSE4 // gemmlowp perchannel stuff does not build on SSE + if (spec.multiplier_exponent_perchannel) { + const auto& output_pipeline = std::make_tuple( + quantize_down_stage_pc, clamp_stage, saturating_cast_stage); + gemmlowp::GemmWithOutputPipeline< + LhsScalar, DstScalar, gemmlowp::L8R8WithLhsNonzeroBitDepthParams>( + &GlobalGemmlowpContext(), gemmlowp_lhs, gemmlowp_rhs, &gemmlowp_dst, + -lhs.zero_point, -rhs.zero_point, output_pipeline); + } else // NOLINT[readability/braces] +#endif + { + const auto& output_pipeline = std::make_tuple( + quantize_down_stage, clamp_stage, saturating_cast_stage); + gemmlowp::GemmWithOutputPipeline< + LhsScalar, DstScalar, gemmlowp::L8R8WithLhsNonzeroBitDepthParams>( + &GlobalGemmlowpContext(), gemmlowp_lhs, gemmlowp_rhs, &gemmlowp_dst, + -lhs.zero_point, -rhs.zero_point, output_pipeline); + } + } +} + +inline constexpr int Mash(Order LhsOrder, Order RhsOrder, Order DstOrder) { + return (LhsOrder == Order::kRowMajor ? 4 : 0) + + (RhsOrder == Order::kRowMajor ? 2 : 0) + + (DstOrder == Order::kRowMajor ? 1 : 0); +} + +template +void EvalGemmlowp(const Matrix& lhs, const Matrix& rhs, + const Spec& spec, int max_num_threads, + Matrix* dst) { + int index = Mash(lhs.layout.order, rhs.layout.order, dst->layout.order); + switch (index) { +#define EVALGEMMLOWP_CASE3(LHS, RHS, DST) \ + case Mash(LHS, RHS, DST): \ + return EvalGemmlowp(lhs, rhs, spec, max_num_threads, dst); +#define EVALGEMMLOWP_CASE2(LHS, RHS) \ + EVALGEMMLOWP_CASE3(LHS, RHS, Order::kColMajor) \ + EVALGEMMLOWP_CASE3(LHS, RHS, Order::kRowMajor) +#define EVALGEMMLOWP_CASE1(LHS) \ + EVALGEMMLOWP_CASE2(LHS, Order::kColMajor) \ + EVALGEMMLOWP_CASE2(LHS, Order::kRowMajor) + + EVALGEMMLOWP_CASE1(Order::kColMajor) + EVALGEMMLOWP_CASE1(Order::kRowMajor) + +#undef EVALGEMMLOWP_CASE1 +#undef EVALGEMMLOWP_CASE2 +#undef EVALGEMMLOWP_CASE3 + + default: + RUY_CHECK(false); + } +} + +template +struct EigenOrder {}; + +template <> +struct EigenOrder { + static constexpr int kValue = Eigen::ColMajor; +}; + +template <> +struct EigenOrder { + static constexpr int kValue = Eigen::RowMajor; +}; + +template +void EvalEigen(const Matrix& lhs, const Matrix& rhs, + const Spec& spec, int max_num_threads, Matrix* dst) { + RUY_CHECK_EQ(lhs.zero_point, 0); + RUY_CHECK_EQ(rhs.zero_point, 0); + RUY_CHECK_EQ(dst->zero_point, 0); + RUY_CHECK_EQ(spec.multiplier_fixedpoint, 0); + RUY_CHECK_EQ(spec.multiplier_exponent, 0); + + static constexpr int kEigenLhsOrder = EigenOrder::kValue; + static constexpr int kEigenRhsOrder = EigenOrder::kValue; + static constexpr int kEigenDstOrder = EigenOrder::kValue; + + using EigenLhsType = typename Eigen::Matrix:: + template StridedConstMapType>::type; + using EigenRhsType = typename Eigen::Matrix:: + template StridedConstMapType>::type; + using EigenDstType = typename Eigen::Matrix:: + template StridedMapType>::type; + using EigenBiasType = + typename Eigen::Matrix::ConstMapType; + + EigenLhsType eigen_lhs(lhs.data.get(), lhs.layout.rows, lhs.layout.cols, + Eigen::OuterStride(lhs.layout.stride)); + EigenRhsType eigen_rhs(rhs.data.get(), rhs.layout.rows, rhs.layout.cols, + Eigen::OuterStride(rhs.layout.stride)); + EigenDstType eigen_dst( + dst->data.get(), dst->layout.rows, dst->layout.cols, + Eigen::OuterStride(dst->layout.stride)); + Eigen::setNbThreads(max_num_threads ? max_num_threads : 1); + + if (spec.bias) { + EigenBiasType eigen_bias(spec.bias, dst->layout.rows); + if (spec.clamp_max == std::numeric_limits::infinity() && + spec.clamp_min == -std::numeric_limits::infinity()) { + eigen_dst.noalias() = (eigen_lhs * eigen_rhs).colwise() + eigen_bias; + } else { + eigen_dst.noalias() = ((eigen_lhs * eigen_rhs).colwise() + eigen_bias) + .cwiseMin(spec.clamp_max) + .cwiseMax(spec.clamp_min); + } + } else { + if (spec.clamp_max == std::numeric_limits::infinity() && + spec.clamp_min == -std::numeric_limits::infinity()) { + eigen_dst.noalias() = eigen_lhs * eigen_rhs; + } else { + eigen_dst.noalias() = (eigen_lhs * eigen_rhs) + .cwiseMin(spec.clamp_max) + .cwiseMax(spec.clamp_min); + } + } +} + +template +void EvalEigen(const Matrix& lhs, const Matrix& rhs, + const Spec& spec, int max_num_threads, Matrix* dst) { + int index = Mash(lhs.layout.order, rhs.layout.order, dst->layout.order); + switch (index) { +#define EVALEIGEN_CASE3(LHS, RHS, DST) \ + case Mash(LHS, RHS, DST): \ + return EvalEigen(lhs, rhs, spec, max_num_threads, dst); +#define EVALEIGEN_CASE2(LHS, RHS) \ + EVALEIGEN_CASE3(LHS, RHS, Order::kColMajor) \ + EVALEIGEN_CASE3(LHS, RHS, Order::kRowMajor) +#define EVALEIGEN_CASE1(LHS) \ + EVALEIGEN_CASE2(LHS, Order::kColMajor) \ + EVALEIGEN_CASE2(LHS, Order::kRowMajor) + + EVALEIGEN_CASE1(Order::kColMajor) + EVALEIGEN_CASE1(Order::kRowMajor) + +#undef EVALEIGEN_CASE1 +#undef EVALEIGEN_CASE2 +#undef EVALEIGEN_CASE3 + + default: + RUY_CHECK(false); + } +} + +template +void EvalEigenTensor(const Matrix& lhs, const Matrix& rhs, + const Spec& spec, int max_num_threads, + Matrix* dst) { + RUY_CHECK_EQ(lhs.zero_point, 0); + RUY_CHECK_EQ(rhs.zero_point, 0); + RUY_CHECK_EQ(dst->zero_point, 0); + RUY_CHECK_EQ(spec.multiplier_fixedpoint, 0); + RUY_CHECK_EQ(spec.multiplier_exponent, 0); + + // Eigen::TensorMap only supports packed layouts + RUY_CHECK(IsPacked(lhs.layout)); + RUY_CHECK(IsPacked(rhs.layout)); + RUY_CHECK(IsPacked(dst->layout)); + + using TensorLhsType = + Eigen::TensorMap>; + using TensorRhsType = + Eigen::TensorMap>; + using TensorDstType = + Eigen::TensorMap>; + using TensorBiasType = + Eigen::TensorMap>; + + const bool tr = DstOrder == Order::kRowMajor; + const auto& contract_lhs = tr ? rhs : lhs; + const auto& contract_rhs = tr ? lhs : rhs; + + TensorLhsType tensor_lhs( + contract_lhs.data.get(), + LhsOrder == Order::kColMajor ? contract_lhs.layout.rows + : contract_lhs.layout.cols, + LhsOrder == Order::kColMajor ? contract_lhs.layout.cols + : contract_lhs.layout.rows); + TensorRhsType tensor_rhs( + contract_rhs.data.get(), + RhsOrder == Order::kColMajor ? contract_rhs.layout.rows + : contract_rhs.layout.cols, + RhsOrder == Order::kColMajor ? contract_rhs.layout.cols + : contract_rhs.layout.rows); + TensorDstType tensor_dst( + dst->data.get(), + DstOrder == Order::kColMajor ? dst->layout.rows : dst->layout.cols, + DstOrder == Order::kColMajor ? dst->layout.cols : dst->layout.rows); + using DimPair = + typename Eigen::Tensor::DimensionPair; + Eigen::array contract_dims( + {DimPair((LhsOrder == Order::kColMajor) ? 1 : 0, + (RhsOrder == Order::kColMajor) ? 0 : 1)}); + Eigen::array shuffle(DstOrder == Order::kColMajor ? 0 : 1, + DstOrder == Order::kColMajor ? 1 : 0); + static Eigen::ThreadPool pool(max_num_threads ? max_num_threads : 1); + static Eigen::ThreadPoolDevice device(&pool, pool.NumThreads()); + if (spec.bias) { + TensorBiasType tensor_bias(spec.bias, dst->layout.rows); + Eigen::array bias_2d_shape(tr ? 1 : dst->layout.rows, + tr ? dst->layout.rows : 1); + Eigen::array bcast(tr ? dst->layout.cols : 1, + tr ? 1 : dst->layout.cols); + if (spec.clamp_max == std::numeric_limits::infinity() && + spec.clamp_min == -std::numeric_limits::infinity()) { + tensor_dst.device(device) = + tensor_lhs.contract(tensor_rhs, contract_dims); + } else { + tensor_dst.device(device) = + (tensor_lhs.contract(tensor_rhs, contract_dims) + + tensor_bias.reshape(bias_2d_shape).broadcast(bcast)) + .cwiseMin(spec.clamp_max) + .cwiseMax(spec.clamp_min); + } + } else { + if (spec.clamp_max == std::numeric_limits::infinity() && + spec.clamp_min == -std::numeric_limits::infinity()) { + tensor_dst.device(device) = + tensor_lhs.contract(tensor_rhs, contract_dims); + } else { + tensor_dst.device(device) = tensor_lhs.contract(tensor_rhs, contract_dims) + .cwiseMin(spec.clamp_max) + .cwiseMax(spec.clamp_min); + } + } +} + +template +void EvalEigenTensor(const Matrix& lhs, const Matrix& rhs, + const Spec& spec, int max_num_threads, + Matrix* dst) { + int index = Mash(lhs.layout.order, rhs.layout.order, dst->layout.order); + switch (index) { +#define EVALEIGENTENSOR_CASE3(LHS, RHS, DST) \ + case Mash(LHS, RHS, DST): \ + return EvalEigenTensor(lhs, rhs, spec, max_num_threads, dst); +#define EVALEIGENTENSOR_CASE2(LHS, RHS) \ + EVALEIGENTENSOR_CASE3(LHS, RHS, Order::kColMajor) \ + EVALEIGENTENSOR_CASE3(LHS, RHS, Order::kRowMajor) +#define EVALEIGENTENSOR_CASE1(LHS) \ + EVALEIGENTENSOR_CASE2(LHS, Order::kColMajor) \ + EVALEIGENTENSOR_CASE2(LHS, Order::kRowMajor) + + EVALEIGENTENSOR_CASE1(Order::kColMajor) + EVALEIGENTENSOR_CASE1(Order::kRowMajor) + +#undef EVALEIGENTENSOR_CASE1 +#undef EVALEIGENTENSOR_CASE2 +#undef EVALEIGENTENSOR_CASE3 + + default: + RUY_CHECK(false); + } +} + +template +struct GenericBlasGemm {}; + +template <> +struct GenericBlasGemm { + static void Run(char* transa, char* transb, lapack::integer* m, + lapack::integer* n, lapack::integer* k, + lapack::doublereal* alpha, lapack::doublereal* a, + lapack::integer* lda, lapack::doublereal* b, + lapack::integer* ldb, lapack::doublereal* beta, + lapack::doublereal* c, lapack::integer* ldc) { + dgemm_(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + } +}; + +template <> +struct GenericBlasGemm { + static void Run(char* transa, char* transb, lapack::integer* m, + lapack::integer* n, lapack::integer* k, lapack::real* alpha, + lapack::real* a, lapack::integer* lda, lapack::real* b, + lapack::integer* ldb, lapack::real* beta, lapack::real* c, + lapack::integer* ldc) { + sgemm_(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + } +}; + +template +void EvalOpenBlas(const Matrix& lhs, const Matrix& rhs, + const Spec& spec, int max_num_threads, Matrix* dst) { + RUY_CHECK_EQ(lhs.zero_point, 0); + RUY_CHECK_EQ(rhs.zero_point, 0); + RUY_CHECK_EQ(dst->zero_point, 0); + RUY_CHECK_EQ(spec.multiplier_fixedpoint, 0); + RUY_CHECK_EQ(spec.multiplier_exponent, 0); + + Matrix gemm_lhs; + Matrix gemm_rhs; + Matrix gemm_dst; + gemm_dst = *dst; + + // Use Transpose to reduce to the all-column-major case. + // Notice that ruy::Matrix merely holds a pointer, does not own data, + // so Transpose is cheap -- no actual matrix data is being transposed here. + if (dst->layout.order == Order::kColMajor) { + gemm_lhs = lhs; + gemm_rhs = rhs; + } else { + gemm_lhs = rhs; + gemm_rhs = lhs; + Transpose(&gemm_lhs); + Transpose(&gemm_rhs); + Transpose(&gemm_dst); + } + bool transposed_lhs = false; + bool transposed_rhs = false; + + if (gemm_lhs.layout.order == Order::kRowMajor) { + Transpose(&gemm_lhs); + transposed_lhs = true; + } + if (gemm_rhs.layout.order == Order::kRowMajor) { + Transpose(&gemm_rhs); + transposed_rhs = true; + } + + RUY_CHECK_EQ(gemm_lhs.layout.order, Order::kColMajor); + RUY_CHECK_EQ(gemm_rhs.layout.order, Order::kColMajor); + RUY_CHECK_EQ(gemm_dst.layout.order, Order::kColMajor); + + char transa = transposed_lhs ? 'T' : 'N'; + char transb = transposed_rhs ? 'T' : 'N'; + int m = gemm_lhs.layout.rows; + int n = gemm_rhs.layout.cols; + int k = gemm_lhs.layout.cols; + float alpha = 1; + Scalar* a = gemm_lhs.data.get(); + int lda = gemm_lhs.layout.stride; + Scalar* b = gemm_rhs.data.get(); + int ldb = gemm_rhs.layout.stride; + float beta = 0; + Scalar* c = gemm_dst.data.get(); + int ldc = gemm_dst.layout.stride; + GenericBlasGemm::Run(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, + &ldb, &beta, c, &ldc); + + // BLAS does not allow us to express the bias-addition and clamping, so + // we use Eigen for that. + + using EigenDstType = + typename Eigen::Matrix:: + template StridedMapType>::type; + using EigenBiasType = + typename Eigen::Matrix::ConstMapType; + + EigenDstType eigen_dst( + gemm_dst.data.get(), gemm_dst.layout.rows, gemm_dst.layout.cols, + Eigen::OuterStride(gemm_dst.layout.stride)); + Eigen::setNbThreads(max_num_threads ? max_num_threads : 1); + + if (spec.bias) { + EigenBiasType eigen_bias(spec.bias, dst->layout.rows); + if (spec.clamp_max == std::numeric_limits::infinity() && + spec.clamp_min == -std::numeric_limits::infinity()) { + eigen_dst.noalias() = eigen_dst.colwise() + eigen_bias; + } else { + eigen_dst.noalias() = (eigen_dst.colwise() + eigen_bias) + .cwiseMin(spec.clamp_max) + .cwiseMax(spec.clamp_min); + } + } else { + if (spec.clamp_max == std::numeric_limits::infinity() && + spec.clamp_min == -std::numeric_limits::infinity()) { + } else { + eigen_dst.noalias() = + eigen_dst.cwiseMin(spec.clamp_max).cwiseMax(spec.clamp_min); + } + } +} + +template +struct SupportsGemmlowp { + static constexpr bool kValue = + std::is_same::value && + std::is_same::value; +}; + +template +struct UsesSingleScalarType { + static constexpr bool kValue = + std::is_same::value && + std::is_same::value && + std::is_same::value; +}; + +template ::value, + bool EnableGemmlowp = SupportsGemmlowp::kValue, + bool SingleScalarType = UsesSingleScalarType::kValue> +struct EvalExternalPathImpl { + using DstScalar = typename TestSetType::DstScalar; + static void Run(TestSetType*, TestResult*) { RUY_CHECK(false); } +}; + +template +struct EvalExternalPathImpl { + using DstScalar = typename TestSetType::DstScalar; + static void Run(TestSetType* test_set, TestResult* test_result) { + if (test_result->external_path == ExternalPath::kEigen) { + EvalEigen(test_set->lhs.matrix, test_set->rhs.matrix, test_set->spec, + test_set->max_num_threads, &test_result->storage_matrix.matrix); + } else if (test_result->external_path == ExternalPath::kEigenTensor) { + EvalEigenTensor(test_set->lhs.matrix, test_set->rhs.matrix, + test_set->spec, test_set->max_num_threads, + &test_result->storage_matrix.matrix); + } else if (test_result->external_path == ExternalPath::kOpenBlas) { + EvalOpenBlas(test_set->lhs.matrix, test_set->rhs.matrix, test_set->spec, + test_set->max_num_threads, + &test_result->storage_matrix.matrix); + } else { + RUY_CHECK(false); + } + } +}; + +template +struct EvalExternalPathImpl { + using DstScalar = typename TestSetType::DstScalar; + static void Run(TestSetType* test_set, TestResult* test_result) { + if (test_result->external_path == ExternalPath::kGemmlowp) { + EvalGemmlowp(test_set->lhs.matrix, test_set->rhs.matrix, test_set->spec, + test_set->max_num_threads, + &test_result->storage_matrix.matrix); + } else { + RUY_CHECK(false); + } + } +}; + +template +void EvalExternalPath( + TestSetType* test_set, + TestResult* test_result) { + EvalExternalPathImpl::Run(test_set, test_result); +} + +#endif // RUY_TEST_EXTERNAL_PATHS + +template +bool Agree(const Matrix& matrix1, const Matrix& matrix2, + int depth) { + RUY_CHECK_EQ(matrix1.layout.rows, matrix2.layout.rows); + RUY_CHECK_EQ(matrix1.layout.cols, matrix2.layout.cols); + RUY_CHECK_EQ(matrix1.zero_point, matrix2.zero_point); + const int size = matrix1.layout.rows * matrix1.layout.cols; + double tolerated_max_diff = 0; + double tolerated_mean_diff = 0; + if (std::is_floating_point::value) { + // TODO: replace hardcoded 100 by something more sensible, probably + // roughly sqrt(depth) based on central limit theorem. + double max_abs_val = 0; + for (int row = 0; row < matrix1.layout.rows; row++) { + for (int col = 0; col < matrix1.layout.cols; col++) { + max_abs_val = + std::max(max_abs_val, + std::abs(static_cast(Element(matrix1, row, col)))); + max_abs_val = + std::max(max_abs_val, + std::abs(static_cast(Element(matrix2, row, col)))); + } + } + tolerated_max_diff = max_abs_val * std::numeric_limits::epsilon() * + 64 * std::sqrt(static_cast(depth)); + tolerated_mean_diff = tolerated_max_diff / std::sqrt(size); + } else if (RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING)) { + tolerated_max_diff = 1; + // totally empirical + tolerated_mean_diff = std::min(1.0, 2.0 * std::pow(size, -0.2)); + } + double sum_diff = 0; + for (int row = 0; row < matrix1.layout.rows; row++) { + for (int col = 0; col < matrix1.layout.cols; col++) { + double elem1 = Element(matrix1, row, col); + double elem2 = Element(matrix2, row, col); + double diff = elem1 - elem2; + + sum_diff += diff; + // Test (std::abs(diff) > tolerated_max_diff), but also true if diff is + // NaN. + if (!(std::abs(diff) <= tolerated_max_diff)) { + return false; + } + } + } + double mean_diff = sum_diff / size; + if (std::abs(mean_diff) > tolerated_mean_diff) { + return false; + } + return true; +} + +template +bool Agree(const StorageMatrix& storage_matrix1, + const StorageMatrix& storage_matrix2, int depth) { + VerifyConsistentFields(storage_matrix1); + VerifyConsistentFields(storage_matrix2); + return Agree(storage_matrix1.matrix, storage_matrix2.matrix, depth); +} + +template +bool Agree(const TestResult& result1, const TestResult& result2, + int depth) { + return Agree(result1.storage_matrix, result2.storage_matrix, depth); +} + +struct Stats { + double median; + double mean; + double min; + double max; +}; + +inline std::string StatsAsString(const Stats& stats) { + char buf[256]; + snprintf(buf, sizeof(buf), "(median = %g, mean = %g, min = %g, max = %g)", + stats.median, stats.mean, stats.min, stats.max); + return std::string(buf); +} + +template +void GetMatrixStats(const Matrix& matrix, Stats* stats) { + double min = std::numeric_limits::infinity(); + double max = -std::numeric_limits::infinity(); + double sum = 0; + std::vector allvals; + for (int row = 0; row < matrix.layout.rows; row++) { + for (int col = 0; col < matrix.layout.cols; col++) { + double val = Element(matrix, row, col); + min = std::min(min, val); + max = std::max(max, val); + sum += val; + allvals.push_back(val); + } + } + std::sort(allvals.begin(), allvals.end()); + stats->min = min; + stats->max = max; + stats->mean = sum / allvals.size(); + stats->median = allvals[allvals.size() / 2]; +} + +struct ErrorAnalysis { + Stats stats_good; + Stats stats_bad; + // The below is to help document departure from bit exactness. It's probably + // not going to be relevant to floating-point. + std::set error_rows; + std::set error_cols; + int row_of_first_error = 0; + int col_of_first_error = 0; + double first_error_good_value = 0; + double first_error_bad_value = 0; +}; + +template +void AnalyzeTestError(const TestSetType& test_set, int first_bad_result_index, + ErrorAnalysis* error_analysis) { + const auto& good_matrix = test_set.results[0]->storage_matrix.matrix; + const auto& bad_matrix = + test_set.results[first_bad_result_index]->storage_matrix.matrix; + GetMatrixStats(good_matrix, &error_analysis->stats_good); + GetMatrixStats(bad_matrix, &error_analysis->stats_bad); + bool found_first_error = false; + for (int row = 0; row < good_matrix.layout.rows; row++) { + for (int col = 0; col < good_matrix.layout.cols; col++) { + if (Element(good_matrix, row, col) != Element(bad_matrix, row, col)) { + if (!found_first_error) { + found_first_error = true; + error_analysis->row_of_first_error = row; + error_analysis->col_of_first_error = col; + error_analysis->first_error_good_value = + Element(good_matrix, row, col); + error_analysis->first_error_bad_value = Element(bad_matrix, row, col); + } + error_analysis->error_rows.insert(row); + error_analysis->error_cols.insert(col); + } + } + } +} + +template +void ComputeReasonableMultiplier( + const Matrix& lhs, + const Matrix& rhs, double* multiplier) { + using LhsScalar = typename TestSetType::LhsScalar; + using RhsScalar = typename TestSetType::RhsScalar; + using DstScalar = typename TestSetType::DstScalar; + if (std::is_floating_point::value || + std::is_same::value) { + *multiplier = 0; + return; + } + *multiplier = static_cast(std::numeric_limits::max()) / + (static_cast(lhs.layout.cols) * + std::numeric_limits::max() * + std::numeric_limits::max()); +} + +inline void QuantizeMultiplier(double multiplier_double, + std::int32_t* multiplier_fixedpoint, + int* multiplier_exponent) { + RUY_CHECK_GT(multiplier_double, 0); + if (multiplier_double == 0.) { + *multiplier_fixedpoint = 0; + *multiplier_exponent = 0; + return; + } + const double q = std::frexp(multiplier_double, multiplier_exponent); + auto q_fixed = static_cast(std::round(q * (1ll << 31))); + RUY_CHECK_LE(q_fixed, (1ll << 31)); + if (q_fixed == (1ll << 31)) { + q_fixed /= 2; + ++*multiplier_exponent; + } + RUY_CHECK_LE(q_fixed, std::numeric_limits::max()); + *multiplier_fixedpoint = static_cast(q_fixed); +} + +template +void SwitchMultiplierToPerChannel(TestSetType* test_set) { + test_set->per_channel_multiplier_fixedpoint.resize(test_set->rows); + test_set->per_channel_multiplier_exponent.resize(test_set->rows); + for (int i = 0; i < test_set->rows; i++) { + // multipliers typically range in [2^30 ; 2^31 - 1]. + // Values in [0, 2^30 - 1] are normally unused, but harmless. + // Thus a good way to randomize multipliers is to subtract from them + // a random value smaller than 2^30 but still significant compared to it. + std::int32_t nudged_multiplier = test_set->spec.multiplier_fixedpoint - + (global_random_engine()() % (1 << 26)); + int nudged_exponent = + test_set->spec.multiplier_exponent - 1 + (global_random_engine()() % 4); + test_set->per_channel_multiplier_fixedpoint[i] = nudged_multiplier; + test_set->per_channel_multiplier_exponent[i] = nudged_exponent; + } + test_set->spec.multiplier_fixedpoint_perchannel = + test_set->per_channel_multiplier_fixedpoint.data(); + test_set->spec.multiplier_exponent_perchannel = + test_set->per_channel_multiplier_exponent.data(); + test_set->spec.multiplier_fixedpoint = 0; + test_set->spec.multiplier_exponent = 0; +} + +template < + typename TestSetType, + bool IsApplicable = + std::is_same::value && + !std::is_same::value> +struct MakeSpecMultiplierFieldsImpl {}; + +template +struct MakeSpecMultiplierFieldsImpl { + static void Run(TestSetType* test_set) { + double multiplier; + ComputeReasonableMultiplier(test_set->lhs.matrix, + test_set->rhs.matrix, &multiplier); + QuantizeMultiplier(multiplier, &test_set->spec.multiplier_fixedpoint, + &test_set->spec.multiplier_exponent); + if (!test_set->benchmark) { + test_set->perchannel = global_random_engine()() & 1; + } + if (test_set->perchannel) { + SwitchMultiplierToPerChannel(test_set); + } + } +}; + +template +struct MakeSpecMultiplierFieldsImpl { + static void Run(TestSetType* test_set) { + test_set->spec.multiplier_fixedpoint = 0; + test_set->spec.multiplier_exponent = 0; + } +}; + +template +void MakeSpecClampFields(Spec* spec) { + using AccumScalar = typename Spec::AccumScalar; + using DstScalar = typename Spec::DstScalar; + + if (std::is_same::value) { + // Returning raw accumulators, clamping is not supported. + spec->clamp_min = std::numeric_limits::lowest(); + spec->clamp_max = std::numeric_limits::max(); + return; + } + + if (getenv("BENCHMARK_ONLY_MATMUL")) { + if (std::is_floating_point::value) { + spec->clamp_min = -std::numeric_limits::infinity(); + spec->clamp_max = std::numeric_limits::infinity(); + } else { + spec->clamp_min = std::numeric_limits::lowest(); + spec->clamp_max = std::numeric_limits::max(); + } + return; + } + + spec->clamp_min = std::numeric_limits::lowest() + 1; + spec->clamp_max = std::numeric_limits::max() - 1; +} + +template +void TestSet::MakeZeroPoints() { + RUY_CHECK_EQ(life_stage, LifeStage::kInitial); + if (!benchmark && !use_specified_zero_points) { + MakeRandomScalar(RandomRange::kReasonableSrcZeroPoint, &lhs_zero_point); + MakeRandomScalar(RandomRange::kReasonableSrcZeroPoint, &rhs_zero_point); + // If destination is std::int32_t, no dst_zero_point is necessary. + if (std::is_same::value) { + dst_zero_point = 0; + } else { + MakeRandomScalar(RandomRange::kReasonableDstZeroPoint, &dst_zero_point); + } + } + life_stage = LifeStage::kHasZeroPoints; +} + +template +void TestSet::MakeLhsRhs() { + RUY_CHECK_EQ(life_stage, LifeStage::kHasZeroPoints); + MakeRandom(rows, depth, lhs_order, lhs_zero_point, layout_style, + RandomRange::kOffCenterAvoidMinValue, &lhs); + MakeRandom(depth, cols, rhs_order, rhs_zero_point, layout_style, + RandomRange::kGeneral, &rhs); + life_stage = LifeStage::kHasLhsRhs; +} + +template +void TestSet::MakeSpec() { + RUY_CHECK_EQ(life_stage, LifeStage::kHasLhsRhs); + + if (!getenv("BENCHMARK_ONLY_MATMUL") && + (benchmark || (global_random_engine()() & 1))) { + MakeRandomVector(RandomRange::kBias, rows, &bias_data); + spec.bias = bias_data.data(); + } + if (lhs.matrix.zero_point == std::numeric_limits::lowest() && + rhs.matrix.zero_point == std::numeric_limits::lowest()) { + lhs.matrix.zero_point += 1; + } + MakeSpecMultiplierFieldsImpl::Run(this); + MakeSpecClampFields(&spec); + life_stage = LifeStage::kHasSpec; +} + +inline int GetIntEnvVarOrZero(const char* name) { + const char* val = getenv(name); + if (!val) { + return 0; + } + return std::stoi(val); +} + +inline float GetFloatEnvVarOrZero(const char* name) { + const char* val = getenv(name); + if (!val) { + return 0; + } + return std::stof(val); +} + +inline int GetHexIntEnvVarOrZero(const char* name) { + const char* val = getenv(name); + if (!val) { + return 0; + } + return std::stoi(val, nullptr, 16); +} + +inline bool GetBoolEnvVarOrFalse(const char* name) { + return static_cast(GetIntEnvVarOrZero(name)); +} + +template +void TestSet::MakeOtherParams() { + RUY_CHECK_EQ(life_stage, LifeStage::kHasSpec); + if (max_num_threads == 0) { + max_num_threads = GetIntEnvVarOrZero("THREADS"); + } + life_stage = LifeStage::kHasOtherParams; +} + +inline std::vector PathsBitfieldAsVector(Path paths_bitfield) { + std::vector result; + std::uint32_t remaining_paths = static_cast(paths_bitfield); + std::uint32_t test_bit = 1; + while (remaining_paths) { + if (remaining_paths & test_bit) { + result.push_back(static_cast(test_bit)); + } + remaining_paths &= ~test_bit; + test_bit <<= 1; + } + return result; +} + +inline std::vector EnumerateTuningsForPath(Path path, bool benchmark) { + if (benchmark) { + return {Tuning::kAuto}; + } +#if RUY_PLATFORM(ARM) + if (path == Path::kNeon || path == Path::kNeonDotprod) { + return {Tuning::kInOrder, Tuning::kOutOfOrder, Tuning::kAuto}; + } +#endif + return {Tuning::kAuto}; +} + +template +void TestSet::MakePrepackedMatrices() { + RUY_CHECK_EQ(life_stage, LifeStage::kHasResultPaths); + + // Prepacked matrices are Path-dependent, so create them for each test result. + for (auto& result : results) { + // If this result uses an external path, then skip this entirely. + if (result->path == Path::kNone) { + continue; + } + // Pre-packing doesn't make sense for Path::kReference. + // TODO(silvasean): Make Path::kReference an ExternalPath? + if (result->path == Path::kReference) { + continue; + } + + // Determine whether we should create/use prepacked matrices. + if (benchmark) { + // For benchmarking, do as requested. + result->use_prepacked_lhs = benchmark_prepack_lhs; + result->use_prepacked_rhs = benchmark_prepack_rhs; + } else { + // When testing, randomly pre-pack sometimes. But don't do it too often. + result->use_prepacked_lhs = (global_random_engine()() & 7) == 0; + result->use_prepacked_rhs = (global_random_engine()() & 7) == 0; + } + + // Create the pre-packed matrices. + PrepackedMatrix* prepacked_lhs_ptr = + result->use_prepacked_lhs ? &result->prepacked_lhs : nullptr; + PrepackedMatrix* prepacked_rhs_ptr = + result->use_prepacked_rhs ? &result->prepacked_rhs : nullptr; + auto alloc_fn = [&result](std::size_t num_bytes) { + return result->allocator.AllocateBytes(num_bytes); + }; + // Use a dst with a null data pointer to check that the pre-packing + // invocation doesn't write into it. + Matrix null_data_dst = result->storage_matrix.matrix; + null_data_dst.data = nullptr; + GlobalContext().SetRuntimeEnabledPaths(result->path); + PrePackForMul(lhs.matrix, rhs.matrix, spec, &GlobalContext(), + &null_data_dst, prepacked_lhs_ptr, + prepacked_rhs_ptr, alloc_fn); + RUY_CHECK_EQ(GlobalContext().last_taken_path, result->path); + } + + life_stage = LifeStage::kHasPrepackedMatrices; +} + +template +void TestSet::MakeResultPaths() { + RUY_CHECK_EQ(life_stage, LifeStage::kHasOtherParams); + + Path paths_bitfield = static_cast(GetHexIntEnvVarOrZero("PATHS")); + + if (paths_bitfield == Path::kNone) { + // Use a dummy Context just to perform the resolution of specific runtime + // enabled paths. + Context context; + paths_bitfield = context.GetRuntimeEnabledPaths(); + } + + // Trim bits that don't correspond to a compiled path, + // to allow specifying e.g. ffff to mean 'all paths' regardless of whether all + // those bits exist as actual paths. + paths_bitfield = paths_bitfield & kAllPaths; + RUY_CHECK_NE(paths_bitfield, Path::kNone); + paths = PathsBitfieldAsVector(paths_bitfield); + +#ifdef RUY_TEST_EXTERNAL_PATHS + + using TestSetType = TestSet; + + if (!GetBoolEnvVarOrFalse("NOEXT")) { + if (SupportsGemmlowp::kValue) { +#ifdef GEMMLOWP_SSE4 + const bool gemmlowp_supported = !spec.multiplier_fixedpoint_perchannel; +#else + const bool gemmlowp_supported = true; +#endif + if (gemmlowp_supported) { + external_paths.push_back(ExternalPath::kGemmlowp); + } + } + if (UsesSingleScalarType::kValue && + std::is_floating_point::value) { + external_paths.push_back(ExternalPath::kEigen); + if (layout_style == LayoutStyle::kPackedLinear) { + external_paths.push_back(ExternalPath::kEigenTensor); + } +// We link against a generic BLAS target that only maps to OpenBLAS on specific +// architectures. +#if RUY_PLATFORM(ARM_32) || RUY_PLATFORM(ARM_64) + // OpenBLAS multi-threading is disabled, so avoid mixing single-threaded + // and multi-threaded benchmark results. + if (max_num_threads == 1 && !getenv("NO_OPENBLAS")) { + external_paths.push_back(ExternalPath::kOpenBlas); + } +#endif + } + } + +#endif // RUY_TEST_EXTERNAL_PATHS + + for (Path path : paths) { + for (Tuning tuning : EnumerateTuningsForPath(path, benchmark)) { + results.emplace_back(new TestResultType); + TestResultType& result = *results.back(); + result.path = path; + result.tuning = tuning; + MakeRandom(rows, cols, dst_order, dst_zero_point, layout_style, + RandomRange::kGeneral, &result.storage_matrix); + } + } + + for (ExternalPath external_path : external_paths) { + results.emplace_back(new TestResultType); + TestResultType& result = *results.back(); + result.external_path = external_path; + MakeRandom(rows, cols, dst_order, dst_zero_point, layout_style, + RandomRange::kGeneral, &result.storage_matrix); + } + + life_stage = LifeStage::kHasResultPaths; +} + +template +void TestSet::EvalResult( + TestResult* result) { + RUY_CHECK(result->path != Path::kNone || + result->external_path != ExternalPath::kNone); + if (result->path != Path::kNone) { + EvalRuy(result); + } else { +#ifdef RUY_TEST_EXTERNAL_PATHS + using TestSetType = TestSet; + EvalExternalPath(this, result); +#endif + } + const std::string& pathname = PathName(*result); + if (std::find(CoveredPaths()->begin(), CoveredPaths()->end(), pathname) == + CoveredPaths()->end()) { + CoveredPaths()->push_back(pathname); + } +} + +using f32 = float; +using f64 = double; +using u8 = std::uint8_t; +using i8 = std::int8_t; +using u16 = std::uint16_t; +using i16 = std::int16_t; +using u32 = std::uint32_t; +using i32 = std::int32_t; +using u64 = std::uint64_t; +using i64 = std::int64_t; + +template +const char* TypeName() { + return nullptr; +} + +#define RUY_TYPENAME(TYPE) \ + template <> \ + const char* TypeName() { \ + return #TYPE; \ + } + +RUY_TYPENAME(f32) +RUY_TYPENAME(f64) +RUY_TYPENAME(u8) +RUY_TYPENAME(i8) +RUY_TYPENAME(u16) +RUY_TYPENAME(i16) +RUY_TYPENAME(u32) +RUY_TYPENAME(i32) +RUY_TYPENAME(u64) +RUY_TYPENAME(i64) + +#undef RUY_TYPENAME + +template +const char* SymmetryName(const Matrix& matrix) { + if (matrix.zero_point == SymmetricZeroPoint()) { + return "symm"; + } else { + return "asymm"; + } +} + +template +int StorageSize(const Matrix& matrix) { + return sizeof(Scalar) * FlatSize(matrix.layout); +} + +// Helper that replicates a buffer and gives out pointers to the replicas. +// This is useful when one wants to traverse data so that it is cold in cache. +// By having a sufficiently large value of num_repeats, one can ensure that the +// working set covered by the replicas is greater than the cache size. +template +class RepeatedBuffer { + public: + RepeatedBuffer() = default; + void Init(const T* elems, std::size_t num_elems, int num_repeats) { + buffers_.clear(); + allocator_.FreeAll(); + for (int i = 0; i < num_repeats; i++) { + T* p; + allocator_.Allocate(num_elems, &p); + memcpy(p, elems, num_elems * sizeof(T)); + buffers_.push_back(p); + } + } + T* Next() { + T* ret = buffers_[current_]; + current_ = (current_ + 1) % buffers_.size(); + return ret; + } + + private: + Allocator allocator_; + std::vector buffers_; + int current_ = 0; +}; + +template +void TestSet::Benchmark( + TestResult* result) { + using DstScalar = typename SpecType::DstScalar; + + const bool cold = getenv("RUY_BENCHMARK_COLD"); + LhsScalar* orig_lhs_data = lhs.matrix.data.get(); + RhsScalar* orig_rhs_data = rhs.matrix.data.get(); + DstScalar* orig_dst_data = result->storage_matrix.matrix.data.get(); + void* orig_prepacked_lhs_data = result->prepacked_lhs.data; + void* orig_prepacked_rhs_data = result->prepacked_rhs.data; + + int num_matmul_sets = 0; + + RepeatedBuffer cold_lhs; + RepeatedBuffer cold_rhs; + RepeatedBuffer cold_dst; + RepeatedBuffer cold_prepacked_lhs; + RepeatedBuffer cold_prepacked_rhs; + + if (cold) { + const int kWorkingSetSize = 100 << 20; + const int each_matmul_set_size = StorageSize(lhs.matrix) + + StorageSize(rhs.matrix) + + StorageSize(result->storage_matrix.matrix); + num_matmul_sets = + (kWorkingSetSize + each_matmul_set_size - 1) / each_matmul_set_size; + + cold_lhs.Init(lhs.matrix.data.get(), FlatSize(lhs.matrix.layout), + num_matmul_sets); + cold_rhs.Init(rhs.matrix.data.get(), FlatSize(rhs.matrix.layout), + num_matmul_sets); + cold_dst.Init(result->storage_matrix.matrix.data.get(), + FlatSize(result->storage_matrix.matrix.layout), + num_matmul_sets); + if (benchmark_prepack_lhs) { + cold_prepacked_lhs.Init(static_cast(result->prepacked_lhs.data), + result->prepacked_lhs.data_size, num_matmul_sets); + } + if (benchmark_prepack_rhs) { + cold_prepacked_rhs.Init(static_cast(result->prepacked_rhs.data), + result->prepacked_rhs.data_size, num_matmul_sets); + } + } + const bool record_pmu = GetBoolEnvVarOrFalse("RUY_BENCHMARK_PMU"); + int repeats = GetIntEnvVarOrZero("RUY_BENCHMARK_REPEATS"); + if (!repeats) { + repeats = 4; + } + float benchmark_min_secs = GetFloatEnvVarOrZero("RUY_BENCHMARK_MIN_SECS"); + if (!benchmark_min_secs) { + benchmark_min_secs = 0.5; + } +#ifdef RUY_PROFILER + { + const char* lhstype = TypeName(); + const char* lhssymm = SymmetryName(lhs.matrix); + const char* rhstype = TypeName(); + const char* rhssymm = SymmetryName(rhs.matrix); + + printf("Profiling path=%s shape=(%dx%dx%d) lhs=(%s,%s) rhs=(%s,%s)\n", + PathName(*result).c_str(), rows, depth, cols, lhstype, lhssymm, + rhstype, rhssymm); + ruy::profiler::ScopeProfile profile; +#endif + + float latency = std::numeric_limits::infinity(); + float l1_refill_rate = std::numeric_limits::infinity(); + float l2_refill_rate = std::numeric_limits::infinity(); + float l3_refill_rate = std::numeric_limits::infinity(); + float l1tlb_refill_rate = std::numeric_limits::infinity(); + float l2tlb_refill_rate = std::numeric_limits::infinity(); + float mispred_rate = std::numeric_limits::infinity(); + float frontend_stall_rate = std::numeric_limits::infinity(); + float backend_stall_rate = std::numeric_limits::infinity(); + + for (int repeat = 0; repeat < repeats; repeat++) { + auto& pmu_events = GlobalPmuEvents(); + if (record_pmu) { + pmu_events.StartRecording(); + } + TimePoint time_start = Now(); + TimePoint t = time_start; + int iters = 0; + int iters_at_a_time = 1; + while (ToFloatSeconds(t - time_start) < benchmark_min_secs) { + for (int i = 0; i < iters_at_a_time; i++) { + if (cold) { + lhs.matrix.data = cold_lhs.Next(); + rhs.matrix.data = cold_rhs.Next(); + result->storage_matrix.matrix.data = cold_dst.Next(); + if (benchmark_prepack_lhs) { + result->prepacked_lhs.data = cold_prepacked_lhs.Next(); + } + if (benchmark_prepack_rhs) { + result->prepacked_rhs.data = cold_prepacked_rhs.Next(); + } + } + EvalResult(result); + iters++; + } + iters_at_a_time *= 2; + t = Now(); + } + latency = std::min( + latency, static_cast(ToFloatSeconds(t - time_start) / iters)); + if (record_pmu) { + pmu_events.StopRecording(); + const float normalization_factor = + 1.0f / (static_cast(iters) * rows * cols * depth); + l1_refill_rate = std::min( + l1_refill_rate, pmu_events.L1RefillCount() * normalization_factor); + l2_refill_rate = std::min( + l2_refill_rate, pmu_events.L2RefillCount() * normalization_factor); + l3_refill_rate = std::min( + l3_refill_rate, pmu_events.L3RefillCount() * normalization_factor); + l1tlb_refill_rate = + std::min(l1tlb_refill_rate, + pmu_events.L1TLBRefillCount() * normalization_factor); + l2tlb_refill_rate = + std::min(l2tlb_refill_rate, + pmu_events.L2TLBRefillCount() * normalization_factor); + mispred_rate = + std::min(mispred_rate, pmu_events.BranchMispredictionCount() * + normalization_factor); + frontend_stall_rate = + std::min(frontend_stall_rate, + pmu_events.FrontendStallCount() * normalization_factor); + backend_stall_rate = + std::min(backend_stall_rate, + pmu_events.BackendStallCount() * normalization_factor); + } + } + result->latency = latency; + if (record_pmu) { + result->l1_refill_rate = l1_refill_rate; + result->l2_refill_rate = l2_refill_rate; + result->l3_refill_rate = l3_refill_rate; + result->l1tlb_refill_rate = l1tlb_refill_rate; + result->l2tlb_refill_rate = l2tlb_refill_rate; + result->mispred_rate = mispred_rate; + result->frontend_stall_rate = frontend_stall_rate; + result->backend_stall_rate = backend_stall_rate; + } + +#ifdef RUY_PROFILER + } + fflush(stdout); +#endif + + if (cold) { + lhs.matrix.data = orig_lhs_data; + rhs.matrix.data = orig_rhs_data; + memcpy(orig_dst_data, result->storage_matrix.matrix.data.get(), + StorageSize(result->storage_matrix.matrix)); + result->storage_matrix.matrix.data = orig_dst_data; + result->prepacked_lhs.data = orig_prepacked_lhs_data; + result->prepacked_rhs.data = orig_prepacked_rhs_data; + } +} + +template +void TestSet::Eval() { + RUY_CHECK_EQ(life_stage, LifeStage::kHasPrepackedMatrices); + for (auto& result : results) { + if (benchmark) { + Benchmark(result.get()); + } else { + EvalResult(result.get()); + } + } + life_stage = LifeStage::kEvaluated; +} + +template +std::string DumpRegion(const Matrix& matrix, int center_row, + int center_col) { + static constexpr int kRadius = 20; + int first_row = std::max(0, center_row - kRadius); + int last_row = std::min(matrix.layout.rows - 1, center_row + kRadius); + int first_col = std::max(0, center_col - kRadius); + int last_col = std::min(matrix.layout.cols - 1, center_col + kRadius); + std::ostringstream stream; + for (int row = first_row; row <= last_row; row++) { + for (int col = first_col; col <= last_col; col++) { + stream << static_cast(Element(matrix, row, col)) << " "; + } + stream << "\n"; + } + return stream.str(); +} + +template +void TestSet::VerifyTestResults() const { + const int depth = lhs.matrix.layout.cols; + for (int i = 0; i < results.size() - 1; i++) { + if (!Agree(*results[i], *results[i + 1], depth)) { + std::string paths_in_agreement; + paths_in_agreement.append(PathName(*results[0])); + for (int j = 1; j <= i; j++) { + paths_in_agreement.append(", "); + paths_in_agreement.append(PathName(*results[j])); + } + ErrorAnalysis error_analysis; + AnalyzeTestError(*this, i + 1, &error_analysis); + std::cerr << "Error: path (" << PathName(*results[i + 1]) + << ") disagrees with the other paths (" << paths_in_agreement + << "), which agree with each other." << std::endl; + std::cerr << "Shape: rows = " << rows << ", cols = " << cols + << ", depth = " << depth << std::endl; + std::cerr << "Stats of the good result matrix: " + << StatsAsString(error_analysis.stats_good) << std::endl; + std::cerr << "Stats of the bad result matrix: " + << StatsAsString(error_analysis.stats_bad) << std::endl; + if (error_analysis.error_rows.size() < rows) { + std::cerr << "Rows containing errors: " + << Join(error_analysis.error_rows) << std::endl; + } else { + std::cerr << "Errors found in ALL rows." << std::endl; + } + if (error_analysis.error_cols.size() < cols) { + std::cerr << "Cols containing errors: " + << Join(error_analysis.error_cols) << std::endl; + } else { + std::cerr << "Errors found in ALL cols." << std::endl; + } + std::cerr << "The first error occurs at row " + << error_analysis.row_of_first_error << ", col " + << error_analysis.col_of_first_error << std::endl; + std::cerr << "Good value: " << error_analysis.first_error_good_value + << std::endl; + std::cerr << "Bad value : " << error_analysis.first_error_bad_value + << std::endl; + std::cerr << "Region of Good result matrix around first error:\n\n" + << DumpRegion(results[0]->storage_matrix.matrix, + error_analysis.row_of_first_error, + error_analysis.col_of_first_error) + << std::endl; + std::cerr << "Region of Bad result matrix around first error:\n\n" + << DumpRegion(results[i + 1]->storage_matrix.matrix, + error_analysis.row_of_first_error, + error_analysis.col_of_first_error) + << std::endl; + RUY_CHECK(false); + } + } +} + +template +void TestSet::Verify() { + RUY_CHECK_EQ(life_stage, LifeStage::kEvaluated); + if (expected_outcome == ExpectedOutcome::kSuccess) { + VerifyTestResults(); + } + life_stage = LifeStage::kFinal; +} + +template +void TestRCC(int rows, int depth, int cols, ExpectedOutcome expected_outcome) { + TestSetType test_set; + test_set.rows = rows; + test_set.depth = depth; + test_set.cols = cols; + test_set.lhs_order = Order::kRowMajor; + test_set.rhs_order = Order::kColMajor; + test_set.dst_order = Order::kColMajor; + test_set.layout_style = LayoutStyle::kPackedLinear; + test_set.expected_outcome = expected_outcome; + test_set.Run(); +} + +template +void TestRCC(int rows, int depth, int cols) { + TestRCC(rows, depth, cols, ExpectedOutcome::kSuccess); +} + +template +void TestNonRCC(int rows, int depth, int cols, + ExpectedOutcome expected_outcome) { + TestSetType test_set; + test_set.rows = rows; + test_set.depth = depth; + test_set.cols = cols; + test_set.lhs_order = Order::kColMajor; + test_set.rhs_order = Order::kColMajor; + test_set.dst_order = Order::kColMajor; + test_set.layout_style = LayoutStyle::kPackedLinear; + test_set.expected_outcome = expected_outcome; + test_set.Run(); +} + +template +void TestLinearAllOrders(int rows, int depth, int cols, + ExpectedOutcome expected_outcome) { + const std::vector orders{Order::kColMajor, Order::kRowMajor}; + + for (Order lhs_order : orders) { + for (Order rhs_order : orders) { + for (Order dst_order : orders) { + TestSetType test_set; + test_set.rows = rows; + test_set.depth = depth; + test_set.cols = cols; + test_set.lhs_order = lhs_order; + test_set.rhs_order = rhs_order; + test_set.dst_order = dst_order; + test_set.layout_style = LayoutStyle::kLinear; + test_set.expected_outcome = expected_outcome; + test_set.Run(); + } + } + } +} + +template +void TestLinearAllOrders(int rows, int depth, int cols) { + TestLinearAllOrders(rows, depth, cols, + ExpectedOutcome::kSuccess); +} + +} // namespace ruy + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TEST_H_ diff --git a/ruy/test_fast.cc b/ruy/test_fast.cc new file mode 100644 index 0000000..d1c1308 --- /dev/null +++ b/ruy/test_fast.cc @@ -0,0 +1,110 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This test contains cheap test cases, completes in a few seconds. + +#include + +#include "ruy/test.h" + +namespace ruy { + +using LhsScalar = RUY_TEST_LHSSCALAR; +using RhsScalar = RUY_TEST_RHSSCALAR; +using AccumScalar = RUY_TEST_ACCUMSCALAR; +using DstScalar = RUY_TEST_DSTSCALAR; + +using TestSetType = + TestSet>; + +TEST(RuyTest, TestSquareMuls) { + const std::vector sizes{ + // small sizes + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + // multiplies of 16 + 16, + 32, + 48, + 64, + // pot-minus-1 sizes + 15, + 31, + 63, + // pot-plus-1 sizes + 17, + 33, + 65, + }; + + for (int size : sizes) { + TestRCC(size, size, size); + TestLinearAllOrders(size, size, size); + } +} + +TEST(RuyTest, TestMiscMuls) { + const int shapes[][3] = { + {2, 3, 4}, {7, 6, 5}, {12, 23, 6}, {19, 3, 11}, {3, 10, 17}, + {30, 21, 43}, {7, 57, 9}, {49, 69, 71}, {38, 111, 29}, {87, 98, 76}, + {16, 96, 16}, {16, 88, 16}, {16, 84, 16}, {16, 92, 16}, {16, 82, 16}, + {16, 81, 16}, {16, 95, 16}, {3, 128, 5}}; + for (const auto& shape : shapes) { + TestLinearAllOrders(shape[0], shape[1], shape[2]); + } +} + +TEST(RuyTest, TestDeepMuls) { + // TODO(b/137649322): clarify what's the max allowed matrix size. + TestRCC(1, 32767, 1); + TestLinearAllOrders(5, 5001, 4); + TestLinearAllOrders(9, 1025, 10); +} + +TEST(RuyTest, TestShallowMuls) { + TestLinearAllOrders(101, 1, 103); + TestLinearAllOrders(71, 2, 53); + TestLinearAllOrders(51, 3, 73); + TestLinearAllOrders(51, 4, 43); +} + +TEST(RuyTest, TestNarrowMuls) { + for (int width : {1, 2, 3, 4, 5, 8}) { + TestLinearAllOrders(width, 12, 13); + TestLinearAllOrders(15, 19, width); + TestLinearAllOrders(width, 123, 137); + TestLinearAllOrders(158, 119, width); + } +} + +TEST(RuyTest, TestGEMV) { + for (int size = 1; size < 1024; size *= 2) { + for (int depth = 1; depth < 500; depth += 47) { + TestLinearAllOrders(size, depth, 1); + } + } + TestLinearAllOrders(5, 5001, 1); + TestLinearAllOrders(8193, 17, 1); +} + +} // namespace ruy diff --git a/ruy/test_slow.cc b/ruy/test_slow.cc new file mode 100644 index 0000000..9f0f218 --- /dev/null +++ b/ruy/test_slow.cc @@ -0,0 +1,71 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This test contains more expensive test cases. + +#include "ruy/test.h" + +namespace ruy { + +using LhsScalar = RUY_TEST_LHSSCALAR; +using RhsScalar = RUY_TEST_RHSSCALAR; +using AccumScalar = RUY_TEST_ACCUMSCALAR; +using DstScalar = RUY_TEST_DSTSCALAR; + +using TestSetType = + TestSet>; + +TEST(RuyTest, TestBigNarrowMuls) { + for (int width : {1, 2, 3, 4, 5, 8}) { + TestRCC(width, 401, 601); + TestRCC(587, 443, width); + } + TestRCC(7, 45984, + 5); // Large enough to trigger row-sum overflows. + TestRCC(512, 256, 16); +} + +TEST(RuyTest, TestBigShallowMuls) { + TestLinearAllOrders(501, 1, 321); + TestLinearAllOrders(301, 5, 403); + TestLinearAllOrders(256, 32, 512); +} + +TEST(RuyTest, TestBigMuls) { + TestRCC(225, 303, 199); + TestLinearAllOrders(256, 192, 128); +} + +TEST(RuyTest, TestBigPowerOfTwoDepthWithAvoidAliasing) { + // Important to test some power-of-two depths: that's when the + // RUY_AVOID_ALIASING optimization kicks in and makes packed matrices + // strided, exposing bugs in kernels mixing up size and stride. + // Moreover, it's important that the test matrices be sufficiently wide + // that they will result in multiple blocks, exposing bugs in the + // computation of the base address of each block. + TestLinearAllOrders(70, 1024, 80); + TestLinearAllOrders(60, 2048, 70); + TestLinearAllOrders(40, 4096, 50); +} + +TEST(RuyTest, TestGEMV) { + for (int size = 1025; size <= 1409; size += 384) { + for (int depth = 350; depth < 500; depth += 47) { + TestLinearAllOrders(size, depth, 1); + } + } +} + +} // namespace ruy diff --git a/ruy/test_special_specs.cc b/ruy/test_special_specs.cc new file mode 100644 index 0000000..a621d0e --- /dev/null +++ b/ruy/test_special_specs.cc @@ -0,0 +1,163 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This test covers non-basic specs. + +#include "ruy/test.h" + +namespace ruy { + +template +struct LoopStructureSpec : BasicSpec { + static constexpr LoopStructure kLoopStructure = tLoopStructure; +}; + +template +struct ZeroPointSupportSpec : BasicSpec { + static constexpr ZeroPointSupport kZeroPointSupport = tZeroPointSupport; +}; + +template +struct RCCSpec : BasicSpec { + static constexpr LayoutSupport kLayoutSupport = LayoutSupport::kRCC; +}; + +template +struct StandardCppKernelLayoutSpec : BasicSpec { + using StandardCppKernelLhsLayout = LhsKernelLayout; + using StandardCppKernelRhsLayout = RhsKernelLayout; + static int local_data_cache_size() { return 1; } + static int shared_data_cache_size() { return 1; } +}; + +using LhsScalar = RUY_TEST_LHSSCALAR; +using RhsScalar = RUY_TEST_RHSSCALAR; +using AccumScalar = RUY_TEST_ACCUMSCALAR; +using DstScalar = RUY_TEST_DSTSCALAR; + +template +void TestLoopStructure() { + using SpecType = LoopStructureSpec; + using TestSetType = TestSet; + for (int size = 1; size < 10; size++) { + TestLinearAllOrders(size, size, size); + } + TestLinearAllOrders(3, 5, 78); + TestLinearAllOrders(19, 91, 7); + TestLinearAllOrders(71, 26, 44); + TestLinearAllOrders(81, 93, 72); +} + +TEST(TestSpecialSpecs, LoopStructure) { + static_assert(BasicSpec::kLoopStructure == + LoopStructure::kAuto, + ""); + static_assert(BasicSpec::kLoopStructure == LoopStructure::kAuto, + ""); + TestLoopStructure(); + TestLoopStructure(); +} + +template +void TestZeroPointSupport(LhsScalar lhs_zero_point, RhsScalar rhs_zero_point, + DstScalar dst_zero_point, + ExpectedOutcome expected_outcome) { + using SpecType = + ZeroPointSupportSpec; + using TestSetType = TestSet; + TestSetType test_set; + test_set.rows = 11; + test_set.depth = 12; + test_set.cols = 13; + test_set.lhs_order = Order::kRowMajor; + test_set.rhs_order = Order::kColMajor; + test_set.dst_order = Order::kColMajor; + test_set.layout_style = LayoutStyle::kPackedLinear; + test_set.expected_outcome = expected_outcome; + test_set.lhs_zero_point = lhs_zero_point; + test_set.rhs_zero_point = rhs_zero_point; + test_set.dst_zero_point = dst_zero_point; + test_set.use_specified_zero_points = true; + test_set.Run(); +} + +TEST(TestSpecialSpecs, ZeroPointSupport) { + // Sanity check + RUY_CHECK_EQ(SymmetricZeroPoint(), 128); + RUY_CHECK_EQ(SymmetricZeroPoint(), 0); + + if (std::is_floating_point::value) { + return; + } + + TestZeroPointSupport( + SymmetricZeroPoint(), SymmetricZeroPoint(), + SymmetricZeroPoint(), ExpectedOutcome::kSuccess); + TestZeroPointSupport( + SymmetricZeroPoint() - 1, SymmetricZeroPoint(), + SymmetricZeroPoint(), ExpectedOutcome::kSuccess); + TestZeroPointSupport( + SymmetricZeroPoint(), SymmetricZeroPoint(), + SymmetricZeroPoint(), ExpectedOutcome::kSuccess); + TestZeroPointSupport( + SymmetricZeroPoint() + 1, SymmetricZeroPoint(), + SymmetricZeroPoint(), ExpectedOutcome::kDeath); + TestZeroPointSupport( + SymmetricZeroPoint(), SymmetricZeroPoint() + 1, + SymmetricZeroPoint(), ExpectedOutcome::kDeath); + TestZeroPointSupport( + SymmetricZeroPoint(), SymmetricZeroPoint(), + SymmetricZeroPoint() - 1, ExpectedOutcome::kDeath); +} + +TEST(TestSpecialSpecs, RCC) { + using RCCSpec = RCCSpec; + using RCCTestSet = TestSet; + TestRCC(81, 93, 72); + TestNonRCC(81, 93, 72, ExpectedOutcome::kDeath); +} + +template +void TestStandardCppKernelLayout() { + using SpecType = + StandardCppKernelLayoutSpec; + using TestSetType = TestSet; + for (int size = 1; size < 10; size++) { + TestLinearAllOrders(size, size, size); + } + TestLinearAllOrders(87, 34, 56); + TestLinearAllOrders(123, 234, 78); +} + +TEST(TestSpecialSpecs, StandardCppKernelLayoutTrivial1x1) { + TestStandardCppKernelLayout, + FixedKernelLayout>(); +} + +TEST(TestSpecialSpecs, StandardCppKernelLayoutSquare4x4) { + TestStandardCppKernelLayout, + FixedKernelLayout>(); +} + +TEST(TestSpecialSpecs, StandardCppKernelLayoutRectangular4x8) { + TestStandardCppKernelLayout, + FixedKernelLayout>(); +} + +} // namespace ruy diff --git a/ruy/thread_pool.cc b/ruy/thread_pool.cc new file mode 100644 index 0000000..d09bf1e --- /dev/null +++ b/ruy/thread_pool.cc @@ -0,0 +1,200 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/thread_pool.h" + +#include +#include // NOLINT(build/c++11) +#include // NOLINT(build/c++11) +#include +#include +#include +#include // NOLINT(build/c++11) +#include // NOLINT(build/c++11) + +#include "ruy/check_macros.h" +#include "ruy/wait.h" + +namespace ruy { + +// A worker thread. +class Thread { + public: + enum class State { + Startup, // The initial state before the thread main loop runs. + Ready, // Is not working, has not yet received new work to do. + HasWork, // Has work to do. + ExitAsSoonAsPossible // Should exit at earliest convenience. + }; + + explicit Thread(BlockingCounter* counter_to_decrement_when_ready) + : task_(nullptr), + state_(State::Startup), + counter_to_decrement_when_ready_(counter_to_decrement_when_ready) { + thread_.reset(new std::thread(ThreadFunc, this)); + } + + ~Thread() { + ChangeState(State::ExitAsSoonAsPossible); + thread_->join(); + } + + // Changes State; may be called from either the worker thread + // or the master thread; however, not all state transitions are legal, + // which is guarded by assertions. + // + // The Task argument is to be used only with new_state==HasWork. + // It specifies the Task being handed to this Thread. + void ChangeState(State new_state, Task* task = nullptr) { + state_mutex_.lock(); + State old_state = state_.load(std::memory_order_relaxed); + RUY_DCHECK_NE(old_state, new_state); + switch (old_state) { + case State::Startup: + RUY_DCHECK_EQ(new_state, State::Ready); + break; + case State::Ready: + RUY_DCHECK(new_state == State::HasWork || + new_state == State::ExitAsSoonAsPossible); + break; + case State::HasWork: + RUY_DCHECK(new_state == State::Ready || + new_state == State::ExitAsSoonAsPossible); + break; + default: + abort(); + } + switch (new_state) { + case State::Ready: + if (task_) { + // Doing work is part of reverting to 'ready' state. + task_->Run(); + task_ = nullptr; + } + break; + case State::HasWork: + RUY_DCHECK(!task_); + task_ = task; + break; + default: + break; + } + state_.store(new_state, std::memory_order_relaxed); + state_cond_.notify_all(); + state_mutex_.unlock(); + if (new_state == State::Ready) { + counter_to_decrement_when_ready_->DecrementCount(); + } + } + + static void ThreadFunc(Thread* arg) { arg->ThreadFuncImpl(); } + + // Called by the master thead to give this thread work to do. + void StartWork(Task* task) { ChangeState(State::HasWork, task); } + + private: + // Thread entry point. + void ThreadFuncImpl() { + ChangeState(State::Ready); + + // Thread main loop + while (true) { + // In the 'Ready' state, we have nothing to do but to wait until + // we switch to another state. + const auto& condition = [this]() { + return state_.load(std::memory_order_acquire) != State::Ready; + }; + Wait(condition, &state_cond_, &state_mutex_); + + // Act on new state. + switch (state_.load(std::memory_order_acquire)) { + case State::HasWork: + // Got work to do! So do it, and then revert to 'Ready' state. + ChangeState(State::Ready); + break; + case State::ExitAsSoonAsPossible: + return; + default: + abort(); + } + } + } + + // The underlying thread. + std::unique_ptr thread_; + + // The task to be worked on. + Task* task_; + + // The condition variable and mutex guarding state changes. + std::condition_variable state_cond_; + std::mutex state_mutex_; + + // The state enum tells if we're currently working, waiting for work, etc. + // Its concurrent accesses by the thread and main threads are guarded by + // state_mutex_, and can thus use memory_order_relaxed. This still needs + // to be a std::atomic because we use WaitForVariableChange. + std::atomic state_; + + // pointer to the master's thread BlockingCounter object, to notify the + // master thread of when this thread switches to the 'Ready' state. + BlockingCounter* const counter_to_decrement_when_ready_; +}; + +void ThreadPool::ExecuteImpl(int task_count, int stride, Task* tasks) { + RUY_DCHECK_GE(task_count, 1); + + // Case of 1 thread: just run the single task on the current thread. + if (task_count == 1) { + (tasks + 0)->Run(); + return; + } + + // Task #0 will be run on the current thread. + CreateThreads(task_count - 1); + counter_to_decrement_when_ready_.Reset(task_count - 1); + for (int i = 1; i < task_count; i++) { + auto task_address = reinterpret_cast(tasks) + i * stride; + threads_[i - 1]->StartWork(reinterpret_cast(task_address)); + } + + // Execute task #0 immediately on the current thread. + (tasks + 0)->Run(); + + // Wait for the threads submitted above to finish. + counter_to_decrement_when_ready_.Wait(); +} + +// Ensures that the pool has at least the given count of threads. +// If any new thread has to be created, this function waits for it to +// be ready. +void ThreadPool::CreateThreads(int threads_count) { + if (threads_.size() >= threads_count) { + return; + } + counter_to_decrement_when_ready_.Reset(threads_count - threads_.size()); + while (threads_.size() < threads_count) { + threads_.push_back(new Thread(&counter_to_decrement_when_ready_)); + } + counter_to_decrement_when_ready_.Wait(); +} + +ThreadPool::~ThreadPool() { + for (auto w : threads_) { + delete w; + } +} + +} // end namespace ruy diff --git a/ruy/thread_pool.h b/ruy/thread_pool.h new file mode 100644 index 0000000..04c201c --- /dev/null +++ b/ruy/thread_pool.h @@ -0,0 +1,102 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file is a fork of gemmlowp's multi_thread_gemm.h, under Apache 2.0 +// license. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_THREAD_POOL_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_THREAD_POOL_H_ + +#include + +#include "ruy/blocking_counter.h" + +namespace ruy { + +// A workload for a thread. +struct Task { + virtual ~Task() {} + virtual void Run() = 0; +}; + +class Thread; + +// A simple pool of threads, that only allows the very +// specific parallelization pattern that we use here: +// One thread, which we call the 'main thread', calls Execute, distributing +// a Task each to N threads, being N-1 'worker threads' and the main thread +// itself. After the main thread has completed its own Task, it waits for +// the worker threads to have all completed. That is the only synchronization +// performed by this ThreadPool. +// +// In particular, there is a naive 1:1 mapping of Tasks to threads. +// This ThreadPool considers it outside of its own scope to try to work +// with fewer threads than there are Tasks. The idea is that such N:M mappings +// of tasks to threads can be implemented as a higher-level feature on top of +// the present low-level 1:1 threadpool. For example, a user might have a +// Task subclass referencing a shared atomic counter indexing into a vector of +// finer-granularity subtasks. Different threads would then concurrently +// increment this atomic counter, getting each their own subtasks to work on. +// That approach is the one used in ruy's multi-thread matrix multiplication +// implementation --- see ruy's TrMulTask. +class ThreadPool { + public: + ThreadPool() {} + + ~ThreadPool(); + + // Executes task_count tasks on task_count threads. + // Grows the threadpool as needed to have at least (task_count-1) threads. + // The 0-th task is run on the thread on which Execute is called: that + // is by definition what we call the "main thread". Synchronization of all + // threads is performed before this function returns. + // + // As explained in the class comment, there is a 1:1 mapping of tasks to + // threads. If you need something smarter than that, for instance if you + // want to run an unbounded number of tasks on a bounded number of threads, + // then you need something higher-level than this ThreadPool, that can + // be layered on top of it by appropriately subclassing Tasks. + // + // TaskType must be a subclass of ruy::Task. That is implicitly guarded by + // the static_cast in this inline implementation. + template + void Execute(int task_count, TaskType* tasks) { + ExecuteImpl(task_count, sizeof(TaskType), static_cast(tasks)); + } + + private: + // Ensures that the pool has at least the given count of threads. + // If any new thread has to be created, this function waits for it to + // be ready. + void CreateThreads(int threads_count); + + // Non-templatized implementation of the public Execute method. + // See the inline implementation of Execute for how this is used. + void ExecuteImpl(int task_count, int stride, Task* tasks); + + // copy construction disallowed + ThreadPool(const ThreadPool&) = delete; + + // The threads in this pool. They are owned by the pool: + // the pool creates threads and destroys them in its destructor. + std::vector threads_; + + // The BlockingCounter used to wait for the threads. + BlockingCounter counter_to_decrement_when_ready_; +}; + +} // namespace ruy + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_THREAD_POOL_H_ diff --git a/ruy/time.h b/ruy/time.h new file mode 100644 index 0000000..9dba75e --- /dev/null +++ b/ruy/time.h @@ -0,0 +1,81 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TIME_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TIME_H_ + +#include // NOLINT(build/c++11) +#include // IWYU pragma: keep +#include // NOLINT(build/c++11) + +#ifdef __linux__ +#include +// IWYU pragma: no_include + +#include +#endif + +namespace ruy { + +using InternalDefaultClock = std::chrono::steady_clock; + +using TimePoint = InternalDefaultClock::time_point; +using Duration = InternalDefaultClock::duration; + +template +Duration DurationFromSeconds(RepresentationType representation) { + return std::chrono::duration_cast( + std::chrono::duration(representation)); +} + +template +Duration DurationFromMilliseconds(RepresentationType representation) { + return std::chrono::duration_cast( + std::chrono::duration(representation)); +} + +template +Duration DurationFromNanoseconds(RepresentationType representation) { + return std::chrono::duration_cast( + std::chrono::duration(representation)); +} + +inline float ToFloatSeconds(const Duration& duration) { + return std::chrono::duration_cast>(duration) + .count(); +} + +inline std::int64_t ToInt64Nanoseconds(const Duration& duration) { + return std::chrono::duration_cast< + std::chrono::duration>(duration) + .count(); +} + +inline TimePoint Now() { return InternalDefaultClock::now(); } + +inline TimePoint CoarseNow() { +#ifdef __linux__ + timespec t; + clock_gettime(CLOCK_MONOTONIC_COARSE, &t); + return TimePoint( + DurationFromNanoseconds(1000000000LL * t.tv_sec + t.tv_nsec)); +#else + return Now(); +#endif +} + +} // namespace ruy + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TIME_H_ diff --git a/ruy/trace.cc b/ruy/trace.cc new file mode 100644 index 0000000..1822cdb --- /dev/null +++ b/ruy/trace.cc @@ -0,0 +1,325 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/trace.h" + +#include +#include // IWYU pragma: keep +#include +#include +#include +#include + +#include "ruy/check_macros.h" +#include "ruy/side_pair.h" +#include "ruy/time.h" + +namespace ruy { + +#ifdef RUY_TRACE + +enum class TraceEvent : std::uint8_t { + kNone, + kThreadStart, + kThreadLoopStart, + kThreadEnd, + kBlockReserved, + kBlockPackedLhs, + kBlockPackedRhs, + kBlockFinished +}; + +struct TraceEntry { + TimePoint time_point; + TraceEvent event; + // ruy-internal thread id i.e. contiguous index into array of threads, + // with 0 designating the main thread. + std::uint16_t thread_id = 0; + // Additional parameters whose meaning depends on the 'event' type. + std::uint32_t params[1]; +}; + +struct Trace { + BlockMap block_map; + // During recording, to avoid having to use locks or atomics, we let + // each thread append to its own specific vector. + std::vector> thread_specific_entries; + // Global vector of entries into which we coalesce thread_specific_entries + // after recording is finished, when dumping a trace. See + // AggregateThreadSpecificEntries. + std::vector entries; + TimePoint time_start; + TimePoint time_execute; + TimePoint time_end; +}; + +namespace { + +// Coalesce Trace::thread_specific_entries into Trace::entries. +void AggregateThreadSpecificEntries(Trace* trace) { + RUY_CHECK(trace->entries.empty()); + for (auto& thread_specific_entries_vector : trace->thread_specific_entries) { + for (const TraceEntry& entry : thread_specific_entries_vector) { + trace->entries.push_back(entry); + } + thread_specific_entries_vector.clear(); + } +} + +// Sort Trace::entries by ascending time. In case of equal timepoints, +// sort by some semi-arbitrary ordering of event types. +void Sort(Trace* trace) { + std::sort(std::begin(trace->entries), std::end(trace->entries), + [](const TraceEntry& a, const TraceEntry& b) -> bool { + return a.time_point < b.time_point || + (a.time_point == b.time_point && + static_cast(a.event) < static_cast(b.event)); + }); +} + +// Dump a trace. Assumes that AggregateThreadSpecificEntries and Sort have +// already been called on it. +// +// On some architectures long long ints are not same as std::int64_t, and +// time is printed as %lld, so static_casts are necessary. +void Dump(const Trace& trace) { + const char* trace_filename = getenv("RUY_TRACE_FILE"); + FILE* trace_file = trace_filename ? fopen(trace_filename, "w") : stderr; + if (!trace_file) { + fprintf(stderr, "Failed to open %s for write, errno=%d\n", trace_filename, + errno); + RUY_CHECK(false); + } + fprintf(trace_file, "thread_count:%d\n", trace.block_map.thread_count); + fprintf(trace_file, "rows:%d\n", trace.block_map.dims[Side::kLhs]); + fprintf(trace_file, "cols:%d\n", trace.block_map.dims[Side::kRhs]); + fprintf(trace_file, "Execute: %lld\n", + static_cast( + ToInt64Nanoseconds(trace.time_execute - trace.time_start))); + for (const TraceEntry& entry : trace.entries) { + long long int time = static_cast( + ToInt64Nanoseconds(entry.time_point - trace.time_start)); + switch (entry.event) { + case TraceEvent::kThreadStart: + fprintf(trace_file, "ThreadStart: %lld, %d\n", time, entry.thread_id); + break; + case TraceEvent::kThreadLoopStart: + fprintf(trace_file, "ThreadLoopStart: %lld, %d\n", time, + entry.thread_id); + break; + case TraceEvent::kThreadEnd: + fprintf(trace_file, "ThreadEnd: %lld, %d\n", time, entry.thread_id); + break; + case TraceEvent::kBlockReserved: { + std::uint32_t block_id = entry.params[0]; + SidePair block; + GetBlockByIndex(trace.block_map, block_id, &block); + SidePair start, end; + GetBlockMatrixCoords(trace.block_map, block, &start, &end); + fprintf(trace_file, + "BlockReserved: %lld, %d, %d, %d, %d, %d, %d, %d, %d\n", time, + entry.thread_id, block_id, block[Side::kLhs], block[Side::kRhs], + start[Side::kLhs], start[Side::kRhs], end[Side::kLhs], + end[Side::kRhs]); + break; + } + case TraceEvent::kBlockPackedLhs: { + std::uint32_t block = entry.params[0]; + int start, end; + GetBlockMatrixCoords(Side::kLhs, trace.block_map, block, &start, &end); + fprintf(trace_file, "BlockPackedLhs: %lld, %d, %d, %d, %d\n", time, + entry.thread_id, block, start, end); + break; + } + case TraceEvent::kBlockPackedRhs: { + std::uint32_t block = entry.params[0]; + int start, end; + GetBlockMatrixCoords(Side::kRhs, trace.block_map, block, &start, &end); + fprintf(trace_file, "BlockPackedRhs: %lld, %d, %d, %d, %d\n", time, + entry.thread_id, block, start, end); + break; + } + case TraceEvent::kBlockFinished: { + std::uint32_t block_id = entry.params[0]; + SidePair block; + GetBlockByIndex(trace.block_map, block_id, &block); + fprintf(trace_file, "BlockFinished: %lld, %d, %d, %d, %d\n", time, + entry.thread_id, block_id, block[Side::kLhs], + block[Side::kRhs]); + break; + } + default: + RUY_CHECK(false); + } + } + fprintf(trace_file, "End: %lld\n", + static_cast( + ToInt64Nanoseconds(trace.time_end - trace.time_start))); + if (trace_filename) { + fclose(trace_file); + } +} + +} // anonymous namespace + +// Get a Trace object to record to, or null of tracing is not enabled. +Trace* NewTraceOrNull(TracingContext* tracing, int rows, int depth, int cols) { + if (!tracing->initialized) { + tracing->initialized = true; + tracing->enabled = getenv("RUY_TRACE"); + if (!tracing->enabled) { + return nullptr; + } + if (getenv("RUY_TRACE_FILTER_ROWS")) { + tracing->filter_shape_rows = std::stoi(getenv("RUY_TRACE_FILTER_ROWS")); + } + if (getenv("RUY_TRACE_FILTER_DEPTH")) { + tracing->filter_shape_depth = std::stoi(getenv("RUY_TRACE_FILTER_DEPTH")); + } + if (getenv("RUY_TRACE_FILTER_COLS")) { + tracing->filter_shape_cols = std::stoi(getenv("RUY_TRACE_FILTER_COLS")); + } + } + if (!tracing->enabled) { + return nullptr; + } + if (tracing->filter_shape_rows && rows != tracing->filter_shape_rows) { + return nullptr; + } + if (tracing->filter_shape_depth && depth != tracing->filter_shape_depth) { + return nullptr; + } + if (tracing->filter_shape_cols && cols != tracing->filter_shape_cols) { + return nullptr; + } + // Delete any existing trace. + delete tracing->trace; + // Create a new one. + tracing->trace = new Trace; + return tracing->trace; +} + +// The trace recorded on a context is finalized and dumped by +// this TracingContext destructor. +// +// The idea of dumping on context destructor is that typically one wants to +// run many matrix multiplications, e.g. to hit a steady state in terms of +// performance characteristics, but only trace the last repetition of the +// workload, when that steady state was attained. +TracingContext::~TracingContext() { + if (trace) { + AggregateThreadSpecificEntries(trace); + Sort(trace); + Dump(*trace); + } + delete trace; +} + +void TraceRecordStart(Trace* trace) { + if (trace) { + trace->time_start = Now(); + } +} + +void TraceRecordExecute(const BlockMap& block_map, Trace* trace) { + if (trace) { + trace->time_execute = Now(); + trace->block_map = block_map; + trace->thread_specific_entries.resize(block_map.thread_count); + for (int thread = 0; thread < block_map.thread_count; thread++) { + trace->thread_specific_entries[thread].clear(); + // Reserve some large size to avoid frequent heap allocations + // affecting the recorded timings. + trace->thread_specific_entries[thread].reserve(16384); + } + } +} + +void TraceRecordEnd(Trace* trace) { + if (trace) { + trace->time_end = Now(); + } +} + +void TraceRecordThreadStart(std::uint32_t thread_id, Trace* trace) { + if (trace) { + TraceEntry entry; + entry.event = TraceEvent::kThreadStart; + entry.time_point = Now(); + entry.thread_id = thread_id; + trace->thread_specific_entries[thread_id].push_back(entry); + } +} + +void TraceRecordThreadLoopStart(std::uint32_t thread_id, Trace* trace) { + if (trace) { + TraceEntry entry; + entry.event = TraceEvent::kThreadLoopStart; + entry.time_point = Now(); + entry.thread_id = thread_id; + trace->thread_specific_entries[thread_id].push_back(entry); + } +} + +void TraceRecordBlockReserved(std::uint32_t thread_id, std::uint32_t block_id, + Trace* trace) { + if (trace) { + TraceEntry entry; + entry.event = TraceEvent::kBlockReserved; + entry.time_point = Now(); + entry.thread_id = thread_id; + entry.params[0] = block_id; + trace->thread_specific_entries[thread_id].push_back(entry); + } +} + +void TraceRecordBlockPacked(std::uint32_t thread_id, Side side, int block, + Trace* trace) { + if (trace) { + TraceEntry entry; + entry.event = side == Side::kLhs ? TraceEvent::kBlockPackedLhs + : TraceEvent::kBlockPackedRhs; + entry.time_point = Now(); + entry.thread_id = thread_id; + entry.params[0] = block; + trace->thread_specific_entries[thread_id].push_back(entry); + } +} + +void TraceRecordBlockFinished(std::uint32_t thread_id, std::uint32_t block_id, + Trace* trace) { + if (trace) { + TraceEntry entry; + entry.event = TraceEvent::kBlockFinished; + entry.time_point = Now(); + entry.thread_id = thread_id; + entry.params[0] = block_id; + trace->thread_specific_entries[thread_id].push_back(entry); + } +} + +void TraceRecordThreadEnd(std::uint32_t thread_id, Trace* trace) { + if (trace) { + TraceEntry entry; + entry.event = TraceEvent::kThreadEnd; + entry.time_point = Now(); + entry.thread_id = thread_id; + trace->thread_specific_entries[thread_id].push_back(entry); + } +} + +#endif + +} // namespace ruy diff --git a/ruy/trace.h b/ruy/trace.h new file mode 100644 index 0000000..d2cc51d --- /dev/null +++ b/ruy/trace.h @@ -0,0 +1,73 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TRACE_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TRACE_H_ + +#include + +#include "ruy/block_map.h" +#include "ruy/side_pair.h" + +namespace ruy { + +struct Trace; + +#ifdef RUY_TRACE + +struct TracingContext { + bool initialized = false; + bool enabled = false; + int filter_shape_rows = 0; + int filter_shape_cols = 0; + int filter_shape_depth = 0; + Trace* trace = nullptr; + ~TracingContext(); +}; + +Trace* NewTraceOrNull(TracingContext* context, int rows, int depth, int cols); +void TraceRecordThreadStart(std::uint32_t thread_id, Trace* trace); +void TraceRecordThreadLoopStart(std::uint32_t thread_id, Trace* trace); +void TraceRecordBlockReserved(std::uint32_t thread_id, std::uint32_t block_id, + Trace* trace); +void TraceRecordBlockPacked(std::uint32_t thread_id, Side side, int block, + Trace* trace); +void TraceRecordBlockFinished(std::uint32_t thread_id, std::uint32_t block_id, + Trace* trace); +void TraceRecordThreadEnd(std::uint32_t thread_id, Trace* trace); +void TraceRecordStart(Trace* trace); +void TraceRecordExecute(const BlockMap& block_map, Trace* trace); +void TraceRecordEnd(Trace* trace); + +#else + +struct TracingContext {}; + +inline Trace* NewTraceOrNull(TracingContext*, int, int, int) { return nullptr; } +inline void TraceRecordThreadStart(std::uint32_t, Trace*) {} +inline void TraceRecordThreadLoopStart(std::uint32_t, Trace*) {} +inline void TraceRecordBlockReserved(std::uint32_t, std::uint32_t, Trace*) {} +inline void TraceRecordBlockPacked(std::uint32_t, Side, int, Trace*) {} +inline void TraceRecordBlockFinished(std::uint32_t, std::uint32_t, Trace*) {} +inline void TraceRecordThreadEnd(std::uint32_t, Trace*) {} +inline void TraceRecordStart(Trace*) {} +inline void TraceRecordExecute(const BlockMap&, Trace*) {} +inline void TraceRecordEnd(Trace*) {} + +#endif + +} // namespace ruy + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TRACE_H_ diff --git a/ruy/trmul.cc b/ruy/trmul.cc new file mode 100644 index 0000000..a3ba46a --- /dev/null +++ b/ruy/trmul.cc @@ -0,0 +1,401 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/trmul.h" + +#include +#include +#include +#include +#include + +#include "ruy/allocator.h" +#include "ruy/block_map.h" +#include "ruy/check_macros.h" +#include "ruy/common.h" +#include "ruy/internal_matrix.h" +#include "ruy/matrix.h" +#include "ruy/opt_set.h" +#include "ruy/profiler/instrumentation.h" +#include "ruy/side_pair.h" +#include "ruy/size_util.h" +#include "ruy/spec.h" +#include "ruy/thread_pool.h" +#include "ruy/trace.h" +#include "ruy/tune.h" + +namespace ruy { + +namespace { + +enum class PackingStatus : std::uint8_t { kNotStarted, kInProgress, kFinished }; + +struct TrMulTask final : Task { + TrMulTask(TrMulParams* params_, const BlockMap& block_map_, + std::atomic* atomic_block_id_, int thread_id_, + bool need_atomics_, + SidePair*> packing_status_, + TuningResolver* tuning_resolver_, Allocator* local_allocator_, + Trace* trace_) + : params(params_), + block_map(block_map_), + atomic_block_id(atomic_block_id_), + thread_id(thread_id_), + need_atomics(need_atomics_), + packing_status(packing_status_), + tuning_resolver(tuning_resolver_), + local_allocator(local_allocator_), + trace(trace_), + local_packed{nullptr, nullptr} {} + + void Run() override { + TraceRecordThreadStart(thread_id, trace); + + for (Side side : {Side::kLhs, Side::kRhs}) { + if (!params->is_prepacked[side]) { + const int size = NumBlocksPerSide(side, block_map); + local_allocator->Allocate(size, &local_packed[side]); + memset(local_packed[side], 0, size * sizeof(bool)); + } + } + + const int num_blocks = NumBlocks(block_map); + + const Tuning tuning = tuning_resolver->Resolve(); + + TraceRecordThreadLoopStart(thread_id, trace); + + SidePair block; + SidePair start; + SidePair end; + + // Each thread starts by initially reserving the block whose id + // is the thread id. + int block_id = thread_id; + TraceRecordBlockReserved(thread_id, block_id, trace); + + while (block_id < num_blocks) { + // Reserve the next block to handle. In order to hide the latency + // (typically comparable to an access to the level of data cache that + // is shared among CPU cores, e.g. 60 cycles on an ARM CPU as of 2019) + // of this atomic operation, we structure this code so as to avoid + // immediately depending on the `next_n` result. + const int next_block_id = + atomic_block_id->fetch_add(1, std::memory_order_relaxed); + TraceRecordBlockReserved(thread_id, next_block_id, trace); + // Get coordinates of the current block to handle, in "block space". + GetBlockByIndex(block_map, block_id, &block); + // Get coordinates of the current block to handle, in matrix space. + GetBlockMatrixCoords(block_map, block, &start, &end); + // Maybe pack the current LHS/RHS block, if not already packed. + EnsurePacked(block, start, end, tuning); + // Actually do matrix multiplication work + params->RunKernel(tuning, start, end); + TraceRecordBlockFinished(thread_id, block_id, trace); + // Move on to the next block as obtained by the atomic increment + // at the start of this while loop iteration. + block_id = next_block_id; + } + + local_allocator->FreeAll(); + + TraceRecordThreadEnd(thread_id, trace); + } + + private: + // Tries to pack a block, without blocking. + // If the block was already packed, returns true. + // If the block was not started packing, packs it and returns true. + // If the block was being packed by another thread, returns false. + bool TryPack(Side side, int block, int start, int end, Tuning tuning) { + if (params->is_prepacked[side]) { + return true; + } + if (!local_packed[side][block]) { + if (need_atomics) { + // Explanation of this compare_exchange_strong operation: + // This atomically performs all of the following: + // 1. Read `status` with "acquire" memory order. + // * That this read uses "acquire" is because both memory orders + // specified have "acquire" as their read-component. + // 2. Compare (bitwise) with `exchanged_status`. + // 3. If equal, stores the value kInProgress to `status` with "release" + // memory order, and returns true, so we take this 'if' branch. + // * That this store uses "release" is because of the _rel part in + // memory_order_acq_rel passed as the first memory order argument. + // 4. If not equal, stores the loaded value of `status` to + // `exchanged_status` with "relaxed" semantics, and returns false, + // so we take the 'else' branch. + // * That this store uses "relaxed" is because the second memory + // order argument, memory_order_acquire, implies no particular + // store semantics. "relaxed" is acceptable here because this + // stores to a local stack variable. + // + // Rationale for compare_exchange_strong as opposed to + // compare_exchange_weak: + // The spurious-failure case with compare_exchange_weak will actually + // happen a lot here, because the atomic 'status' bytes are stored + // contiguously in arrays and neighboring values will be accessed + // by multiple threads concurrently. On a typical ARM CPU, an exclusives + // reservation granule is 64 bytes, so a lot of false-sharing may + // happen. Using compare_exchange_weak would thus result in often having + // TryPack return 'false' when it could instead have done the packing + // work and returned 'true'. Heuristically, that is not a good thing. + // Moreover, this changes the TryPack contract, loosening it and making + // it harder for the caller to reason about. Finally, the overhead of + // atomic operations is mitigated by the enclosing check on + // local_packed, so maybe the overhead of compare_exchange_strong isn't + // such a problem. But we don't really know for sure, that would be + // interesting to experiment more with. + PackingStatus exchanged_status = PackingStatus::kNotStarted; + std::atomic& status = packing_status[side][block]; + if (status.compare_exchange_strong( + exchanged_status, PackingStatus::kInProgress, + std::memory_order_acq_rel, std::memory_order_acquire)) { + // In this branch, the status was kNotStarted and we just atomically + // changed it to kInProgress as we are about to handle the packing + // ourselves. + params->RunPack(side, tuning, start, end); + TraceRecordBlockPacked(thread_id, side, block, trace); + status.store(PackingStatus::kFinished, std::memory_order_release); + } else if (exchanged_status == PackingStatus::kInProgress) { + // Another thread is currently packing this block. + return false; + } + RUY_DCHECK(status.load(std::memory_order_acquire) == + PackingStatus::kFinished); + } else { + // Single-threaded case: no need for expensive atomics, local_packed + // is the truth already. + params->RunPack(side, tuning, start, end); + TraceRecordBlockPacked(thread_id, side, block, trace); + } + local_packed[side][block] = true; + } + return true; + } + + // Ensures that both the LHS and RHS blocks required by the specified block + // are packed. In the event that they are already being packed on another + // threads, this function may perform the packing of some other block while + // waiting for that other thread to finish packing the requested block. + void EnsurePacked(const SidePair& block, const SidePair& start, + const SidePair& end, Tuning tuning) { +#if RUY_OPT_ENABLED(RUY_OPT_PACK_AHEAD) + SidePair next_runahead_block{block[Side::kLhs] + 1, + block[Side::kRhs] + 1}; + Side next_runahead_side = Side::kLhs; +#endif + while (true) { + bool both_sides_packed = true; + for (Side side : {Side::kLhs, Side::kRhs}) { + both_sides_packed &= + TryPack(side, block[side], start[side], end[side], tuning); + } + if (both_sides_packed) { + break; + } +#if RUY_OPT_ENABLED(RUY_OPT_PACK_AHEAD) + const Side runahead_side = next_runahead_side; + const int runahead_block = next_runahead_block[runahead_side]; + next_runahead_side = + next_runahead_side == Side::kLhs ? Side::kRhs : Side::kLhs; + if (runahead_block >= NumBlocksPerSide(runahead_side, block_map)) { + continue; + } + int runahead_block_start, runahead_block_end; + GetBlockMatrixCoords(runahead_side, block_map, runahead_block, + &runahead_block_start, &runahead_block_end); + TryPack(runahead_side, runahead_block, runahead_block_start, + runahead_block_end, tuning); + next_runahead_block[runahead_side] = runahead_block + 1; +#endif + } + } + + TrMulParams* params; + const BlockMap& block_map; + std::atomic* atomic_block_id; + int thread_id; + bool need_atomics; + SidePair*> packing_status; + TuningResolver* tuning_resolver; + Allocator* local_allocator; + Trace* trace; + + // Local indicators of packedness to avoid the overhead of atomic ops. + SidePair local_packed; +}; + +void AllocatePMatrix(Allocator* allocator, PMatrix* packed) { + packed->data = allocator->AllocateBytes(DataSize(*packed)); + packed->sums = allocator->AllocateBytes(SumsSize(*packed)); +} + +int GetThreadCount(Context* context, int rows, int cols, int depth) { +#if RUY_PLATFORM(EMSCRIPTEN) + // b/139927184, std::thread constructor raises exception + return 1; +#endif + // Empirically determined rule for reasonable number of + // threads to use. This is proportional to the number of arithmetic ops + // in this Mul (product of the 3 sizes). + static constexpr int kDivisorLog2 = 15; + const int guess_log2 = std::max( + 0, ceil_log2(rows) + ceil_log2(cols) + ceil_log2(depth) - kDivisorLog2); + return std::min(1 << guess_log2, context->max_num_threads); +} + +LoopStructure GetLoopStructure(int tentative_thread_count, int rows, int cols, + int depth, int lhs_scalar_size, + int rhs_scalar_size, int local_data_cache_size, + int shared_data_cache_size) { + if (tentative_thread_count == 1) { + const BlockMapTraversalOrder traversal_order = + GetTraversalOrder(rows, cols, depth, lhs_scalar_size, rhs_scalar_size, + local_data_cache_size, shared_data_cache_size); + // If we are in the GEMV case or the block_map would be using linear + // traversal anyway, use the simple loop. + if ((cols == 1) || traversal_order == BlockMapTraversalOrder::kLinear) { + return LoopStructure::kSimple; + } + } + return LoopStructure::kGeneral; +} + +} // namespace + +void TrMul(TrMulParams* params, Context* context) { + profiler::ScopeLabel label( + "TrMul (Path=0x%x, max_num_threads=%d, is_prepacked=(%d,%d))", + static_cast(params->path), context->max_num_threads, + params->is_prepacked[Side::kLhs], params->is_prepacked[Side::kRhs]); + + PMatrix& packed_lhs = params->packed[Side::kLhs]; + PMatrix& packed_rhs = params->packed[Side::kRhs]; + DMatrix& lhs = params->src[Side::kLhs]; + DMatrix& rhs = params->src[Side::kRhs]; + + const int rows = lhs.layout.cols; + const int cols = rhs.layout.cols; + const int depth = lhs.layout.rows; + + const int tentative_thread_count = GetThreadCount(context, rows, cols, depth); + const auto loop_structure = GetLoopStructure( + tentative_thread_count, rows, cols, depth, lhs.data_type.size, + rhs.data_type.size, params->local_data_cache_size, + params->shared_data_cache_size); + Allocator* allocator = context->GetMainAllocator(); + + // Allocate packed matrices + for (Side side : {Side::kLhs, Side::kRhs}) { + if (!params->is_prepacked[side]) { + AllocatePMatrix(allocator, ¶ms->packed[side]); + } + } + + // Case of running this TrMul as a simple loop. + // This is a good place to start reading this function: all the rest + // of this function is just an optimized, but functionally equivalent, + // version of that. + if (loop_structure == LoopStructure::kSimple) { + profiler::ScopeLabel label_simple("TrMulImpl, simple loop"); + Tuning tuning = context->GetMainThreadTuning(); + + const SidePair origin{0, 0}; + const SidePair rounded_dims{packed_lhs.layout.cols, + packed_rhs.layout.cols}; + for (Side side : {Side::kLhs, Side::kRhs}) { + if (!params->is_prepacked[side]) { + params->RunPack(side, tuning, origin[side], rounded_dims[side]); + } + } + params->RunKernel(tuning, origin, rounded_dims); + + allocator->FreeAll(); + return; + } + + profiler::ScopeLabel label_general("TrMulImpl, general case"); + + auto* trace = NewTraceOrNull(&context->tracing, rows, depth, cols); + TraceRecordStart(trace); + + // Initialize block map. + BlockMap block_map; + MakeBlockMap(packed_lhs.layout.cols, packed_rhs.layout.cols, depth, + packed_lhs.layout.kernel.cols, packed_rhs.layout.kernel.cols, + packed_lhs.data_type.size, packed_rhs.data_type.size, + tentative_thread_count, params->path, + params->local_data_cache_size, params->shared_data_cache_size, + &block_map); + + // Initialize per-thread state. + const int thread_count = block_map.thread_count; + const bool need_atomics = thread_count > 1; + context->EnsureNPerThreadStates(thread_count); + for (auto& per_thread_state : context->per_thread_states) { + per_thread_state->tuning_resolver.SetTuning(context->explicit_tuning); + } + + // In the need_atomics case, allocate and initialize atomic values tracking + // the packing status of blocks. + SidePair*> packing_status{nullptr, nullptr}; + if (need_atomics) { + for (Side side : {Side::kLhs, Side::kRhs}) { + if (!params->is_prepacked[side]) { + const int size = NumBlocksPerSide(side, block_map); + allocator->Allocate(size, &packing_status[side]); + for (int i = 0; i < size; i++) { + packing_status[side][i].store(PackingStatus::kNotStarted, + std::memory_order_relaxed); + } + } + } + } + + // Create the atomic block id, allocate it using Allocator so that + // we get the alignment ensuring that it sits alone in its exclusives + // reservation granule. + std::atomic* atomic_block_id; + allocator->Allocate(1, &atomic_block_id); + + // Create task objects. + TrMulTask* tasks; + allocator->Allocate(thread_count, &tasks); + + atomic_block_id->store(thread_count); + + for (int i = 0; i < thread_count; i++) { + new (tasks + i) TrMulTask(params, block_map, atomic_block_id, i, + need_atomics, packing_status, + &context->per_thread_states[i]->tuning_resolver, + &context->per_thread_states[i]->allocator, trace); + } + + // Do the computation. + TraceRecordExecute(block_map, trace); + context->workers_pool.Execute(thread_count, tasks); + + // Finish up. + for (int i = 0; i < thread_count; i++) { + tasks[i].~TrMulTask(); + } + + allocator->FreeAll(); + TraceRecordEnd(trace); +} + +} // namespace ruy diff --git a/ruy/trmul.h b/ruy/trmul.h new file mode 100644 index 0000000..f50bb0c --- /dev/null +++ b/ruy/trmul.h @@ -0,0 +1,38 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// As a matrix multiplication library, Ruy offers a Mul entry point, performing +// matrix multiplication. For implementation purposes, it is much nicer to +// be dealing with the transpose-and-multiply operation, doing +// Destination = Transpose(LHS) * RHS +// Indeed, the latter is performing dot-products between the *columns* of LHS +// and the columns of RHS, whereas a plain matrix multiplication is performing +// dot-products between the *rows* of LHS and the columns of RHS. +// That is why TrMul is nicer to implement, allowing for a more symmetric +// treatment of LHS and RHS. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TRMUL_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TRMUL_H_ + +#include "ruy/context.h" +#include "ruy/trmul_params.h" + +namespace ruy { + +void TrMul(TrMulParams* params, Context* context); + +} // namespace ruy + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TRMUL_H_ diff --git a/ruy/trmul_params.h b/ruy/trmul_params.h new file mode 100644 index 0000000..47537b7 --- /dev/null +++ b/ruy/trmul_params.h @@ -0,0 +1,67 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TRMUL_PARAMS_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TRMUL_PARAMS_H_ + +#include "ruy/internal_matrix.h" +#include "ruy/side_pair.h" +#include "ruy/tune.h" + +namespace ruy { + +using RunKernelFn = void(Tuning, const SidePair&, void*, + const SidePair&, const SidePair&, DMatrix*); + +using RunPackFn = void(Tuning, const DMatrix&, PMatrix*, int, int); + +// Type-erased data needed for implementing TrMul. +struct TrMulParams { + TrMulParams() : run_pack{nullptr, nullptr}, is_prepacked{false, false} {} + // Helper functions for invoking the function pointers. + void RunPack(Side side, Tuning tuning, int start, int end) { + run_pack[side](tuning, src[side], &packed[side], start, end); + } + void RunKernel(Tuning tuning, const SidePair& start, + const SidePair& end) { + run_kernel(tuning, packed, spec, start, end, &dst); + } + + // path id, can be useful info for some fine-tuning, e.g. to guess reasonable + // cache sizes when not runtime-detectable. + Path path; + + // See Spec::local_data_cache_size(). + int local_data_cache_size = 0; + // See Spec::shared_data_cache_size(). + int shared_data_cache_size = 0; + + // Function pointers to type-erased entry points for kernels and packers. + SidePair run_pack; + RunKernelFn* run_kernel = nullptr; + + // Matrices and packed matrices. + SidePair src; + DMatrix dst; + SidePair packed; + SidePair is_prepacked; + + // Type-erased Spec. + void* spec = nullptr; +}; + +} // namespace ruy + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TRMUL_PARAMS_H_ diff --git a/ruy/tune.cc b/ruy/tune.cc new file mode 100644 index 0000000..a89242f --- /dev/null +++ b/ruy/tune.cc @@ -0,0 +1,161 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/tune.h" + +#include +#include + +namespace ruy { + +#ifdef RUY_IMPLEMENT_TUNING + +namespace { + +void PoorlyOrderedKernel(int iters) { + asm volatile( + "mov w0, %w[iters]\n" + "1:\n" + "subs w0, w0, #1\n" + "mul v0.4s, v0.4s, v0.4s\n" + "mul v0.4s, v0.4s, v0.4s\n" + "mul v0.4s, v0.4s, v0.4s\n" + "mul v0.4s, v0.4s, v0.4s\n" + "mul v1.4s, v1.4s, v1.4s\n" + "mul v1.4s, v1.4s, v1.4s\n" + "mul v1.4s, v1.4s, v1.4s\n" + "mul v1.4s, v1.4s, v1.4s\n" + "mul v2.4s, v2.4s, v2.4s\n" + "mul v2.4s, v2.4s, v2.4s\n" + "mul v2.4s, v2.4s, v2.4s\n" + "mul v2.4s, v2.4s, v2.4s\n" + "mul v3.4s, v3.4s, v3.4s\n" + "mul v3.4s, v3.4s, v3.4s\n" + "mul v3.4s, v3.4s, v3.4s\n" + "mul v3.4s, v3.4s, v3.4s\n" + "bne 1b\n" ::[iters] "r"(iters) + : "cc", "x0", "v0", "v1", "v2", "v3"); +} + +void NicelyOrderedKernel(int iters) { + asm volatile( + "mov w0, %w[iters]\n" + "1:\n" + "subs w0, w0, #1\n" + "mul v0.4s, v0.4s, v0.4s\n" + "mul v1.4s, v1.4s, v1.4s\n" + "mul v2.4s, v2.4s, v2.4s\n" + "mul v3.4s, v3.4s, v3.4s\n" + "mul v0.4s, v0.4s, v0.4s\n" + "mul v1.4s, v1.4s, v1.4s\n" + "mul v2.4s, v2.4s, v2.4s\n" + "mul v3.4s, v3.4s, v3.4s\n" + "mul v0.4s, v0.4s, v0.4s\n" + "mul v1.4s, v1.4s, v1.4s\n" + "mul v2.4s, v2.4s, v2.4s\n" + "mul v3.4s, v3.4s, v3.4s\n" + "mul v0.4s, v0.4s, v0.4s\n" + "mul v1.4s, v1.4s, v1.4s\n" + "mul v2.4s, v2.4s, v2.4s\n" + "mul v3.4s, v3.4s, v3.4s\n" + "bne 1b\n" ::[iters] "r"(iters) + : "cc", "x0", "v0", "v1", "v2", "v3"); +} + +} // namespace + +float TuningResolver::EvalRatio() { + // With the current settings, 400 iterations and 4 repeats, this test has + // a latency of roughly 80 microseconds on a Cortex-A53 at 1.4 GHz. + static constexpr int kLoopIters = 400; + static constexpr int kRepeats = 4; + + Duration timing_poorly_ordered = Duration::max(); + Duration timing_nicely_ordered = Duration::max(); + + for (int r = 0; r < kRepeats; r++) { + TimePoint t0 = Now(); + PoorlyOrderedKernel(kLoopIters); + TimePoint t1 = Now(); + NicelyOrderedKernel(kLoopIters); + TimePoint t2 = Now(); + timing_poorly_ordered = std::min(timing_poorly_ordered, t1 - t0); + timing_nicely_ordered = std::min(timing_nicely_ordered, t2 - t1); + } + + return ToFloatSeconds(timing_nicely_ordered) / + ToFloatSeconds(timing_poorly_ordered); +} + +float TuningResolver::ThresholdRatio() { + // Empirically (see :tune_tool) determined threshold to distinguish in-order + // Cortex-A53/A55 cores from out-of-order Cortex-A57/A73/A75/A76 cores. Based + // on these experimental results, which were obtained with much lower + // (kLoopIters=1000, kRepeats=1) so as to make them resilient to noise, we + // have: + // + // CPU core type | in/out of order | observed ratio + // --------------+-----------------+----------------------------------------- + // Cortex-A53 | in-order | 0.32 -- 0.329 + // Cortex-A55 | in-order | 0.319 -- 0.325 + // Cortex-A55r1 | in-order | 0.319 -- 0.325 + // Cortex-A57 | out-of-order | 0.99 -- 1.01 + // Cortex-A73 | out-of-order | 0.922 -- 0.927 + // Cortex-A75 | out-of-order | 0.921 -- 0.93 + // Cortex-A76 | out-of-order | 1 + // Kryo (pixel1) | out-of-order | 0.73 -- 0.76 + // + // Thus the allowable range for the threshold is [0.35 .. 0.70]. + // We pick a value closer to the upper bound because really any out-of-order + // CPU should by definition produce a ratio close to 1. + return 0.65f; +} + +Tuning TuningResolver::ResolveNow() { + const bool is_probably_inorder = EvalRatio() < ThresholdRatio(); + return is_probably_inorder ? Tuning::kInOrder : Tuning::kOutOfOrder; +} + +#else // not defined RUY_IMPLEMENT_TUNING + +float TuningResolver::EvalRatio() { return 0; } +float TuningResolver::ThresholdRatio() { return 0; } + +Tuning TuningResolver::ResolveNow() { return Tuning::kOutOfOrder; } + +#endif + +TuningResolver::TuningResolver() + : expiry_duration_(DurationFromMilliseconds(250)) {} + +Tuning TuningResolver::Resolve() { +#ifdef RUY_IMPLEMENT_TUNING + if (unresolved_tuning_ != Tuning::kAuto) { + return unresolved_tuning_; + } + TimePoint new_timepoint = CoarseNow(); + if (last_resolved_tuning_ != Tuning::kAuto && + (new_timepoint - last_resolved_timepoint_) < expiry_duration_) { + return last_resolved_tuning_; + } + last_resolved_timepoint_ = new_timepoint; + last_resolved_tuning_ = ResolveNow(); + return last_resolved_tuning_; +#else + return Tuning::kOutOfOrder; +#endif +} + +} // namespace ruy diff --git a/ruy/tune.h b/ruy/tune.h new file mode 100644 index 0000000..e6a0ee8 --- /dev/null +++ b/ruy/tune.h @@ -0,0 +1,163 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Library doing minimal CPU detection to decide what to tune asm code for. +// +// # Tuning vs Path +// +// Tunings are merely local variations of optimized code paths, that are +// drop-in replacements for each other --- the input and output data layouts +// are identical. By contrast, what ruy calls a Path dictates its own +// data layouts. For example, Path::kNeonDotprod will use different +// layouts compared to Path::kNeon; but within each, different tunings +// will share that same layout. +// +// # Tuning is for now only based on 1 bit: OutOfOrder / InOrder +// +// In practice, each of our asm code paths only needs one bit information to +// decide on tuning: whether the CPU is out-of-order or in-order. +// That is because out-of-order CPUs are by definition relatively insensitive +// to small-scale asm details (which is what "tuning" is about); and for each +// asm code path, there tends to be one main in-order CPU architecture that +// we focus our tuning effort on. Examples: +// * For Path::kNeon, the main in-order CPU is Cortex-A53/A55 (pre-dotprod) +// * For Path::kNeonDotprod, the main in-order CPU is Cortex-A55r1 (dotprod) +// +// Because having tuned code paths is a compromise of efficiency gains +// versus implementation effort and code size, we are happy to stop at just this +// single bit of information, OutOfOrder/InOrder, at least in the current CPU +// landscape. This could change in the future. +// +// # Implementation notes and alternatives. +// +// The current implementation uses a nano-benchmark, see tune.cc. +// That is why it's quite expensive, making caching / +// statefulness necessary (see TuningResolver class comment). +// +// An interesting alternative, which was explained to us by Marat Dukhan +// (maratek@) after this was implemented, would be to use the +// getcpu(2) system call on Linux. This returns a +// numeric CPU identifier that could be mapped to a OutOfOrder/InOrder +// classification given additional information about the CPU. Such +// additional information could be obtained by the cpuinfo library, +// https://github.com/pytorch/cpuinfo +// which obtains this information mainly from parsing /proc/cpuinfo. +// Pros: +// * Would remove the need for the relatively expensive nano-benchmark +// (dozens of microseconds, which have to be reevaluated again several +// times per second). +// * Would conceivably be more reliable. +// Cons: +// * Linux-specific. +// * Modest binary size increase (Marat mentioned the cpuinfo lib is 20k). +// * Won't support exactly 100% of devices (nonstandard /proc/cpuinfo etc). +// +// We could also have both: +// * Maybe by trying getcpu first if supported, then falling back to a +// nano-benchmark. +// * Maybe using getcpu in conjunction with the nano-benchmark to cache +// per-CPU-id nano-benchmark results. +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TUNE_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TUNE_H_ + +#include "ruy/opt_set.h" +#include "ruy/platform.h" +#include "ruy/time.h" + +// Tuning only implemented on NEON_64 at the moment (see assembly code +// in the nano-benchmark) and not on Apple (some Apple CPUs produce incorrect +// results on in-order-tuned kernels combining ARM and NEON load instructions +// and NEON `ins` instructions). +// +// When tuning is not implemented, we simply always use Tuning::kOutOfOrder. +#if RUY_OPT_ENABLED(RUY_OPT_TUNING) && RUY_PLATFORM(NEON_64) && \ + !RUY_PLATFORM(APPLE) +#define RUY_IMPLEMENT_TUNING +#endif + +namespace ruy { + +enum class Tuning { + // kAuto means please use auto-detection. It's the default in the + // user-visible parts (see Context). It's meant to be resolved to an + // actual tuning at some point by means of TuningResolver. + kAuto, + // Target an out-order CPU. Example: ARM Cortex-A75. + kOutOfOrder, + // Target an in-order CPU. Example: ARM Cortex-A55. + kInOrder +}; + +// Why a TuningResolver class? +// +// Ideally, this Library would offer a single function, +// Tuning GetCurrentCPUTuning(); +// +// However, determining information about the current CPU is not necessarily, +// cheap, so we currently cache that and only invalidate/reevaluate after +// a fixed amount of time. This need to store state is why this library +// has to expose a class, TuningResolver, not just a function. +class TuningResolver { + public: + TuningResolver(); + + // Allows the user to specify an explicit Tuning value, bypassing auto + // detection; or to specify Tuning::kAuto, reverting to auto detection. + void SetTuning(Tuning tuning) { unresolved_tuning_ = tuning; } + + // Get an actual tuning --- that is the function that this class wanted to be. + Tuning Resolve(); + + private: + TuningResolver(const TuningResolver&) = delete; + + // TuningTool is a demo/tool used to tweak the tuning implementation to + // specific devices. It needs to access some finer granularity information + // than just the Tuning returned by Resolve. Nothing else should need + // access to that. + friend class TuneTool; + // Actually runs a nano-benchmark, producing a real number called 'ratio' + // whose meaning is generally opaque / implementation defined. Typically, + // this would be the ratio between the latencies of two different + // pieces of asm code differing only by the ordering of instructions, + // revealing whether the CPU cares about such ordering details. + // An implementation may just return a dummy value if it is not based on + // such nanobenchmarking / ratio evaluation. + float EvalRatio(); + // Empirically determined threshold on ratio values delineating + // out-of-order (ratios closer to 1) from in-order (ratios farther from 1). + // An implementation may just return a dummy value if it is not based on + // such nanobenchmarking / ratio evaluation. + float ThresholdRatio(); + // Perform the tuning resolution now. That may typically use EvalRatio and + // ThresholdRatio, but an implementation may use a different approach instead. + Tuning ResolveNow(); + + // The tuning as specified by the user, before actual resolution happens + // i.e. before querying any specifics of the current CPU. + // The default value kAuto means try to auto-detect. Other values mean + // bypass auto-detect, use explicit value instead. See SetTuning(). + Tuning unresolved_tuning_ = Tuning::kAuto; + // Cached last resolved tuning. + Tuning last_resolved_tuning_ = Tuning::kAuto; + // Timepoint of cached last resolved tuning, for invalidation purposes. + TimePoint last_resolved_timepoint_; + // Cached last resolved tunings that are older than this age are invalid. + const Duration expiry_duration_; +}; + +} // namespace ruy + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TUNE_H_ diff --git a/ruy/tune_test.cc b/ruy/tune_test.cc new file mode 100644 index 0000000..ebd86e0 --- /dev/null +++ b/ruy/tune_test.cc @@ -0,0 +1,53 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/tune.h" + +#include // NOLINT(build/c++11) +#include // NOLINT(build/c++11) + +#include "testing/base/public/gunit.h" + +namespace ruy { +namespace { + +TEST(TuneTest, TuneTest) { + TuningResolver tuning_resolver; + ASSERT_FALSE(tuning_resolver.Resolve() == Tuning::kAuto); + // 1 second is likely higher than TuningResolver's internal cache expiry, + // exercising the logic invalidating earlier tuning resolutions. + std::this_thread::sleep_for(std::chrono::seconds(1)); + ASSERT_FALSE(tuning_resolver.Resolve() == Tuning::kAuto); + + tuning_resolver.SetTuning(Tuning::kAuto); + +#ifdef RUY_IMPLEMENT_TUNING + for (auto tuning : {Tuning::kOutOfOrder, Tuning::kInOrder}) { + tuning_resolver.SetTuning(tuning); + ASSERT_TRUE(tuning_resolver.Resolve() == tuning); + // See above comment about 1 second. + std::this_thread::sleep_for(std::chrono::seconds(1)); + ASSERT_TRUE(tuning_resolver.Resolve() == tuning); + } +#endif +} + +} // namespace +} // namespace ruy + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/ruy/tune_tool.cc b/ruy/tune_tool.cc new file mode 100644 index 0000000..0b6e4ab --- /dev/null +++ b/ruy/tune_tool.cc @@ -0,0 +1,56 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Self-contained tool used to tune the tune code --- see the +// threshold ratios used in tune.cc. + +#include // NOLINT(build/c++11) +#include +#include // NOLINT(build/c++11) + +#include "ruy/tune.h" + +#ifdef _WIN32 +#define getpid() 0 +#else +#include +#endif + +namespace ruy { + +class TuneTool { + public: + static void Query(float* eval, float* threshold) { + TuningResolver resolver; + *eval = resolver.EvalRatio(); + *threshold = resolver.ThresholdRatio(); + } +}; + +} // namespace ruy + +int main() { + // Infinite loop: the user can hit Ctrl-C + while (true) { + float eval; + float threshold; + ruy::TuneTool::Query(&eval, &threshold); + printf("[%d] eval=%.3f %c threshold=%.3f ==> probably %s...\n", getpid(), + eval, eval < threshold ? '<' : '>', threshold, + eval < threshold ? "in-order" : "out-of-order"); + fflush(stdout); + std::this_thread::sleep_for(std::chrono::seconds(1)); + } +} diff --git a/ruy/wait.cc b/ruy/wait.cc new file mode 100644 index 0000000..d8156bc --- /dev/null +++ b/ruy/wait.cc @@ -0,0 +1,69 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/wait.h" + +#include // NOLINT(build/c++11) + +namespace ruy { + +void Wait(const std::function& condition, const Duration& spin_duration, + std::condition_variable* condvar, std::mutex* mutex) { + // First, trivial case where the `condition` is already true; + if (condition()) { + return; + } + + // Then try busy-waiting. + const TimePoint wait_start = Now(); + while (Now() - wait_start < spin_duration) { + if (condition()) { + return; + } + } + + // Finally, do real passive waiting. + std::unique_lock lock(*mutex); + condvar->wait(lock, condition); +} + +void Wait(const std::function& condition, + std::condition_variable* condvar, std::mutex* mutex) { + // This value was empirically derived with some microbenchmark, we don't have + // high confidence in it. + // + // TODO(b/135595069): make this value configurable at runtime. + // I almost wanted to file another bug to ask for experimenting in a more + // principled way to tune this value better, but this would have to be tuned + // on real end-to-end applications and we'd expect different applications to + // require different tunings. So the more important point is the need for + // this to be controllable by the application. + // + // That this value means that we may be sleeping substantially longer + // than a scheduler timeslice's duration is not necessarily surprising. The + // idea is to pick up quickly new work after having finished the previous + // workload. When it's new work within the same GEMM as the previous work, the + // time interval that we might be busy-waiting is very small, so for that + // purpose it would be more than enough to sleep for 1 ms. + // That is all what we would observe on a GEMM benchmark. However, in a real + // application, after having finished a GEMM, we might do unrelated work for + // a little while, then start on a new GEMM. In that case the wait interval + // may be a little longer. There may also not be another GEMM for a long time, + // in which case we'll end up passively waiting below. + const Duration spin_duration = DurationFromMilliseconds(2); + Wait(condition, spin_duration, condvar, mutex); +} + +} // namespace ruy diff --git a/ruy/wait.h b/ruy/wait.h new file mode 100644 index 0000000..900ec8d --- /dev/null +++ b/ruy/wait.h @@ -0,0 +1,73 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_WAIT_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_WAIT_H_ + +#include // NOLINT(build/c++11) +#include +#include // NOLINT(build/c++11) + +#include "ruy/time.h" + +namespace ruy { + +// Waits until some evaluation of `condition` has returned true. +// +// There is no guarantee that calling `condition` again after this function +// has returned would still return true. The only +// contract is that at some point during the execution of that function, +// `condition` has returned true. +// +// First does some spin-waiting for the specified `spin_duration`, +// then falls back to passive waiting for the given condvar, guarded +// by the given mutex. At this point it will try to acquire the mutex lock, +// around the waiting on the condition variable. +// Therefore, this function expects that the calling thread hasn't already +// locked the mutex before calling it. +// This function will always release the mutex lock before returning. +// +// The idea of doing some initial spin-waiting is to help get +// better and more consistent multithreading benefits for small GEMM sizes. +// Spin-waiting help ensuring that if we need to wake up soon after having +// started waiting, then we can wake up quickly (as opposed to, say, +// having to wait to be scheduled again by the OS). On the other hand, +// we must still eventually revert to passive waiting for longer waits +// (e.g. worker threads having finished a GEMM and waiting until the next GEMM) +// so as to avoid permanently spinning. +// +// In situations where other threads might have more useful things to do with +// these CPU cores than our spin-waiting, it may be best to reduce the value +// of `spin_duration`. Setting it to zero disables the spin-waiting entirely. +// +// There is a risk that the std::function used here might use a heap allocation +// to store its context. The expected usage pattern is that these functions' +// contexts will consist of a single pointer value (typically capturing only +// [this]), and that in this case the std::function implementation will use +// inline storage, avoiding a heap allocation. However, we can't effectively +// guard that assumption, and that's not a big concern anyway because the +// latency of a small heap allocation is probably low compared to the intrinsic +// latency of what this Wait function does. +void Wait(const std::function& condition, const Duration& spin_duration, + std::condition_variable* condvar, std::mutex* mutex); + +// Convenience overload using a default `spin_duration`. +// TODO(benoitjacob): let this be controlled from the ruy API. +void Wait(const std::function& condition, + std::condition_variable* condvar, std::mutex* mutex); + +} // namespace ruy + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_WAIT_H_ diff --git a/ruy/wait_test.cc b/ruy/wait_test.cc new file mode 100644 index 0000000..f0548f9 --- /dev/null +++ b/ruy/wait_test.cc @@ -0,0 +1,117 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/wait.h" + +#include +#include // NOLINT(build/c++11) +#include // NOLINT(build/c++11) +#include // NOLINT(build/c++11) + +#include "testing/base/public/gunit.h" +#include "ruy/platform.h" + +namespace ruy { +namespace { + +// Thread taking a `value` atomic counter and incrementing it until it equals +// `end_value`, then notifying the condition variable as long as +// `value == end_value`. If `end_value` is increased, it will then resume +// incrementing `value`, etc. Terminates if `end_value == -1`. +class ThreadCountingUpToValue { + public: + ThreadCountingUpToValue(const std::atomic& end_value, + std::atomic* value, + std::condition_variable* condvar, std::mutex* mutex) + : end_value_(end_value), + value_(value), + condvar_(condvar), + mutex_(mutex) {} + void operator()() { + // end_value_==-1 is how the master thread will tell us it's OK to terminate + while (end_value_.load() != -1) { + // wait until end_value is set to a higher value + while (value_->load() == end_value_.load()) { + } + // increment value as long as it's lower than end_value + while (value_->fetch_add(1) < end_value_.load() - 1) { + } + // when value has reached end_value, notify the master thread. + while (value_->load() == end_value_.load()) { + std::lock_guard lock(*mutex_); + condvar_->notify_all(); + } + } + } + + private: + const std::atomic& end_value_; + std::atomic* value_; + std::condition_variable* condvar_; + std::mutex* mutex_; +}; + +void WaitTest(const Duration& spin_duration, const Duration& delay) { +#if RUY_PLATFORM(EMSCRIPTEN) + // b/139927184, std::thread constructor raises exception + return; +#endif + std::condition_variable condvar; + std::mutex mutex; + std::atomic value(0); + std::atomic end_value(0); + ThreadCountingUpToValue thread_callable(end_value, &value, &condvar, &mutex); + std::thread thread(thread_callable); + std::this_thread::sleep_for(delay); + for (int i = 1; i < 10; i++) { + end_value.store(1000 * i); + const auto& condition = [&value, &end_value]() { + return value.load() == end_value.load(); + }; + ruy::Wait(condition, spin_duration, &condvar, &mutex); + EXPECT_EQ(value.load(), end_value.load()); + } + end_value.store(-1); + thread.join(); +} + +TEST(WaitTest, WaitTestNoSpin) { + WaitTest(DurationFromSeconds(0), DurationFromSeconds(0)); +} + +TEST(WaitTest, WaitTestSpinOneMicrosecond) { + WaitTest(DurationFromSeconds(1e-6), DurationFromSeconds(0)); +} + +TEST(WaitTest, WaitTestSpinOneMillisecond) { + WaitTest(DurationFromSeconds(1e-3), DurationFromSeconds(0)); +} + +TEST(WaitTest, WaitTestSpinOneSecond) { + WaitTest(DurationFromSeconds(1), DurationFromSeconds(0)); +} + +// Testcase to consistently reproduce the hang in b/139062384. +TEST(WaitTest, WaitTestNoSpinWithDelayBug139062384) { + WaitTest(DurationFromSeconds(0), DurationFromSeconds(1)); +} + +} // namespace +} // namespace ruy + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} -- cgit v1.2.3