From f7ea583082c670103fb2cebd6035b944c71d64c4 Mon Sep 17 00:00:00 2001 From: Benoit Jacob Date: Fri, 27 Mar 2020 21:58:51 -0400 Subject: Move ruy's code to a ruy/ subdirectory. The motivation is that having source files in the repository root runs into a number of corner cases with copybara setups and with external CMake build systems, so enclosing all code in ruy/ avoids that while generally making our setup much more similar to that of other related projects (TensorFlow, IREE). PiperOrigin-RevId: 303448881 --- BUILD | 954 ---- WORKSPACE | 17 + allocator.cc | 51 - allocator.h | 185 - allocator_test.cc | 103 - benchmark.cc | 196 - block_map.cc | 486 -- block_map.h | 161 - block_map_test.cc | 263 - blocking_counter.cc | 49 - blocking_counter.h | 62 - build_defs.bzl | 40 - check_macros.h | 138 - check_macros_test.cc | 153 - common.h | 73 - context.cc | 109 - context.h | 109 - context_test.cc | 63 - cpu_cache_size.h | 81 - detect_arm.cc | 73 - detect_arm.h | 29 - detect_x86.cc | 101 - detect_x86.h | 49 - dispatch.h | 482 -- example.cc | 136 - example_advanced.cc | 83 - have_built_path_for.h | 32 - have_built_path_for_avx2.cc | 35 - have_built_path_for_avx512.cc | 35 - have_built_path_for_avxvnni.cc | 39 - have_built_path_for_sse42.cc | 39 - internal_matrix.h | 388 -- kernel.h | 31 - kernel_arm.h | 211 - kernel_arm32.cc | 2499 --------- kernel_arm64.cc | 7835 ----------------------------- kernel_avx2.cc | 1664 ------ kernel_avx512.cc | 1820 ------- kernel_avxvnni.cc | 435 -- kernel_common.h | 481 -- kernel_sse42.cc | 428 -- kernel_x86.h | 222 - matrix.h | 182 - opt_set.h | 51 - pack.h | 98 - pack_arm.cc | 1936 ------- pack_arm.h | 497 -- pack_avx2.cc | 816 --- pack_avx512.cc | 693 --- pack_avxvnni.cc | 478 -- pack_common.h | 246 - pack_sse42.cc | 471 -- pack_x86.h | 461 -- path.h | 162 - platform.h | 156 - pmu.cc | 281 -- pmu.h | 44 - prepack.h | 108 - prepacked_cache.cc | 82 - prepacked_cache.h | 130 - prepacked_cache_test.cc | 210 - profiler/BUILD | 52 - profiler/README.md | 149 - profiler/instrumentation.cc | 130 - profiler/instrumentation.h | 203 - profiler/profiler.cc | 109 - profiler/profiler.h | 106 - profiler/test.cc | 167 - profiler/test_instrumented_library.cc | 59 - profiler/test_instrumented_library.h | 23 - profiler/treeview.cc | 248 - profiler/treeview.h | 130 - ruy.h | 42 - ruy/BUILD | 954 ++++ ruy/allocator.cc | 51 + ruy/allocator.h | 185 + ruy/allocator_test.cc | 103 + ruy/benchmark.cc | 196 + ruy/block_map.cc | 486 ++ ruy/block_map.h | 161 + ruy/block_map_test.cc | 263 + ruy/blocking_counter.cc | 49 + ruy/blocking_counter.h | 62 + ruy/build_defs.bzl | 54 + ruy/build_defs.bzl.opensource | 40 + ruy/check_macros.h | 138 + ruy/check_macros_test.cc | 153 + ruy/common.h | 73 + ruy/context.cc | 109 + ruy/context.h | 109 + ruy/context_test.cc | 63 + ruy/cpu_cache_size.h | 81 + ruy/detect_arm.cc | 73 + ruy/detect_arm.h | 29 + ruy/detect_x86.cc | 101 + ruy/detect_x86.h | 49 + ruy/dispatch.h | 482 ++ ruy/example.cc | 136 + ruy/example_advanced.cc | 83 + ruy/have_built_path_for.h | 32 + ruy/have_built_path_for_avx2.cc | 35 + ruy/have_built_path_for_avx512.cc | 35 + ruy/have_built_path_for_avxvnni.cc | 39 + ruy/have_built_path_for_sse42.cc | 39 + ruy/internal_matrix.h | 388 ++ ruy/kernel.h | 31 + ruy/kernel_arm.h | 211 + ruy/kernel_arm32.cc | 2499 +++++++++ ruy/kernel_arm64.cc | 7835 +++++++++++++++++++++++++++++ ruy/kernel_avx2.cc | 1664 ++++++ ruy/kernel_avx512.cc | 1820 +++++++ ruy/kernel_avxvnni.cc | 435 ++ ruy/kernel_common.h | 481 ++ ruy/kernel_sse42.cc | 428 ++ ruy/kernel_x86.h | 222 + ruy/matrix.h | 182 + ruy/opt_set.h | 51 + ruy/pack.h | 98 + ruy/pack_arm.cc | 1936 +++++++ ruy/pack_arm.h | 497 ++ ruy/pack_avx2.cc | 816 +++ ruy/pack_avx512.cc | 693 +++ ruy/pack_avxvnni.cc | 478 ++ ruy/pack_common.h | 246 + ruy/pack_sse42.cc | 471 ++ ruy/pack_x86.h | 461 ++ ruy/path.h | 162 + ruy/platform.h | 156 + ruy/pmu.cc | 281 ++ ruy/pmu.h | 44 + ruy/prepack.h | 108 + ruy/prepacked_cache.cc | 82 + ruy/prepacked_cache.h | 130 + ruy/prepacked_cache_test.cc | 210 + ruy/profiler/BUILD | 52 + ruy/profiler/README.md | 149 + ruy/profiler/instrumentation.cc | 130 + ruy/profiler/instrumentation.h | 203 + ruy/profiler/profiler.cc | 109 + ruy/profiler/profiler.h | 106 + ruy/profiler/test.cc | 167 + ruy/profiler/test_instrumented_library.cc | 59 + ruy/profiler/test_instrumented_library.h | 23 + ruy/profiler/treeview.cc | 248 + ruy/profiler/treeview.h | 130 + ruy/ruy.h | 42 + ruy/ruy_advanced.h | 69 + ruy/ruy_test.bzl | 34 + ruy/ruy_test_ext.bzl | 19 + ruy/ruy_test_ext.bzl.opensource | 7 + ruy/side_pair.h | 64 + ruy/size_util.h | 93 + ruy/size_util_test.cc | 101 + ruy/spec.h | 118 + ruy/test.h | 2125 ++++++++ ruy/test_fast.cc | 110 + ruy/test_slow.cc | 71 + ruy/test_special_specs.cc | 163 + ruy/thread_pool.cc | 200 + ruy/thread_pool.h | 102 + ruy/time.h | 81 + ruy/trace.cc | 325 ++ ruy/trace.h | 73 + ruy/trmul.cc | 401 ++ ruy/trmul.h | 38 + ruy/trmul_params.h | 67 + ruy/tune.cc | 161 + ruy/tune.h | 163 + ruy/tune_test.cc | 53 + ruy/tune_tool.cc | 56 + ruy/wait.cc | 69 + ruy/wait.h | 73 + ruy/wait_test.cc | 117 + ruy_advanced.h | 69 - ruy_test.bzl | 34 - ruy_test_ext.bzl | 7 - side_pair.h | 64 - size_util.h | 93 - size_util_test.cc | 101 - spec.h | 118 - test.h | 2125 -------- test_fast.cc | 110 - test_slow.cc | 71 - test_special_specs.cc | 163 - thread_pool.cc | 200 - thread_pool.h | 102 - time.h | 81 - trace.cc | 325 -- trace.h | 73 - trmul.cc | 401 -- trmul.h | 38 - trmul_params.h | 67 - tune.cc | 161 - tune.h | 163 - tune_test.cc | 53 - tune_tool.cc | 56 - wait.cc | 69 - wait.h | 73 - wait_test.cc | 117 - 199 files changed, 33967 insertions(+), 33877 deletions(-) delete mode 100644 BUILD create mode 100644 WORKSPACE delete mode 100644 allocator.cc delete mode 100644 allocator.h delete mode 100644 allocator_test.cc delete mode 100644 benchmark.cc delete mode 100644 block_map.cc delete mode 100644 block_map.h delete mode 100644 block_map_test.cc delete mode 100644 blocking_counter.cc delete mode 100644 blocking_counter.h delete mode 100644 build_defs.bzl delete mode 100644 check_macros.h delete mode 100644 check_macros_test.cc delete mode 100644 common.h delete mode 100644 context.cc delete mode 100644 context.h delete mode 100644 context_test.cc delete mode 100644 cpu_cache_size.h delete mode 100644 detect_arm.cc delete mode 100644 detect_arm.h delete mode 100644 detect_x86.cc delete mode 100644 detect_x86.h delete mode 100644 dispatch.h delete mode 100644 example.cc delete mode 100644 example_advanced.cc delete mode 100644 have_built_path_for.h delete mode 100644 have_built_path_for_avx2.cc delete mode 100644 have_built_path_for_avx512.cc delete mode 100644 have_built_path_for_avxvnni.cc delete mode 100644 have_built_path_for_sse42.cc delete mode 100644 internal_matrix.h delete mode 100644 kernel.h delete mode 100644 kernel_arm.h delete mode 100644 kernel_arm32.cc delete mode 100644 kernel_arm64.cc delete mode 100644 kernel_avx2.cc delete mode 100644 kernel_avx512.cc delete mode 100644 kernel_avxvnni.cc delete mode 100644 kernel_common.h delete mode 100644 kernel_sse42.cc delete mode 100644 kernel_x86.h delete mode 100644 matrix.h delete mode 100644 opt_set.h delete mode 100644 pack.h delete mode 100644 pack_arm.cc delete mode 100644 pack_arm.h delete mode 100644 pack_avx2.cc delete mode 100644 pack_avx512.cc delete mode 100644 pack_avxvnni.cc delete mode 100644 pack_common.h delete mode 100644 pack_sse42.cc delete mode 100644 pack_x86.h delete mode 100644 path.h delete mode 100644 platform.h delete mode 100644 pmu.cc delete mode 100644 pmu.h delete mode 100644 prepack.h delete mode 100644 prepacked_cache.cc delete mode 100644 prepacked_cache.h delete mode 100644 prepacked_cache_test.cc delete mode 100644 profiler/BUILD delete mode 100644 profiler/README.md delete mode 100644 profiler/instrumentation.cc delete mode 100644 profiler/instrumentation.h delete mode 100644 profiler/profiler.cc delete mode 100644 profiler/profiler.h delete mode 100644 profiler/test.cc delete mode 100644 profiler/test_instrumented_library.cc delete mode 100644 profiler/test_instrumented_library.h delete mode 100644 profiler/treeview.cc delete mode 100644 profiler/treeview.h delete mode 100644 ruy.h create mode 100644 ruy/BUILD create mode 100644 ruy/allocator.cc create mode 100644 ruy/allocator.h create mode 100644 ruy/allocator_test.cc create mode 100644 ruy/benchmark.cc create mode 100644 ruy/block_map.cc create mode 100644 ruy/block_map.h create mode 100644 ruy/block_map_test.cc create mode 100644 ruy/blocking_counter.cc create mode 100644 ruy/blocking_counter.h create mode 100644 ruy/build_defs.bzl create mode 100644 ruy/build_defs.bzl.opensource create mode 100644 ruy/check_macros.h create mode 100644 ruy/check_macros_test.cc create mode 100644 ruy/common.h create mode 100644 ruy/context.cc create mode 100644 ruy/context.h create mode 100644 ruy/context_test.cc create mode 100644 ruy/cpu_cache_size.h create mode 100644 ruy/detect_arm.cc create mode 100644 ruy/detect_arm.h create mode 100644 ruy/detect_x86.cc create mode 100644 ruy/detect_x86.h create mode 100644 ruy/dispatch.h create mode 100644 ruy/example.cc create mode 100644 ruy/example_advanced.cc create mode 100644 ruy/have_built_path_for.h create mode 100644 ruy/have_built_path_for_avx2.cc create mode 100644 ruy/have_built_path_for_avx512.cc create mode 100644 ruy/have_built_path_for_avxvnni.cc create mode 100644 ruy/have_built_path_for_sse42.cc create mode 100644 ruy/internal_matrix.h create mode 100644 ruy/kernel.h create mode 100644 ruy/kernel_arm.h create mode 100644 ruy/kernel_arm32.cc create mode 100644 ruy/kernel_arm64.cc create mode 100644 ruy/kernel_avx2.cc create mode 100644 ruy/kernel_avx512.cc create mode 100644 ruy/kernel_avxvnni.cc create mode 100644 ruy/kernel_common.h create mode 100644 ruy/kernel_sse42.cc create mode 100644 ruy/kernel_x86.h create mode 100644 ruy/matrix.h create mode 100644 ruy/opt_set.h create mode 100644 ruy/pack.h create mode 100644 ruy/pack_arm.cc create mode 100644 ruy/pack_arm.h create mode 100644 ruy/pack_avx2.cc create mode 100644 ruy/pack_avx512.cc create mode 100644 ruy/pack_avxvnni.cc create mode 100644 ruy/pack_common.h create mode 100644 ruy/pack_sse42.cc create mode 100644 ruy/pack_x86.h create mode 100644 ruy/path.h create mode 100644 ruy/platform.h create mode 100644 ruy/pmu.cc create mode 100644 ruy/pmu.h create mode 100644 ruy/prepack.h create mode 100644 ruy/prepacked_cache.cc create mode 100644 ruy/prepacked_cache.h create mode 100644 ruy/prepacked_cache_test.cc create mode 100644 ruy/profiler/BUILD create mode 100644 ruy/profiler/README.md create mode 100644 ruy/profiler/instrumentation.cc create mode 100644 ruy/profiler/instrumentation.h create mode 100644 ruy/profiler/profiler.cc create mode 100644 ruy/profiler/profiler.h create mode 100644 ruy/profiler/test.cc create mode 100644 ruy/profiler/test_instrumented_library.cc create mode 100644 ruy/profiler/test_instrumented_library.h create mode 100644 ruy/profiler/treeview.cc create mode 100644 ruy/profiler/treeview.h create mode 100644 ruy/ruy.h create mode 100644 ruy/ruy_advanced.h create mode 100644 ruy/ruy_test.bzl create mode 100644 ruy/ruy_test_ext.bzl create mode 100644 ruy/ruy_test_ext.bzl.opensource create mode 100644 ruy/side_pair.h create mode 100644 ruy/size_util.h create mode 100644 ruy/size_util_test.cc create mode 100644 ruy/spec.h create mode 100644 ruy/test.h create mode 100644 ruy/test_fast.cc create mode 100644 ruy/test_slow.cc create mode 100644 ruy/test_special_specs.cc create mode 100644 ruy/thread_pool.cc create mode 100644 ruy/thread_pool.h create mode 100644 ruy/time.h create mode 100644 ruy/trace.cc create mode 100644 ruy/trace.h create mode 100644 ruy/trmul.cc create mode 100644 ruy/trmul.h create mode 100644 ruy/trmul_params.h create mode 100644 ruy/tune.cc create mode 100644 ruy/tune.h create mode 100644 ruy/tune_test.cc create mode 100644 ruy/tune_tool.cc create mode 100644 ruy/wait.cc create mode 100644 ruy/wait.h create mode 100644 ruy/wait_test.cc delete mode 100644 ruy_advanced.h delete mode 100644 ruy_test.bzl delete mode 100644 ruy_test_ext.bzl delete mode 100644 side_pair.h delete mode 100644 size_util.h delete mode 100644 size_util_test.cc delete mode 100644 spec.h delete mode 100644 test.h delete mode 100644 test_fast.cc delete mode 100644 test_slow.cc delete mode 100644 test_special_specs.cc delete mode 100644 thread_pool.cc delete mode 100644 thread_pool.h delete mode 100644 time.h delete mode 100644 trace.cc delete mode 100644 trace.h delete mode 100644 trmul.cc delete mode 100644 trmul.h delete mode 100644 trmul_params.h delete mode 100644 tune.cc delete mode 100644 tune.h delete mode 100644 tune_test.cc delete mode 100644 tune_tool.cc delete mode 100644 wait.cc delete mode 100644 wait.h delete mode 100644 wait_test.cc diff --git a/BUILD b/BUILD deleted file mode 100644 index 9d331b8..0000000 --- a/BUILD +++ /dev/null @@ -1,954 +0,0 @@ -# Ruy is not BLAS - -load(":build_defs.bzl", "ruy_copts_avx2", "ruy_copts_avxvnni", "ruy_copts_base", "ruy_copts_skylake", "ruy_copts_sse42") -load(":ruy_test_ext.bzl", "ruy_test_ext_defines", "ruy_test_ext_deps") -load(":ruy_test.bzl", "ruy_benchmark", "ruy_test") - -package( - default_visibility = ["//visibility:public"], - licenses = ["notice"], # Apache 2.0 -) - -config_setting( - name = "windows", - values = {"cpu": "x64_windows"}, -) - -config_setting( - name = "armeabi-v7a", - values = {"cpu": "armeabi-v7a"}, -) - -config_setting( - name = "x86_64", - values = {"cpu": "k8"}, -) - -config_setting( - name = "optimized", - values = { - "compilation_mode": "opt", - }, - visibility = ["//visibility:public"], -) - -cc_library( - name = "platform", - hdrs = ["platform.h"], - copts = ruy_copts_base(), -) - -cc_library( - name = "check_macros", - hdrs = ["check_macros.h"], - copts = ruy_copts_base(), -) - -cc_test( - name = "check_macros_test", - srcs = ["check_macros_test.cc"], - copts = ruy_copts_base(), - deps = [ - ":check_macros", - "@com_google_googletest//:gtest", - ], -) - -cc_library( - name = "opt_set", - hdrs = ["opt_set.h"], - copts = ruy_copts_base(), -) - -cc_library( - name = "time", - hdrs = ["time.h"], - copts = ruy_copts_base(), -) - -cc_library( - name = "wait", - srcs = ["wait.cc"], - hdrs = ["wait.h"], - copts = ruy_copts_base(), - deps = [":time"], -) - -cc_test( - name = "wait_test", - srcs = ["wait_test.cc"], - deps = [ - ":platform", - ":wait", - "@com_google_googletest//:gtest", - ], -) - -cc_library( - name = "size_util", - hdrs = ["size_util.h"], - copts = ruy_copts_base(), - deps = [":check_macros"], -) - -cc_test( - name = "size_util_test", - srcs = ["size_util_test.cc"], - deps = [ - ":size_util", - "@com_google_googletest//:gtest", - ], -) - -cc_library( - name = "tune", - srcs = [ - "tune.cc", - ], - hdrs = [ - "tune.h", - ], - copts = ruy_copts_base(), - deps = [ - ":opt_set", - ":platform", - ":time", - ], -) - -cc_library( - name = "prepacked_cache", - srcs = [ - "prepacked_cache.cc", - ], - hdrs = [ - "prepacked_cache.h", - ], - copts = ruy_copts_base(), - deps = [ - ":allocator", - ":matrix", - ":opt_set", - ":platform", - ":time", - "//profiler:instrumentation", - ], -) - -cc_test( - name = "tune_test", - srcs = ["tune_test.cc"], - deps = [ - ":tune", - "@com_google_googletest//:gtest", - ], -) - -cc_test( - name = "prepacked_cache_test", - srcs = ["prepacked_cache_test.cc"], - deps = [ - ":prepacked_cache", - ":ruy", - ":time", - "@com_google_googletest//:gtest", - ], -) - -cc_binary( - name = "tune_tool", - srcs = ["tune_tool.cc"], - deps = [ - ":tune", - ], -) - -cc_library( - name = "allocator", - srcs = [ - "allocator.cc", - ], - hdrs = [ - "allocator.h", - ], - copts = ruy_copts_base(), - deps = [ - ":check_macros", - ":size_util", - ], -) - -cc_test( - name = "allocator_test", - srcs = ["allocator_test.cc"], - deps = [ - ":allocator", - "@com_google_googletest//:gtest", - ], -) - -cc_library( - name = "side_pair", - hdrs = ["side_pair.h"], - copts = ruy_copts_base(), - deps = [":check_macros"], -) - -cc_library( - name = "block_map", - srcs = [ - "block_map.cc", - ], - hdrs = [ - "block_map.h", - ], - copts = ruy_copts_base(), - deps = [ - ":check_macros", - ":opt_set", - ":path", - ":side_pair", - ":size_util", - "//profiler:instrumentation", - ], -) - -cc_test( - name = "block_map_test", - srcs = ["block_map_test.cc"], - deps = [ - ":block_map", - ":cpu_cache_size", - ":path", - ":side_pair", - "@com_google_googletest//:gtest", - ], -) - -cc_library( - name = "blocking_counter", - srcs = [ - "blocking_counter.cc", - ], - hdrs = [ - "blocking_counter.h", - ], - copts = ruy_copts_base(), - deps = [ - ":check_macros", - ":wait", - ], -) - -cc_library( - name = "thread_pool", - srcs = [ - "thread_pool.cc", - ], - hdrs = [ - "thread_pool.h", - ], - copts = ruy_copts_base(), - deps = [ - ":blocking_counter", - ":check_macros", - ":wait", - ], -) - -cc_library( - name = "detect_arm", - srcs = [ - "detect_arm.cc", - ], - hdrs = [ - "detect_arm.h", - ], - copts = ruy_copts_base(), -) - -cc_library( - name = "detect_x86", - srcs = [ - "detect_x86.cc", - ], - hdrs = [ - "detect_x86.h", - ], - copts = ruy_copts_base(), - deps = [ - ":platform", - ], -) - -cc_library( - name = "path", - hdrs = ["path.h"], - copts = ruy_copts_base(), - deps = [ - ":platform", - ":size_util", - ], -) - -cc_library( - name = "cpu_cache_size", - hdrs = ["cpu_cache_size.h"], - copts = ruy_copts_base(), - deps = [ - ":path", - ":platform", - ], -) - -cc_library( - name = "trace", - srcs = [ - "trace.cc", - ], - hdrs = [ - "trace.h", - ], - copts = ruy_copts_base(), - deps = [ - ":block_map", - ":check_macros", - ":side_pair", - ":time", - ], -) - -cc_library( - name = "matrix", - hdrs = ["matrix.h"], - copts = ruy_copts_base(), - deps = [":check_macros"], -) - -cc_library( - name = "spec", - hdrs = ["spec.h"], - copts = ruy_copts_base(), - deps = [ - ":cpu_cache_size", - ":matrix", - ], -) - -cc_library( - name = "internal_matrix", - hdrs = ["internal_matrix.h"], - copts = ruy_copts_base(), - deps = [ - ":check_macros", - ":common", - ":matrix", - ":size_util", - ], -) - -cc_library( - name = "common", - hdrs = [ - "common.h", - ], - copts = ruy_copts_base(), - deps = [ - ":check_macros", - ":matrix", - ":opt_set", - ":path", - ":platform", - ], -) - -cc_library( - name = "kernel_common", - hdrs = [ - "kernel.h", - "kernel_arm.h", - "kernel_common.h", - "kernel_x86.h", - ], - copts = ruy_copts_base(), - deps = [ - ":check_macros", - ":common", - ":internal_matrix", - ":matrix", - ":opt_set", - ":path", - ":platform", - ":side_pair", - ":size_util", - ":spec", - ":tune", - "//profiler:instrumentation", - ], -) - -cc_library( - name = "pack_common", - hdrs = [ - "pack.h", - "pack_arm.h", - "pack_common.h", - "pack_x86.h", - ], - copts = ruy_copts_base(), - deps = [ - ":check_macros", - ":common", - ":internal_matrix", - ":matrix", - ":opt_set", - ":path", - ":platform", - ":tune", - "//profiler:instrumentation", - ], -) - -cc_library( - name = "kernel_arm", - srcs = [ - "kernel_arm32.cc", - "kernel_arm64.cc", - ], - copts = ruy_copts_base(), - deps = [ - ":common", - ":kernel_common", - ":opt_set", - ":platform", - "//profiler:instrumentation", - ], -) - -cc_library( - name = "pack_arm", - srcs = [ - "pack_arm.cc", - ], - copts = ruy_copts_base(), - deps = [ - ":common", - ":opt_set", - ":pack_common", - ":platform", - "//profiler:instrumentation", - ], -) - -# AVX-512 compilation units. -# -# These must use the same compiler options. -RUY_COPTS_BUILT_FOR_AVX512 = ruy_copts_base() + ruy_copts_skylake() - -cc_library( - name = "kernel_avx512", - srcs = [ - "kernel_avx512.cc", - ], - copts = RUY_COPTS_BUILT_FOR_AVX512, - deps = [ - ":check_macros", - ":kernel_common", - ":opt_set", - ":platform", - "//profiler:instrumentation", - ], -) - -cc_library( - name = "pack_avx512", - srcs = [ - "pack_avx512.cc", - ], - copts = RUY_COPTS_BUILT_FOR_AVX512, - deps = [ - ":check_macros", - ":matrix", - ":opt_set", - ":pack_common", - ":path", - ":platform", - "//profiler:instrumentation", - ], -) - -cc_library( - name = "have_built_path_for_avx512", - srcs = [ - "have_built_path_for_avx512.cc", - ], - hdrs = [ - "have_built_path_for.h", - ], - copts = RUY_COPTS_BUILT_FOR_AVX512, - deps = [ - ":opt_set", - ":platform", - ], -) -# End: AVX-512 compilation units. - -# AVX2 compilation units. -# -# These must use the same compiler options. -RUY_COPTS_BUILT_FOR_AVX2 = ruy_copts_base() + ruy_copts_avx2() - -cc_library( - name = "kernel_avx2", - srcs = [ - "kernel_avx2.cc", - ], - copts = RUY_COPTS_BUILT_FOR_AVX2, - deps = [ - ":check_macros", - ":kernel_common", - ":opt_set", - ":platform", - "//profiler:instrumentation", - ], -) - -cc_library( - name = "pack_avx2", - srcs = [ - "pack_avx2.cc", - ], - copts = RUY_COPTS_BUILT_FOR_AVX2, - deps = [ - ":check_macros", - ":matrix", - ":opt_set", - ":pack_common", - ":path", - ":platform", - "//profiler:instrumentation", - ], -) - -cc_library( - name = "have_built_path_for_avx2", - srcs = [ - "have_built_path_for_avx2.cc", - ], - hdrs = [ - "have_built_path_for.h", - ], - copts = RUY_COPTS_BUILT_FOR_AVX2, - deps = [ - ":opt_set", - ":platform", - ], -) -# End: AVX2 compilation units. - -# SSE42 compilation units. -# -# TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. -# Optimization is not finished. In particular the dimensions of the kernel -# blocks can be changed as desired. -# -# These must use the same compiler options. -RUY_COPTS_BUILT_FOR_SSE42 = ruy_copts_base() + ruy_copts_sse42() - -cc_library( - name = "kernel_sse42", - srcs = [ - "kernel_sse42.cc", - ], - copts = RUY_COPTS_BUILT_FOR_SSE42, - deps = [ - ":check_macros", - ":kernel_common", - ":opt_set", - ":platform", - "//profiler:instrumentation", - ], -) - -cc_library( - name = "pack_sse42", - srcs = [ - "pack_sse42.cc", - ], - copts = RUY_COPTS_BUILT_FOR_SSE42, - deps = [ - ":check_macros", - ":matrix", - ":opt_set", - ":pack_common", - ":path", - ":platform", - "//profiler:instrumentation", - ], -) - -cc_library( - name = "have_built_path_for_sse42", - srcs = [ - "have_built_path_for_sse42.cc", - ], - hdrs = [ - "have_built_path_for.h", - ], - copts = RUY_COPTS_BUILT_FOR_SSE42, - deps = [ - ":opt_set", - ":platform", - ], -) -# End: SSE42 compilation units. - -# AVX-VNNI compilation units. -# -# TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. -# Optimization is not finished. In particular the dimensions of the kernel -# blocks can be changed as desired. -# -# These must use the same compiler options. -RUY_COPTS_BUILT_FOR_AVX_VNNI = ruy_copts_base() + ruy_copts_avxvnni() - -cc_library( - name = "kernel_avxvnni", - srcs = [ - "kernel_avxvnni.cc", - ], - copts = RUY_COPTS_BUILT_FOR_AVX_VNNI, - deps = [ - ":check_macros", - ":kernel_common", - ":opt_set", - ":platform", - "//profiler:instrumentation", - ], -) - -cc_library( - name = "pack_avxvnni", - srcs = [ - "pack_avxvnni.cc", - ], - copts = RUY_COPTS_BUILT_FOR_AVX_VNNI, - deps = [ - ":check_macros", - ":matrix", - ":opt_set", - ":pack_common", - ":path", - ":platform", - "//profiler:instrumentation", - ], -) - -cc_library( - name = "have_built_path_for_avxvnni", - srcs = [ - "have_built_path_for_avxvnni.cc", - ], - hdrs = [ - "have_built_path_for.h", - ], - copts = RUY_COPTS_BUILT_FOR_AVX_VNNI, - deps = [ - ":opt_set", - ":platform", - ], -) -# End: AVX-VNNI compilation units. - -cc_library( - name = "kernel", - hdrs = [ - "kernel.h", - "kernel_common.h", - ], - copts = ruy_copts_base(), - deps = [ - ":check_macros", - ":common", - ":internal_matrix", - ":kernel_arm", # fixdeps: keep - ":kernel_avx2", # fixdeps: keep - ":kernel_avx512", # fixdeps: keep - ":kernel_avxvnni", # fixdeps: keep - ":kernel_common", - ":kernel_sse42", # fixdeps: keep - ":matrix", - ":opt_set", - ":path", - ":platform", - ":side_pair", - ":size_util", - ":spec", - ":tune", - "//profiler:instrumentation", - ], -) - -cc_library( - name = "pack", - hdrs = [ - "pack.h", - "pack_common.h", - ], - copts = ruy_copts_base(), - deps = [ - ":check_macros", - ":common", - ":internal_matrix", - ":matrix", - ":opt_set", - ":pack_arm", # fixdeps: keep - ":pack_avx2", # fixdeps: keep - ":pack_avx512", # fixdeps: keep - ":pack_avxvnni", # fixdeps: keep - ":pack_common", - ":pack_sse42", # fixdeps: keep - ":path", - ":platform", - ":tune", - "//profiler:instrumentation", - ], -) - -cc_library( - name = "have_built_path_for", - hdrs = [ - "have_built_path_for.h", - ], - deps = [ - ":have_built_path_for_avx2", - ":have_built_path_for_avx512", - ":have_built_path_for_avxvnni", - ":have_built_path_for_sse42", - ":platform", - ], -) - -cc_library( - name = "context", - srcs = [ - "context.cc", - ], - hdrs = [ - "context.h", - ], - copts = ruy_copts_base(), - deps = [ - ":allocator", - ":check_macros", - ":detect_arm", - ":detect_x86", - ":have_built_path_for", - ":path", - ":platform", - ":prepacked_cache", - ":thread_pool", - ":trace", - ":tune", - ], -) - -cc_test( - name = "context_test", - srcs = ["context_test.cc"], - deps = [ - ":context", - ":path", - ":platform", - "@com_google_googletest//:gtest", - ], -) - -cc_library( - name = "trmul_params", - hdrs = ["trmul_params.h"], - copts = ruy_copts_base(), - deps = [ - ":internal_matrix", - ":side_pair", - ":tune", - ], -) - -cc_library( - name = "trmul", - srcs = ["trmul.cc"], - hdrs = ["trmul.h"], - copts = ruy_copts_base(), - deps = [ - ":allocator", - ":block_map", - ":check_macros", - ":common", - ":context", - ":internal_matrix", - ":matrix", - ":opt_set", - ":side_pair", - ":size_util", - ":spec", - ":thread_pool", - ":trace", - ":trmul_params", - ":tune", - "//profiler:instrumentation", - ], -) - -# The main library. -cc_library( - name = "ruy", - srcs = [ - "dispatch.h", - "prepack.h", - ], - hdrs = [ - "ruy.h", - "ruy_advanced.h", - ], - copts = ruy_copts_base(), - deps = [ - ":check_macros", - ":common", - ":context", - ":internal_matrix", - ":kernel", - ":matrix", - ":opt_set", - ":pack", - ":path", - ":prepacked_cache", - ":side_pair", - ":size_util", - ":spec", - ":trmul", - ":trmul_params", - ":tune", - "//profiler:instrumentation", - ], -) - -# Usage examples. -cc_binary( - name = "example", - srcs = ["example.cc"], - deps = [":ruy"], -) - -# Usage examples of the advanced API. -cc_binary( - name = "example_advanced", - srcs = ["example_advanced.cc"], - deps = [":ruy"], -) - -# Small library to query PMU counters, for benchmark only -cc_library( - name = "pmu", - testonly = True, - srcs = ["pmu.cc"], - hdrs = ["pmu.h"], - copts = ruy_copts_base(), - deps = [":check_macros"], -) - -# Testing framework. -cc_library( - name = "test_lib", - testonly = True, - hdrs = ["test.h"], - copts = ruy_copts_base(), - # need defines, not copts, because it's controlling a header, test.h - defines = ruy_test_ext_defines(), - linkopts = select({ - ":windows": [], - "//conditions:default": ["-lm"], - }), - deps = [ - ":matrix", - ":pmu", - ":ruy", - ":spec", - ":time", - "@com_google_googletest//:gtest", - ":platform", - "//profiler:profiler", - ] + ruy_test_ext_deps(), -) - -ruy_benchmark( - name = "benchmark", - srcs = ["benchmark.cc"], - copts = ruy_copts_base(), - lhs_rhs_accum_dst = [ - ("f32", "f32", "f32", "f32"), - ("u8", "u8", "i32", "u8"), - ("i8", "i8", "i32", "u8"), - ("i8", "i8", "i32", "i8"), - ("u8", "u8", "i32", "i16"), - ("i8", "i8", "i32", "i32"), - ], - deps = [ - ":test_lib", - "//profiler:instrumentation", - ], -) - -ruy_test( - name = "test_fast", - srcs = ["test_fast.cc"], - copts = ruy_copts_base(), - lhs_rhs_accum_dst = [ - ("f32", "f32", "f32", "f32"), - ("f64", "f32", "f64", "f32"), - ("f32", "f64", "f64", "f64"), - ("u8", "u8", "i32", "u8"), - ("i8", "i8", "i32", "i8"), - ("i8", "u8", "i32", "i8"), - ("u8", "u8", "i32", "i16"), - ("i8", "i8", "i32", "i32"), - ("i8", "u8", "i32", "i32"), - ], - deps = [ - "@com_google_googletest//:gtest_main", - ":test_lib", - ], -) - -ruy_test( - name = "test_slow", - srcs = ["test_slow.cc"], - copts = ruy_copts_base(), - lhs_rhs_accum_dst = [ - ("f32", "f32", "f32", "f32"), - ("u8", "u8", "i32", "u8"), - ("i8", "i8", "i32", "i8"), - ("u8", "u8", "i32", "i16"), - ("i8", "i8", "i32", "i32"), - ], - tags = ["slow"], - deps = [ - "@com_google_googletest//:gtest_main", - ":test_lib", - ], -) - -ruy_test( - name = "test_special_specs", - srcs = ["test_special_specs.cc"], - copts = ruy_copts_base(), - lhs_rhs_accum_dst = [ - ("f32", "f32", "f32", "f32"), - ("u8", "u8", "i32", "u8"), - ("u8", "u8", "i32", "i16"), - ], - deps = [ - "@com_google_googletest//:gtest_main", - ":test_lib", - ], -) diff --git a/WORKSPACE b/WORKSPACE new file mode 100644 index 0000000..8364d80 --- /dev/null +++ b/WORKSPACE @@ -0,0 +1,17 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Workspace file for the Ruy project. + +workspace(name = "com_google_ruy") diff --git a/allocator.cc b/allocator.cc deleted file mode 100644 index a2e596a..0000000 --- a/allocator.cc +++ /dev/null @@ -1,51 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "allocator.h" - -#include -#include - -#ifdef _WIN32 -#include -#endif - -namespace ruy { - -namespace detail { - -void *SystemAlignedAlloc(std::ptrdiff_t num_bytes) { -#ifdef _WIN32 - return _aligned_malloc(num_bytes, kMinimumBlockAlignment); -#else - void *ptr; - if (posix_memalign(&ptr, kMinimumBlockAlignment, num_bytes)) { - return nullptr; - } - return ptr; -#endif -} - -void SystemAlignedFree(void *ptr) { -#ifdef _WIN32 - _aligned_free(ptr); -#else - free(ptr); -#endif -} - -} // namespace detail - -} // namespace ruy diff --git a/allocator.h b/allocator.h deleted file mode 100644 index e2d31e4..0000000 --- a/allocator.h +++ /dev/null @@ -1,185 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_ALLOCATOR_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_ALLOCATOR_H_ - -#include -#include -#include -#include - -#include "check_macros.h" -#include "size_util.h" - -namespace ruy { - -namespace detail { - -inline void* VoidPtrAdd(void* p, std::ptrdiff_t offset) { - RUY_DCHECK(p); - std::uintptr_t addr = reinterpret_cast(p) + offset; - return reinterpret_cast(addr); -} - -// Minimum alignment for blocks. -// -// Considerations: -// - This needs to be at least the alignment of any usual data type. -// - It's useful that this is at least the size of a cache line to limit -// possible cache side effects (if only on performance behavior). -// - It's useful that this is at least the size of SIMD registers, as -// some SIMD instruction sets have at least performance behavior -// differences (e.g. NEON) or even different requirements (e.g. SSE) -// based on that. -// - It's useful that this is at least the size of an "exclusive reservation -// granule" on ARM, meaning that if we use this Allocator to allocate -// an atomic variable, there will be no side effects from other things -// contending for exclusive/atomic memory accesses to it. While the -// ARM reference manual mentions that this granule size may be as large -// as 2048 bytes, in practice we observe it to be 64 bytes. It can -// be queried cheaply, at runtime, from userspace, if needed. -static constexpr std::ptrdiff_t kMinimumBlockAlignment = 64; - -// Primitive allocation functions obtaining aligned memory from the -// operating system. -void* SystemAlignedAlloc(std::ptrdiff_t num_bytes); -void SystemAlignedFree(void* ptr); - -// Specialized allocator designed to converge to a steady-state where all -// allocations are bump-ptr allocations from an already-allocated buffer. -// -// To support these constraints, this allocator only supports two -// operations. -// - AllocateAlignedBytes: allocates a pointer to storage of a specified -// size, which must be aligned to kMinimumBlockAlignment. -// - FreeAll: frees all previous allocations (but retains the internal -// buffer to minimize future calls into the system allocator). -// -// This class is specialized for supporting just those two operations -// under this specific steady-state usage pattern. Extending this class -// with new allocation interfaces that don't fit that pattern is probably not -// the right choice. Instead, build a new class on top of -// SystemAlignedAlloc/SystemAlignedFree. -// -// All operations happen on aligned blocks for simplicity. -class AlignedAllocator { - public: - void operator=(const AlignedAllocator&) = delete; - ~AlignedAllocator() { - FreeAll(); - SystemAlignedFree(ptr_); - } - - void* AllocateAlignedBytes(std::ptrdiff_t num_bytes) { - RUY_DCHECK_GT(num_bytes, 0); - RUY_DCHECK((num_bytes & (kMinimumBlockAlignment - 1)) == 0); - if (void* p = AllocateFast(num_bytes)) { - return p; - } - return AllocateSlow(num_bytes); - } - - void FreeAll() { - current_ = 0; - if (fallback_blocks_.empty()) { - return; - } - - // No rounding-up of the size means linear instead of logarithmic - // bound on the number of allocation in some worst-case calling patterns. - // This is considered worth it because minimizing memory usage is important - // and actual calling patterns in applications that we care about still - // reach the no-further-allocations steady state in a small finite number - // of iterations. - std::ptrdiff_t new_size = size_ + fallback_blocks_total_size_; - SystemAlignedFree(ptr_); - ptr_ = SystemAlignedAlloc(new_size); - size_ = new_size; - - for (void* p : fallback_blocks_) { - SystemAlignedFree(p); - } - fallback_blocks_.clear(); - fallback_blocks_total_size_ = 0; - } - - private: - void* AllocateFast(std::ptrdiff_t num_bytes) { - if (current_ + num_bytes > size_) { - return nullptr; - } - void* ret = VoidPtrAdd(ptr_, current_); - current_ += num_bytes; - return ret; - } - - void* AllocateSlow(std::ptrdiff_t num_bytes) { - void* p = SystemAlignedAlloc(num_bytes); - fallback_blocks_total_size_ += num_bytes; - fallback_blocks_.push_back(p); - return p; - } - - // Theory of operation: - // - // - ptr_, current_, and size_ implement a basic bump-ptr allocator. - // - // - in AllocateAlignedBytes, the fast path is just a bump-ptr - // allocation. If our bump-ptr allocator doesn't have enough space for an - // allocation, then we allocate a block from the system allocator to - // service the allocation request. We save that block in fallback_blocks_ - // and track the total size of the fallback blocks in - // fallback_blocks_total_size_. - // - // - in FreeAll, the fast path just resets the bump-ptr allocator. If - // there are any fallback blocks, we free them and reallocate the - // bump-ptr allocator's buffer so that the next sequence of allocations - // will hopefully not need any fallback blocks. - void* ptr_ = nullptr; - std::ptrdiff_t current_ = 0; - std::ptrdiff_t size_ = 0; - std::vector fallback_blocks_; - std::ptrdiff_t fallback_blocks_total_size_ = 0; -}; - -} // namespace detail - -// The main Allocator class, with a convenient interface for allocating a -// typed buffer. -class Allocator { - public: - void* AllocateBytes(std::ptrdiff_t num_bytes) { - if (num_bytes == 0) { - return nullptr; - } - return aligned.AllocateAlignedBytes( - round_up_pot(num_bytes, detail::kMinimumBlockAlignment)); - } - template - void Allocate(std::ptrdiff_t count, Pointer* out) { - using T = typename std::pointer_traits::element_type; - *out = static_cast(AllocateBytes(count * sizeof(T))); - } - - void FreeAll() { aligned.FreeAll(); } - - private: - detail::AlignedAllocator aligned; -}; - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_ALLOCATOR_H_ diff --git a/allocator_test.cc b/allocator_test.cc deleted file mode 100644 index 51707bd..0000000 --- a/allocator_test.cc +++ /dev/null @@ -1,103 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "allocator.h" - -#include "testing/base/public/gunit.h" - -namespace ruy { -namespace { - -TEST(AllocatorTest, ReturnsValidMemory) { - Allocator allocator; - int *p; - allocator.Allocate(1, &p); - ASSERT_NE(p, nullptr); - - // If this is bogus memory, ASan will cause this test to fail. - *p = 42; - - allocator.FreeAll(); -} - -TEST(AllocatorTest, NoLeak) { - Allocator allocator; - // Allocate and free some ridiculously large total amount of memory, so - // that a leak will hopefully cause some sort of resource exhaustion. - // - // Despite the large number of allocations, this test is actually quite - // fast, since our fast-path allocation logic is very fast. - constexpr int kNumAllocations = 100 * 1024; - constexpr int kAllocationSize = 1024 * 1024; - for (int i = 0; i < kNumAllocations; i++) { - char *p; - allocator.Allocate(kAllocationSize, &p); - allocator.FreeAll(); - } -} - -TEST(AllocatorTest, IncreasingSizes) { - Allocator allocator; - // Allocate sizes that increase by small amounts across FreeAll calls. - for (int i = 1; i < 100 * 1024; i++) { - char *p; - allocator.Allocate(i, &p); - allocator.FreeAll(); - } -} - -TEST(AllocatorTest, ManySmallAllocations) { - Allocator allocator; - // Allocate many small allocations between FreeAll calls. - for (int i = 0; i < 10 * 1024; i += 100) { - for (int j = 0; j < i; j++) { - char *p; - allocator.Allocate(1, &p); - } - allocator.FreeAll(); - } -} - -TEST(AllocatorTest, DestructorHandlesMainBumpPtr) { - // This is a white-box test. - Allocator allocator; - allocator.AllocateBytes(1); - allocator.FreeAll(); - // After the call to FreeAll, the allocator will consolidate all of the memory - // into the main bump-ptr allocator's block, which we then expect to be freed - // in the destructor. - // - // We have no test assertions -- we primarily expect that this trigger a leak - // checker and cause the test to fail. -} - -TEST(AllocatorTest, DestructorHandlesFallbackBlocks) { - // This is a white-box test. - Allocator allocator; - // Since we just created the allocator, this will allocate a fallback block, - // which we then expect to be freed in the destructor. - // - // We have no test assertions -- we primarily expect that this trigger a leak - // checker and cause the test to fail. - allocator.AllocateBytes(1); -} - -} // namespace -} // namespace ruy - -int main(int argc, char **argv) { - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); -} diff --git a/benchmark.cc b/benchmark.cc deleted file mode 100644 index ece71e1..0000000 --- a/benchmark.cc +++ /dev/null @@ -1,196 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include - -#include "test.h" - -namespace ruy { - -using LhsScalar = RUY_TEST_LHSSCALAR; -using RhsScalar = RUY_TEST_RHSSCALAR; -using AccumScalar = RUY_TEST_ACCUMSCALAR; -using DstScalar = RUY_TEST_DSTSCALAR; -using TestSetType = - TestSet>; - -struct BenchmarkShape { - int rows; - int depth; - int cols; - int symm_lhs; - int symm_rhs; -}; - -template -std::vector>> BenchmarkRCC( - const BenchmarkShape& shape) { - TestSetType test_set; - test_set.rows = shape.rows; - test_set.depth = shape.depth; - test_set.cols = shape.cols; - test_set.lhs_order = Order::kRowMajor; - test_set.rhs_order = Order::kColMajor; - test_set.dst_order = Order::kColMajor; - test_set.layout_style = LayoutStyle::kPackedLinear; - test_set.benchmark = true; - const int asymmetry_lhs = shape.symm_lhs ? 0 : 1; - const int asymmetry_rhs = shape.symm_rhs ? 0 : 1; - test_set.lhs_zero_point = SymmetricZeroPoint() + asymmetry_lhs; - test_set.rhs_zero_point = SymmetricZeroPoint() + asymmetry_rhs; - test_set.use_specified_zero_points = true; - test_set.perchannel = GetBoolEnvVarOrFalse("PERCHANNEL"); - test_set.benchmark_prepack_lhs = GetBoolEnvVarOrFalse("PREPACK_LHS"); - test_set.benchmark_prepack_rhs = GetBoolEnvVarOrFalse("PREPACK_RHS"); - test_set.Run(); - return std::move(test_set.results); -} - -std::vector ParseCommaSeparatedInts( - const std::string& comma_separated_ints) { - std::vector result; - for (std::size_t pos = 0; pos < comma_separated_ints.size();) { - std::size_t delim_pos = comma_separated_ints.find(',', pos); - if (delim_pos == std::string::npos) { - delim_pos = comma_separated_ints.size(); - } - result.push_back( - std::stoi(comma_separated_ints.substr(pos, delim_pos - pos))); - pos = delim_pos + 1; - } - return result; -} - -void Benchmark() { - const bool symm_lhs = std::is_floating_point::value || - GetBoolEnvVarOrFalse("SYMM_LHS"); - const bool symm_rhs = std::is_floating_point::value || - GetBoolEnvVarOrFalse("SYMM_RHS"); - const bool benchmark_cubic = GetBoolEnvVarOrFalse("RUY_BENCHMARK_CUBIC") || - GetBoolEnvVarOrFalse("RUY_BENCHMARK_CUBIC_LIST"); - const int explicit_rows = GetIntEnvVarOrZero("ROWS"); - const int explicit_cols = GetIntEnvVarOrZero("COLS"); - const int explicit_depth = GetIntEnvVarOrZero("DEPTH"); - - std::vector shapes; - - if (benchmark_cubic) { - std::vector sizes; - const char* benchmark_cubic_list_env = getenv("RUY_BENCHMARK_CUBIC_LIST"); - if (benchmark_cubic_list_env) { - sizes = ParseCommaSeparatedInts(benchmark_cubic_list_env); - } else { - // Often 8 is used for this multiplier, but to check teeny sizes one can - // use 1. - static constexpr int cubic_size_multiplier = 8; - for (int i = 2 * cubic_size_multiplier; - i <= (512 * cubic_size_multiplier); i *= 2) { - sizes.push_back(i); - if (i < (512 * cubic_size_multiplier)) { - sizes.push_back(i * 3 / 2); - } - } - } - for (int i : sizes) { - BenchmarkShape shape; - // Even in cubic mode, one may still override an individual dimension - // to allow testing a batch of rectangular sizes. - shape.rows = explicit_rows ? explicit_rows : i; - shape.cols = explicit_cols ? explicit_cols : i; - shape.depth = explicit_depth ? explicit_depth : i; - shape.symm_lhs = symm_lhs; - shape.symm_rhs = symm_rhs; - shapes.push_back(shape); - } - } else { - BenchmarkShape shape; - shape.rows = explicit_rows; - shape.cols = explicit_cols; - shape.depth = explicit_depth; - if (!shape.rows || !shape.depth || !shape.cols) { - fprintf(stderr, - "Please specify positive sizes with these env vars: ROWS, DEPTH, " - "COLS.\n"); - exit(1); - } - shape.symm_lhs = symm_lhs; - shape.symm_rhs = symm_rhs; - shapes.push_back(shape); - } - - for (int i = 0; i < shapes.size(); i++) { - const auto& shape = shapes[i]; - const auto& results = BenchmarkRCC(shape); - if (i == 0) { - if (benchmark_cubic) { - printf("size"); - for (const auto& result : results) { - if (results.size() > 1) { - printf(",%s:Gop/s", PathName(*result).c_str()); - } else { - printf(",Gop/s"); - } - if (GetBoolEnvVarOrFalse("RUY_BENCHMARK_PMU")) { - printf( - ",l1_refill,l2_refill,l3_refill,l1tlb_refill,l2tlb_refill," - "mispred,frontend_stall,backend_stall"); - } - } - printf("\n"); - } else { - printf("path,shape,Gop/s\n"); - } - fflush(stdout); - } - if (benchmark_cubic) { - printf("%d", shape.rows); - for (const auto& result : results) { - printf(",%.4g", 2.0e-9 * shape.rows * shape.cols * shape.depth / - result->latency); - if (GetBoolEnvVarOrFalse("RUY_BENCHMARK_PMU")) { - printf(",%.3g,%.3g,%.3g,%.3g,%.3g,%.3g,%.3g,%.3g", - result->l1_refill_rate, result->l2_refill_rate, - result->l3_refill_rate, result->l1tlb_refill_rate, - result->l2tlb_refill_rate, result->mispred_rate, - result->frontend_stall_rate, result->backend_stall_rate); - } - } - printf("\n"); - fflush(stdout); - } else { - for (const auto& result : results) { - printf( - "%s,%dx%dx%d,%.4g", PathName(*result).c_str(), shape.rows, - shape.depth, shape.cols, - 2.0e-9 * shape.rows * shape.cols * shape.depth / result->latency); - if (GetBoolEnvVarOrFalse("RUY_BENCHMARK_PMU")) { - printf(",%.3g,%.3g,%.3g,%.3g,%.3g,%.3g,%.3g,%.3g", - result->l1_refill_rate, result->l2_refill_rate, - result->l3_refill_rate, result->l1tlb_refill_rate, - result->l2tlb_refill_rate, result->mispred_rate, - result->frontend_stall_rate, result->backend_stall_rate); - } - printf("\n"); - } - fflush(stdout); - } - } -} - -} // namespace ruy - -int main() { ruy::Benchmark(); } diff --git a/block_map.cc b/block_map.cc deleted file mode 100644 index 04ef5b2..0000000 --- a/block_map.cc +++ /dev/null @@ -1,486 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "block_map.h" - -#include -#include - -#ifdef RUY_MAKEBLOCKMAP_DEBUG -#include -#include -#include -#endif - -#include "check_macros.h" -#include "opt_set.h" -#include "profiler/instrumentation.h" -#include "size_util.h" - -namespace ruy { - -namespace { - -void DecodeTraversalLinear(int size_log2, std::uint32_t square_index, - SidePair* local_pos) { - (*local_pos)[Side::kLhs] = square_index & ((1 << size_log2) - 1); - (*local_pos)[Side::kRhs] = square_index >> size_log2; -} - -void DecodeTraversalFractalZ(std::uint32_t square_index, - SidePair* local_pos) { - const std::uint32_t n1 = square_index; - const std::uint32_t n2 = (n1 & 0x99999999u) | ((n1 & 0x44444444u) >> 1) | - ((n1 & 0x22222222u) << 1); - const std::uint32_t n4 = (n2 & 0xc3c3c3c3u) | ((n2 & 0x30303030u) >> 2) | - ((n2 & 0x0c0c0c0cu) << 2); - const std::uint32_t n8 = (n4 & 0xf00ff00fu) | ((n4 & 0x0f000f00u) >> 4) | - ((n4 & 0x00f000f0u) << 4); - const std::uint32_t n16 = (n8 & 0xff0000ffu) | ((n8 & 0x00ff0000u) >> 8) | - ((n8 & 0x0000ff00u) << 8); - (*local_pos)[Side::kLhs] = n16 & 0xffff; - (*local_pos)[Side::kRhs] = n16 >> 16; -} - -void DecodeTraversalFractalU(std::uint32_t square_index, - SidePair* local_pos) { - DecodeTraversalFractalZ(square_index, local_pos); - // Change fractal z-order to u-order - (*local_pos)[Side::kLhs] ^= (*local_pos)[Side::kRhs]; -} - -// Code inspired by the sample code in -// https://en.wikipedia.org/wiki/Hilbert_curve -// The main optimization is to avoid hard-to-predict conditional branches -// based on the bits of the square_index parameter. -void DecodeTraversalFractalHilbert(int size_log2, std::uint32_t square_index, - SidePair* local_pos) { - std::uint32_t t = square_index; - std::uint32_t x = 0; - std::uint32_t y = 0; - // Easy-to-predict for loop, the number of iterations is the same for - // an entire GEMM. - for (int sb = 0; sb < size_log2; sb++) { - std::uint32_t s = 1 << sb; - bool rx = t & 2; - bool ry = (t & 1) ^ rx; - std::uint32_t tmp = rx ? (s - 1 - x) : x; - x = ry ? x : rx ? (s - 1 - y) : y; - y = ry ? (y + s) : tmp; - x = rx ? (x + s) : x; - t >>= 2; - } - (*local_pos)[Side::kLhs] = y; - (*local_pos)[Side::kRhs] = x; -} - -} // end anonymous namespace - -void GetBlockByIndex(const BlockMap& block_map, int index, - SidePair* block) { - profiler::ScopeLabel label("GetBlockByIndex"); - const std::uint32_t index_u32 = index; - - const std::uint32_t num_blocks_per_local_curve = - 1u << (2 * block_map.num_blocks_base_log2); - const std::uint32_t square_index = - index_u32 & (num_blocks_per_local_curve - 1); - - const int size_log2 = block_map.num_blocks_base_log2; - SidePair local_pos; - switch (block_map.traversal_order) { - case BlockMapTraversalOrder::kFractalZ: - DecodeTraversalFractalZ(square_index, &local_pos); - break; - case BlockMapTraversalOrder::kFractalU: - DecodeTraversalFractalU(square_index, &local_pos); - break; - case BlockMapTraversalOrder::kFractalHilbert: - DecodeTraversalFractalHilbert(size_log2, square_index, &local_pos); - break; - default: - RUY_DCHECK(block_map.traversal_order == BlockMapTraversalOrder::kLinear); - DecodeTraversalLinear(size_log2, square_index, &local_pos); - break; - } - - const std::uint32_t rectangular_index = - index_u32 >> 2 * block_map.num_blocks_base_log2; - for (Side side : {Side::kLhs, Side::kRhs}) { - const std::uint32_t mask = (1u << block_map.rectangularness_log2[side]) - 1; - const int rectangular_offset = (rectangular_index & mask) - << block_map.num_blocks_base_log2; - (*block)[side] = local_pos[side] + rectangular_offset; - } -} - -BlockMapTraversalOrder GetTraversalOrder(int rows, int cols, int depth, - int lhs_scalar_size, - int rhs_scalar_size, - int local_data_cache_size, - int shared_data_cache_size) { - const int kFractalOptSets = - RUY_OPT_FRACTAL_Z | RUY_OPT_FRACTAL_U | RUY_OPT_FRACTAL_HILBERT; - const int working_set_size = - (lhs_scalar_size * rows + rhs_scalar_size * cols) * depth; - if (RUY_OPT_ENABLED(kFractalOptSets) && - (working_set_size > local_data_cache_size)) { - if (RUY_OPT_ENABLED(RUY_OPT_FRACTAL_HILBERT) && - (working_set_size > shared_data_cache_size)) { - return BlockMapTraversalOrder::kFractalHilbert; - } else if (RUY_OPT_ENABLED(RUY_OPT_FRACTAL_U)) { - return BlockMapTraversalOrder::kFractalU; - } else { - return BlockMapTraversalOrder::kFractalZ; - } - } else { - return BlockMapTraversalOrder::kLinear; - } -} - -namespace { - -int floor_log2_quotient(int num, int denom) { - if (num <= denom) { - return 0; - } - int log2_quotient = floor_log2(num) - ceil_log2(denom); - if ((denom << (log2_quotient + 1)) <= num) { - log2_quotient++; - } - return log2_quotient; -} - -// Computes the rectangularness of the matrix shape (rows, cols). This is -// essentially just the log2 of the quotient (rows / cols). The kernel_rows and -// kernel_cols only get into the picture for clamping bounds but don't affect -// the generic computation. -void GetRectangularness(int rows, int cols, int kernel_rows, int kernel_cols, - int* rows_rectangularness_log2, - int* cols_rectangularness_log2) { - *rows_rectangularness_log2 = 0; - *cols_rectangularness_log2 = 0; - - // In GEMV-ish cases, that is when kernel blocks are as narrow as the kernel - // itself, we risk having too small kernel blocks for good kernel - // amortization. We avoid that by limiting recangularness so that kernel - // blocks are not too tiny at least in that dimension. Specifically, we try to - // have at least (2^min_kernel_inner_loop_runs_log2) kernels fitting in each - // kernel block along the large dimension. - const int min_kernel_inner_loop_runs_log2 = 3; - if (rows > cols) { - int cols_of_kernel_inner_loop_runs_log2 = - ceil_log2(cols) - pot_log2(kernel_cols); - int min_rows_of_kernel_inner_loop_runs_log2 = - std::max(0, min_kernel_inner_loop_runs_log2 - - cols_of_kernel_inner_loop_runs_log2); - *rows_rectangularness_log2 = - std::min(floor_log2_quotient(rows, cols), - std::max(0, floor_log2(rows) - pot_log2(kernel_rows) - - min_rows_of_kernel_inner_loop_runs_log2)); - // Sanity check that we did not over-estimate rows_rectangularness_log2. - RUY_DCHECK_GE(rows >> *rows_rectangularness_log2, cols); - } else if (cols > rows) { - int rows_of_kernel_inner_loop_runs_log2 = - ceil_log2(rows) - pot_log2(kernel_rows); - int min_cols_of_kernel_inner_loop_runs_log2 = - std::max(0, min_kernel_inner_loop_runs_log2 - - rows_of_kernel_inner_loop_runs_log2); - *cols_rectangularness_log2 = - std::min(floor_log2_quotient(cols, rows), - std::max(0, floor_log2(cols) - pot_log2(kernel_cols) - - min_cols_of_kernel_inner_loop_runs_log2)); - // Sanity check that we did not over-estimate cols_rectangularness_log2. - RUY_DCHECK_GE(cols >> *cols_rectangularness_log2, rows); - } - RUY_DCHECK(!*rows_rectangularness_log2 || !*cols_rectangularness_log2); -} - -// Computes a 'multithreading score'. When multithreading, we need there to -// be at least as many tiles as there are threads, and hopefully -// substantially more than that, so we benefit from ruy's ability to -// dispatch fine-grained workloads to threads. -int GetMultithreadingScore(int block_size_log2, int rows, int cols, - int tentative_thread_count) { - const int num_full_blocks_of_rows = rows >> block_size_log2; - const int num_full_blocks_of_cols = cols >> block_size_log2; - const int candidate_num_full_blocks_log2 = floor_log2( - std::max(1, num_full_blocks_of_rows * num_full_blocks_of_cols)); - - // The values here have been tuned on ARM Cortex-A55. - // We expect this to have to be tuned differently for other CPUs. - if (tentative_thread_count == 1) { - return 0; - } else { - const int blocks_per_thread_log2 = - candidate_num_full_blocks_log2 - ceil_log2(tentative_thread_count); - if (blocks_per_thread_log2 < 0) { - return -64; - } else if (blocks_per_thread_log2 == 0) { - return -16; - } else if (blocks_per_thread_log2 == 1) { - return -8; - } else if (blocks_per_thread_log2 == 2) { - return 0; - } else if (blocks_per_thread_log2 == 3) { - return 8; - } else { - return 16; - } - } -} - -// Computes a 'cache locality score'. -int GetCacheLocalityScore(int block_size_log2, int rows, int cols, int depth, - int kernel_rows_log2, int kernel_cols_log2, - int lhs_scalar_size, int rhs_scalar_size, Path path, - int local_data_cache_size) { - // In the narrow case (e.g. matrix*vector), each byte of the big operand - // matrix (either LHS or RHS) is traversed only once, so any notion of data - // locality is irrelevant. Ignore the 'cache locality score' by forcing it to - // be 0 in that case. - if (rows <= (1 << kernel_rows_log2) || cols <= (1 << kernel_cols_log2)) { - return 0; - } - const int block_rows = std::min(1 << block_size_log2, rows); - const int block_cols = std::min(1 << block_size_log2, cols); - const int total_read_bytes = - (lhs_scalar_size * block_rows + rhs_scalar_size * block_cols) * depth; - const int total_read_bytes_log2 = ceil_log2(total_read_bytes); - const int nonlocality_log2 = - total_read_bytes_log2 - floor_log2(local_data_cache_size); - // The values here have been tuned on ARM Cortex-A55. - // We expect this to have to be tuned differently for other CPUs. - if (nonlocality_log2 < -1) { - return 64; - } else if (nonlocality_log2 == -1) { - return 56; - } else if (nonlocality_log2 == 0) { - return 48; - } else if (nonlocality_log2 == 1) { - return 32; - } else if (nonlocality_log2 == 2) { - return 16; - } else if (nonlocality_log2 == 3) { - return 0; - } else { - return -64; - } -} - -// Compute a 'kernel amortization score'. This is the notion that very small -// tiles result in more overhead outside of kernels, more complex memory -// access patterns and less benefits from ruy's fat kernels, so we reward -// larger blocks more than smaller ones. -int GetKernelAmortizationScore(int block_size_log2, int rows, int cols, - int kernel_rows_log2, int kernel_cols_log2) { - const int block_rows = std::min(1 << block_size_log2, rows); - const int block_cols = std::min(1 << block_size_log2, cols); - const int kernels_per_block_log2 = - floor_log2(block_rows * block_cols) - kernel_rows_log2 - kernel_cols_log2; - RUY_DCHECK_GE(kernels_per_block_log2, 0); - // The values here have been tuned on ARM Cortex-A55. - // We expect this to have to be tuned differently for other CPUs. - if (kernels_per_block_log2 == 0) { - return 0; - } else if (kernels_per_block_log2 == 1) { - return 8; - } else if (kernels_per_block_log2 == 2) { - return 16; - } else if (kernels_per_block_log2 == 3) { - return 24; - } else if (kernels_per_block_log2 == 4) { - return 32; - } else if (kernels_per_block_log2 == 5) { - return 40; - } else if (kernels_per_block_log2 == 6) { - return 48; - } else if (kernels_per_block_log2 == 7) { - return 56; - } else { - return 64; - } -} - -} // namespace - -void MakeBlockMap(int rows, int cols, int depth, int kernel_rows, - int kernel_cols, int lhs_scalar_size, int rhs_scalar_size, - int tentative_thread_count, Path path, - int local_data_cache_size, int shared_data_cache_size, - BlockMap* block_map) { - profiler::ScopeLabel label("MakeBlockMap"); - -#ifdef RUY_MAKEBLOCKMAP_DEBUG -#if RUY_MAKEBLOCKMAP_DEBUG >= 2 - static constexpr bool debug_everytime = true; -#else - static constexpr bool debug_everytime = false; -#endif - static bool firsttime = true; - if (firsttime || debug_everytime) { - fprintf(stderr, - "MakeBlockMap(rows=%d, cols=%d, depth=%d, kernel_rows=%d, " - "kernel_cols=%d, lhs_scalar_size=%d, rhs_scalar_size=%d, " - "tentative_thread_count=%d)\n", - rows, cols, depth, kernel_rows, kernel_cols, lhs_scalar_size, - rhs_scalar_size, tentative_thread_count); - } -#endif - - RUY_DCHECK_GE(rows, kernel_rows); - RUY_DCHECK_GE(cols, kernel_cols); - RUY_DCHECK_EQ(rows % kernel_rows, 0); - RUY_DCHECK_EQ(cols % kernel_cols, 0); - - block_map->traversal_order = - GetTraversalOrder(rows, cols, depth, lhs_scalar_size, rhs_scalar_size, - local_data_cache_size, shared_data_cache_size); - - int rows_rectangularness_log2 = 0; - int cols_rectangularness_log2 = 0; - GetRectangularness(rows, cols, kernel_rows, kernel_cols, - &rows_rectangularness_log2, &cols_rectangularness_log2); - - const int kernel_rows_log2 = pot_log2(kernel_rows); - const int kernel_cols_log2 = pot_log2(kernel_cols); - const int kernel_size_log2 = std::max(kernel_cols_log2, kernel_rows_log2); - - const int size = std::min(rows, cols); - const int size_log2 = std::max(kernel_size_log2, floor_log2(size)); - - RUY_DCHECK_GE(size_log2, kernel_size_log2); - - // We are going to try candidate values for block_size_log2 ranging from - // kernel_size_log2 to (kernel_size_log2 + kMaxKernelsPerBlockLog2). - // For each of them we will compute a 'score' by adding individual scores - // for a few different considerations, all of which is entirely empirical. - // The values (and possibly the logic) around here are all subject to tuning - // based on benchmarks on different hardware. The current values are based - // on benchmarking on Qualcomm S855 (big and little cores), arm64, - // kNeonDotprod, 8bit quantized path. Don't read too much into it, go ahead - // and tune this as needed to achieve good performance elsewhere. Use - // the unit test, block_map_test, to encode values that should be preserved - // on specific architectures. Use RUY_MAKEBLOCKMAP_DEBUG to help tuning this. - static constexpr int kMaxKernelsPerBlockLog2 = 6; - const int max_block_size_log2 = - std::min(size_log2, kernel_size_log2 + kMaxKernelsPerBlockLog2); - int best_score = std::numeric_limits::min(); - int best_score_block_size_log2 = -1; - for (int block_size_log2 = kernel_size_log2; - block_size_log2 <= max_block_size_log2; block_size_log2++) { - const int multithreading_score = GetMultithreadingScore( - block_size_log2, rows, cols, tentative_thread_count); - const int cache_locality_score = GetCacheLocalityScore( - block_size_log2, rows, cols, depth, kernel_rows_log2, kernel_cols_log2, - lhs_scalar_size, rhs_scalar_size, path, local_data_cache_size); - const int kernel_amortization_score = GetKernelAmortizationScore( - block_size_log2, rows, cols, kernel_rows_log2, kernel_cols_log2); - const int score = - multithreading_score + cache_locality_score + kernel_amortization_score; -#ifdef RUY_MAKEBLOCKMAP_DEBUG - if (firsttime || debug_everytime) { - fprintf(stderr, - "block_size_log2=%d: score=%d multithreading_score=%d " - "cache_locality_score=%d kernel_amortization_score=%d\n", - block_size_log2, score, multithreading_score, - cache_locality_score, kernel_amortization_score); - } -#endif - if (score >= best_score) { - best_score = score; - best_score_block_size_log2 = block_size_log2; - } - } - -#ifdef RUY_MAKEBLOCKMAP_DEBUG - if (firsttime || debug_everytime) { - fprintf(stderr, "best_score_block_size_log2=%d\n", - best_score_block_size_log2); - } - - static const char* explicit_block_size_log2_env = - getenv("RUY_MAKEBLOCKMAP_EXPLICIT_BLOCK_SIZE_LOG2"); - if (explicit_block_size_log2_env) { - best_score_block_size_log2 = std::stoi(explicit_block_size_log2_env); - if (firsttime || debug_everytime) { - fprintf(stderr, "Overridden best_score_block_size_log2=%d\n", - best_score_block_size_log2); - } - } - firsttime = false; -#endif - - int num_blocks_base_log2 = size_log2 - best_score_block_size_log2; - RUY_DCHECK_GE(num_blocks_base_log2, 0); - - const int num_blocks_of_rows_log2 = - num_blocks_base_log2 + rows_rectangularness_log2; - const int num_blocks_of_cols_log2 = - num_blocks_base_log2 + cols_rectangularness_log2; - - const int smallr = - round_down_pot(rows >> num_blocks_of_rows_log2, kernel_rows); - const int smallc = - round_down_pot(cols >> num_blocks_of_cols_log2, kernel_cols); - const int missr = - round_up_pot(rows - (smallr << num_blocks_of_rows_log2), kernel_rows) >> - pot_log2(kernel_rows); - const int missc = - round_up_pot(cols - (smallc << num_blocks_of_cols_log2), kernel_cols) >> - pot_log2(kernel_cols); - - block_map->dims[Side::kLhs] = rows; - block_map->dims[Side::kRhs] = cols; - block_map->kernel_dims[Side::kLhs] = kernel_rows; - block_map->kernel_dims[Side::kRhs] = kernel_cols; - block_map->num_blocks_base_log2 = num_blocks_base_log2; - block_map->rectangularness_log2[Side::kLhs] = rows_rectangularness_log2; - block_map->rectangularness_log2[Side::kRhs] = cols_rectangularness_log2; - block_map->small_block_dims[Side::kLhs] = smallr; - block_map->small_block_dims[Side::kRhs] = smallc; - block_map->large_blocks[Side::kLhs] = missr; - block_map->large_blocks[Side::kRhs] = missc; - // Done last: NumBlocks needs some of the block_map fields to be already set. - block_map->thread_count = - std::min(tentative_thread_count, NumBlocks(*block_map)); -} - -void GetBlockMatrixCoords(Side side, const BlockMap& block_map, int block, - int* start, int* end) { - profiler::ScopeLabel label("GetBlockMatrixCoords"); - *start = block * block_map.small_block_dims[side] + - std::min(block, block_map.large_blocks[side]) * - block_map.kernel_dims[side]; - *end = - *start + block_map.small_block_dims[side] + - (block < block_map.large_blocks[side] ? block_map.kernel_dims[side] : 0); - - RUY_DCHECK_EQ(0, *start % block_map.kernel_dims[side]); - RUY_DCHECK_EQ(0, *end % block_map.kernel_dims[side]); - RUY_DCHECK_LE(*end, block_map.dims[side]); - RUY_DCHECK_LT(*start, *end); - RUY_DCHECK_GE(*start, 0); -} - -void GetBlockMatrixCoords(const BlockMap& block_map, const SidePair& block, - SidePair* start, SidePair* end) { - for (Side side : {Side::kLhs, Side::kRhs}) { - GetBlockMatrixCoords(side, block_map, block[side], &(*start)[side], - &(*end)[side]); - } -} - -} // namespace ruy diff --git a/block_map.h b/block_map.h deleted file mode 100644 index 18e9847..0000000 --- a/block_map.h +++ /dev/null @@ -1,161 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_BLOCK_MAP_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_BLOCK_MAP_H_ - -#include "path.h" -#include "side_pair.h" - -namespace ruy { - -enum class BlockMapTraversalOrder { - // Plain old row-by-row or column-by-column traversal. - kLinear, - // Fractal Z-order curve, https://en.wikipedia.org/wiki/Z-order_curve - kFractalZ, - // Variant of Z-order doing a U instead of a Z. - kFractalU, - // Hilbert curve, https://en.wikipedia.org/wiki/Hilbert_curve - kFractalHilbert -}; - -// A BlockMap describes a tiling of a matrix, typically the destination matrix -// of a matrix multiplication computation. As is standard in matrix -// multiplication, a tile is called a "block". -// -// Ruy subdivides work by blocks of the destination matrix: each thread fully -// computes a block at once, then moves on to another block; each block is -// produced by a single thread. -// -// This ensures that the workloads for each block are mutually independent, -// which reduces synchronization requirements. -// -// Typically, a matrix multiplication will early on create a BlockMap by -// calling MakeBlockMap. It will then query the number of blocks in that -// BlockMap by calling NumBlocks. It will then create a single atomic integer -// counter indexing these blocks, called the 'index', and will distribute -// work to its N threads by ensuring that each thread works on disjoint sets -// of index values. For a given index value, the thread will call -// GetBlockByIndex to get the corresponding block, then GetBlockMatrixCoords -// to find the actual row and column numbers of this block. -// -// There are two nested levels of subdivision. On a local level, the matrix is -// tiled into a square NxN grid where N is a power of two, specifically: -// N = 2^num_blocks_base_log2. -// -// At a larger scale, around these blocks, there may be one further -// level of subdivision, in only one dimension: either along rows or along -// columns. That is used to handle arbitrarily rectangular matrices. The -// aforementioned high-level block grid is square, so it does not readily fit -// well very rectangular matrices. -// -// Taking together these two nested levels of subdivision, the effective -// tiling is by -// 2^(num_blocks_base_log2 + rows_rectangularness_log2) -// blocks in the row dimension, and by -// 2^(num_blocks_base_log2 + cols_rectangularness_log2) -// blocks in the column dimension. See NumBlocksOfRows, NumBlocksOfCols. -// -// Either rows_rectangularness_log2 or cols_rectangularness_log2 must be zero. -// -// Finally, this BlockMap is designed to operate under alignment constraints: -// two fields, kernel_rows and kernel_cols, describe the requested alignment -// of the effective grid in both dimensions. The idea is to feed matrix -// multiplication kernels with tiles that fit their width as much as possible. -// Of course, if rows (resp. cols) is not a multiple of kernel_rows (resp. -// kernel_cols) then some tile will have to have unaligned size. BlockMap -// will only allow that to happen in the last position along each axis, so -// as to minimize the overhead incurred onto the matrix multiplication kernels. -struct BlockMap { - // The number of threads to use (to distribute the blocks to). - int thread_count; - // The order in which to traverse the matrix of which this BlockMap represents - // a tiling (hereafter "the matrix"). - BlockMapTraversalOrder traversal_order; - // The dimensions of the block_map, that is, of the destination - // matrix rounded up to next multiples of kernel_dims. - SidePair dims; - // Log2 of the minimum number of subdivisions of the grid along either axis. - int num_blocks_base_log2; - // Log2 of the additional subdivision of the rows/columns axis. - SidePair rectangularness_log2; - // Requested alignment of the subdivisions of the grid along the rows/columns - // axis. - SidePair kernel_dims; - // Internal helper. Minimum number of rows/columns in each block. - SidePair small_block_dims; - // Internal helper. Number of blocks along each dimension that need to have - // their size in that dimension be given by (small_block_dims + kernel_dims) - // instead of just small_block_dims. - SidePair large_blocks; -}; - -// Returns the traversal order to be used for the given matrix multiplication -// parameters. -BlockMapTraversalOrder GetTraversalOrder(int rows, int cols, int depth, - int lhs_scalar_size, - int rhs_scalar_size, - int local_data_cache_size, - int shared_data_cache_size); - -// Create a BlockMap suitable for tiling the destination matrix in a -// matrix multiplication with the given parameters. -void MakeBlockMap(int rows, int cols, int depth, int kernel_rows, - int kernel_cols, int lhs_scalar_size, int rhs_scalar_size, - int tentative_thread_count, Path path, - int local_data_cache_size, int shared_data_cache_size, - BlockMap* block_map); - -// Maps an integer index to a block position in the grid. -void GetBlockByIndex(const BlockMap& block_map, int index, - SidePair* block); - -// Given a block position in the grid, returns its actual -// position in the matrix that the BlockMap refers to in the dimension -// referred to by `side`: along rows if side==kLhs, along columns if -// side==kRhs. -void GetBlockMatrixCoords(Side side, const BlockMap& block_map, int block, - int* start, int* end); - -// Given a block position in the grid, returns its actual -// position in the matrix that the BlockMap refers to in terms of -// actual row/column indices. -void GetBlockMatrixCoords(const BlockMap& block_map, const SidePair& block, - SidePair* start, SidePair* end); - -// Returns the number of grid subdivisions along the rows dimension (if -// side == kLhs) or columns dimension (if side == kRhs). -inline int NumBlocksPerSide(Side side, const BlockMap& block_map) { - return 1 << (block_map.num_blocks_base_log2 + - block_map.rectangularness_log2[side]); -} - -// Returns the overall number of blocks in -// the BlockMap. The valid index values to pass to GetBlockByIndex are the -// integers from 0 to N-1 where N is the value returned here. -// -// Note that it is always true that -// NumBlocks == NumBlocksOfRows * NumBlocksOfCols -// because either rows_rectangularness_log2 or cols_rectangularness_log2 is 0. -inline int NumBlocks(const BlockMap& block_map) { - return 1 << (2 * block_map.num_blocks_base_log2 + - block_map.rectangularness_log2[Side::kLhs] + - block_map.rectangularness_log2[Side::kRhs]); -} - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_BLOCK_MAP_H_ diff --git a/block_map_test.cc b/block_map_test.cc deleted file mode 100644 index 3ce6d0f..0000000 --- a/block_map_test.cc +++ /dev/null @@ -1,263 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "block_map.h" - -#include -#include -#include -#include -#include - -#include "testing/base/public/gunit.h" -#include "cpu_cache_size.h" -#include "path.h" -#include "side_pair.h" - -namespace ruy { -namespace { - -#if RUY_PLATFORM(NEON_64) - -// Unless otherwise specified, these tests have been tuned on ARM Cortex-A55. -void MakeBlockMapTuningTest(int rows, int cols, int depth, int kernel_rows, - int kernel_cols, int lhs_scalar_size, - int rhs_scalar_size, int tentative_thread_count, - Path path, int expected_num_blocks_base_log2, - int expected_rectangularness_log2) { - BlockMap block_map; - MakeBlockMap(rows, cols, depth, kernel_rows, kernel_cols, lhs_scalar_size, - rhs_scalar_size, tentative_thread_count, path, - LocalDataCacheSize(path), SharedDataCacheSize(path), &block_map); - EXPECT_EQ(block_map.num_blocks_base_log2, expected_num_blocks_base_log2); - EXPECT_EQ(std::min(block_map.rectangularness_log2[Side::kLhs], - block_map.rectangularness_log2[Side::kRhs]), - 0); - EXPECT_EQ(std::max(block_map.rectangularness_log2[Side::kLhs], - block_map.rectangularness_log2[Side::kRhs]), - expected_rectangularness_log2); -} - -TEST(BlockMapTest, MakeBlockMapTuningTest8bitCubicShapesOneThreadNeonDotprod) { - MakeBlockMapTuningTest(32, 32, 32, 8, 8, 1, 1, /* tentative_thread_count */ 1, - Path::kNeonDotprod, - /* expected_num_blocks_base_log2 */ 0, - /* expected_rectangularness_log2 */ 0); - MakeBlockMapTuningTest(48, 48, 48, 8, 8, 1, 1, /* tentative_thread_count */ 1, - Path::kNeonDotprod, - /* expected_num_blocks_base_log2 */ 0, - /* expected_rectangularness_log2 */ 0); - MakeBlockMapTuningTest(64, 64, 64, 8, 8, 1, 1, /* tentative_thread_count */ 1, - Path::kNeonDotprod, - /* expected_num_blocks_base_log2 */ 0, - /* expected_rectangularness_log2 */ 0); - MakeBlockMapTuningTest(96, 96, 96, 8, 8, 1, 1, /* tentative_thread_count */ 1, - Path::kNeonDotprod, - /* expected_num_blocks_base_log2 */ 0, - /* expected_rectangularness_log2 */ 0); - MakeBlockMapTuningTest(128, 128, 128, 8, 8, 1, 1, - /* tentative_thread_count */ 1, Path::kNeonDotprod, - /* expected_num_blocks_base_log2 */ 0, - /* expected_rectangularness_log2 */ 0); - MakeBlockMapTuningTest(192, 192, 192, 8, 8, 1, 1, - /* tentative_thread_count */ 1, Path::kNeonDotprod, - /* expected_num_blocks_base_log2 */ 0, - /* expected_rectangularness_log2 */ 0); - MakeBlockMapTuningTest(256, 256, 256, 8, 8, 1, 1, - /* tentative_thread_count */ 1, Path::kNeonDotprod, - /* expected_num_blocks_base_log2 */ 1, - /* expected_rectangularness_log2 */ 0); - MakeBlockMapTuningTest(384, 384, 384, 8, 8, 1, 1, - /* tentative_thread_count */ 1, Path::kNeonDotprod, - /* expected_num_blocks_base_log2 */ 1, - /* expected_rectangularness_log2 */ 0); -} - -TEST(BlockMapTest, - MakeBlockMapTuningTest8bitCubicShapesFourThreadsNeonDotprod) { - MakeBlockMapTuningTest(32, 32, 32, 8, 8, 1, 1, /* tentative_thread_count */ 4, - Path::kNeonDotprod, - /* expected_num_blocks_base_log2 */ 1, - /* expected_rectangularness_log2 */ 0); - MakeBlockMapTuningTest(48, 48, 48, 8, 8, 1, 1, /* tentative_thread_count */ 4, - Path::kNeonDotprod, - /* expected_num_blocks_base_log2 */ 1, - /* expected_rectangularness_log2 */ 0); - MakeBlockMapTuningTest(64, 64, 64, 8, 8, 1, 1, /* tentative_thread_count */ 4, - Path::kNeonDotprod, - /* expected_num_blocks_base_log2 */ 1, - /* expected_rectangularness_log2 */ 0); - MakeBlockMapTuningTest(96, 96, 96, 8, 8, 1, 1, /* tentative_thread_count */ 4, - Path::kNeonDotprod, - /* expected_num_blocks_base_log2 */ 1, - /* expected_rectangularness_log2 */ 0); - MakeBlockMapTuningTest(128, 128, 128, 8, 8, 1, 1, - /* tentative_thread_count */ 4, Path::kNeonDotprod, - /* expected_num_blocks_base_log2 */ 1, - /* expected_rectangularness_log2 */ 0); - MakeBlockMapTuningTest(192, 192, 192, 8, 8, 1, 1, - /* tentative_thread_count */ 4, Path::kNeonDotprod, - /* expected_num_blocks_base_log2 */ 1, - /* expected_rectangularness_log2 */ 0); - MakeBlockMapTuningTest(256, 256, 256, 8, 8, 1, 1, - /* tentative_thread_count */ 4, Path::kNeonDotprod, - /* expected_num_blocks_base_log2 */ 2, - /* expected_rectangularness_log2 */ 0); - MakeBlockMapTuningTest(384, 384, 384, 8, 8, 1, 1, - /* tentative_thread_count */ 4, Path::kNeonDotprod, - /* expected_num_blocks_base_log2 */ 2, - /* expected_rectangularness_log2 */ 0); -} - -TEST(BlockMapTest, MakeBlockMapTuningTest32bit) { - MakeBlockMapTuningTest(256, 256, 256, 8, 8, 4, 4, - /* tentative_thread_count */ 4, Path::kNeonDotprod, - /* expected_num_blocks_base_log2 */ 3, - /* expected_rectangularness_log2 */ 0); - MakeBlockMapTuningTest(4096, 4096, 4096, 8, 8, 4, 4, - /* tentative_thread_count */ 4, Path::kNeonDotprod, - /* expected_num_blocks_base_log2 */ 7, - /* expected_rectangularness_log2 */ 0); -} - -TEST(BlockMapTest, MakeBlockMapTuningTestRectangular) { - MakeBlockMapTuningTest(256, 16, 256, 8, 8, 1, 1, - /* tentative_thread_count */ 1, Path::kNeonDotprod, - /* expected_num_blocks_base_log2 */ 0, - /* expected_rectangularness_log2 */ 3); - MakeBlockMapTuningTest(24, 2400, 256, 8, 8, 1, 1, - /* tentative_thread_count */ 1, Path::kNeonDotprod, - /* expected_num_blocks_base_log2 */ 0, - /* expected_rectangularness_log2 */ 6); -} - -#endif - -int L1Distance(const SidePair& a, const SidePair& b) { - return std::abs(a[Side::kLhs] - b[Side::kLhs]) + - std::abs(a[Side::kRhs] - b[Side::kRhs]); -} - -void GetBlockByIndexSquareTest(int num_blocks_base_log2, - BlockMapTraversalOrder traversal_order) { - // Arbitrary, does not affect this test. 3 is just a typical value. - constexpr int kKernelSizeLog2 = 3; - - const int size_log2 = num_blocks_base_log2 + kKernelSizeLog2; - BlockMap block_map; - block_map.thread_count = 1; - block_map.traversal_order = traversal_order; - block_map.num_blocks_base_log2 = num_blocks_base_log2; - for (Side side : {Side::kLhs, Side::kRhs}) { - block_map.dims[side] = 1 << size_log2; - block_map.rectangularness_log2[side] = 0; - block_map.kernel_dims[side] = 1 << kKernelSizeLog2; - block_map.small_block_dims[side] = block_map.kernel_dims[side]; - block_map.large_blocks[side] = 0; - } - - const int num_blocks_per_side = 1 << num_blocks_base_log2; - const int num_blocks = num_blocks_per_side * num_blocks_per_side; - EXPECT_EQ(num_blocks, NumBlocks(block_map)); - - // Perform a full traversal of all blocks, as if computing a whole matrix - // multiplication. - // - // Used to record how many times each block was hit by the traversal. - std::vector block_hit_counts(num_blocks); - // Here we guard an assumption that all traversal orders start at (0, 0). - SidePair previous_block_coords(0, 0); - // Sum of L1 norm of the coordinate change at every step of the traversal. - std::int64_t total_l1_distance = 0; - // Number of jumps i.e. traversal steps with a L1 norm greater than 1. - int discontinuity_count = 0; - for (int block_index = 0; block_index < num_blocks; block_index++) { - SidePair block_coords; - GetBlockByIndex(block_map, block_index, &block_coords); - ++block_hit_counts[block_coords[Side::kLhs] + - num_blocks_per_side * block_coords[Side::kRhs]]; - int distance = L1Distance(block_coords, previous_block_coords); - total_l1_distance += distance; - discontinuity_count += (distance > 1); - previous_block_coords = block_coords; - } - - // Verify that each block was traversed exactly once. - for (int l = 0; l < num_blocks_per_side; l++) { - for (int r = 0; r < num_blocks_per_side; r++) { - EXPECT_EQ(block_hit_counts[l + num_blocks_per_side * r], 1); - } - } - - // Verify that the discontinuity_count and total_l1_distance are as expected - // for the given traversal_order. - switch (traversal_order) { - case BlockMapTraversalOrder::kFractalHilbert: - // No discontinuity at all with this space-filling continuous curve! - EXPECT_EQ(discontinuity_count, 0); - // Therefore, total_l1_distance has to be the number of blocks minus one. - EXPECT_EQ(total_l1_distance, num_blocks - 1); - break; - case BlockMapTraversalOrder::kLinear: - EXPECT_EQ(discontinuity_count, num_blocks_per_side - 1); - EXPECT_EQ(total_l1_distance, - 2 * num_blocks_per_side * (num_blocks_per_side - 1)); - break; - case BlockMapTraversalOrder::kFractalZ: - EXPECT_EQ(discontinuity_count, num_blocks > 1 ? (num_blocks / 2 - 1) : 0); - EXPECT_EQ(total_l1_distance, - 2 * num_blocks_per_side * (num_blocks_per_side - 1)); - break; - case BlockMapTraversalOrder::kFractalU: { - if (num_blocks_base_log2 == 0) { - EXPECT_EQ(discontinuity_count, 0); - EXPECT_EQ(total_l1_distance, 0); - } else { - int expected_discontinuity_count = 0; - int expected_total_l1_distance = 3; - for (int i = 2; i <= num_blocks_base_log2; i++) { - expected_discontinuity_count = 4 * expected_discontinuity_count + 2; - expected_total_l1_distance = - 4 * expected_total_l1_distance + (1 << (i + 1)) - 1; - } - EXPECT_EQ(discontinuity_count, expected_discontinuity_count); - EXPECT_EQ(total_l1_distance, expected_total_l1_distance); - } - break; - } - default: - abort(); - } -} - -TEST(BlockMapTest, GetBlockByIndexSquare) { - for (int num_blocks_base_log2 = 0; num_blocks_base_log2 <= 10; - num_blocks_base_log2++) { - for (BlockMapTraversalOrder traversal_order : - {BlockMapTraversalOrder::kLinear, BlockMapTraversalOrder::kFractalZ, - BlockMapTraversalOrder::kFractalU, - BlockMapTraversalOrder::kFractalHilbert}) { - GetBlockByIndexSquareTest(num_blocks_base_log2, traversal_order); - } - } -} - -} // namespace -} // namespace ruy - -int main(int argc, char **argv) { - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); -} diff --git a/blocking_counter.cc b/blocking_counter.cc deleted file mode 100644 index 2bfb896..0000000 --- a/blocking_counter.cc +++ /dev/null @@ -1,49 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "blocking_counter.h" - -#include "check_macros.h" -#include "wait.h" - -namespace ruy { - -void BlockingCounter::Reset(int initial_count) { - int old_count_value = count_.load(std::memory_order_relaxed); - RUY_DCHECK_EQ(old_count_value, 0); - (void)old_count_value; - count_.store(initial_count, std::memory_order_release); -} - -bool BlockingCounter::DecrementCount() { - int old_count_value = count_.fetch_sub(1, std::memory_order_acq_rel); - RUY_DCHECK_GT(old_count_value, 0); - int count_value = old_count_value - 1; - bool hit_zero = (count_value == 0); - if (hit_zero) { - std::lock_guard lock(count_mutex_); - count_cond_.notify_all(); - } - return hit_zero; -} - -void BlockingCounter::Wait() { - const auto& condition = [this]() { - return count_.load(std::memory_order_acquire) == 0; - }; - ruy::Wait(condition, &count_cond_, &count_mutex_); -} - -} // namespace ruy diff --git a/blocking_counter.h b/blocking_counter.h deleted file mode 100644 index e8c76d5..0000000 --- a/blocking_counter.h +++ /dev/null @@ -1,62 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_BLOCKING_COUNTER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_BLOCKING_COUNTER_H_ - -#include -#include // NOLINT(build/c++11) // IWYU pragma: keep -#include // NOLINT(build/c++11) // IWYU pragma: keep - -namespace ruy { - -// A BlockingCounter lets one thread to wait for N events to occur. -// This is how the master thread waits for all the worker threads -// to have finished working. -// The waiting is done using a naive spinlock waiting for the atomic -// count_ to hit the value 0. This is acceptable because in our usage -// pattern, BlockingCounter is used only to synchronize threads after -// short-lived tasks (performing parts of the same GEMM). It is not used -// for synchronizing longer waits (resuming work on the next GEMM). -class BlockingCounter { - public: - BlockingCounter() : count_(0) {} - - // Sets/resets the counter; initial_count is the number of - // decrementing events that the Wait() call will be waiting for. - void Reset(int initial_count); - - // Decrements the counter; if the counter hits zero, signals - // the threads that were waiting for that, and returns true. - // Otherwise (if the decremented count is still nonzero), - // returns false. - bool DecrementCount(); - - // Waits for the N other threads (N having been set by Reset()) - // to hit the BlockingCounter. - void Wait(); - - private: - std::atomic count_; - - // The condition variable and mutex allowing to passively wait for count_ - // to reach the value zero, in the case of longer waits. - std::condition_variable count_cond_; - std::mutex count_mutex_; -}; - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_BLOCKING_COUNTER_H_ diff --git a/build_defs.bzl b/build_defs.bzl deleted file mode 100644 index 9bccccf..0000000 --- a/build_defs.bzl +++ /dev/null @@ -1,40 +0,0 @@ -"""Build definitions for Ruy.""" - -# 1. Enable -mfpu=neon unconditionally on ARM32. If it turns out that we need to support -# ARM32 without NEON then we'll implement runtime detection and dispatch at that point. -# 2. Explicitly pass -O3 on optimization configs where just "-c opt" means "optimize for code size". - -def ruy_copts_base(): - return select({ - ":armeabi-v7a": [ - "-mfpu=neon", - ], - "//conditions:default": [], - }) + select({ - ":optimized": ["-O3"], - "//conditions:default": [], - }) - -# Used for targets that are compiled with extra features that are skipped at runtime if unavailable. -def ruy_copts_skylake(): - return [] - -# Used for targets that are compiled with extra features that are skipped at runtime if unavailable. -def ruy_copts_avx2(): - return [] - -# TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. -# Optimization is not finished. In particular the dimensions of the kernel -# blocks can be changed as desired. -# -# Used for targets that are compiled with extra features that are skipped at runtime if unavailable. -def ruy_copts_sse42(): - return [] - -# TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. -# Optimization is not finished. In particular the dimensions of the kernel -# blocks can be changed as desired. -# -# Used for targets that are compiled with extra features that are skipped at runtime if unavailable. -def ruy_copts_avxvnni(): - return [] diff --git a/check_macros.h b/check_macros.h deleted file mode 100644 index 564440b..0000000 --- a/check_macros.h +++ /dev/null @@ -1,138 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_CHECK_MACROS_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_CHECK_MACROS_H_ - -#include -#include -#include - -namespace ruy { -namespace check_macros { - -constexpr int kValueBufSize = 32; - -template -struct ToString { - static void Run(const T& value, char* buf) { - snprintf(buf, kValueBufSize, "(?)"); - } -}; - -template <> -struct ToString { - static void Run(float value, char* buf) { - snprintf(buf, kValueBufSize, "%.9g", static_cast(value)); - } -}; - -template <> -struct ToString { - static void Run(double value, char* buf) { - snprintf(buf, kValueBufSize, "%.16g", value); - } -}; - -template -struct ToString::value>::type> { - static void Run(const T& value, char* buf) { - snprintf(buf, kValueBufSize, "%lld", static_cast(value)); - } -}; - -template -struct ToString { - static void Run(T* value, char* buf) { - snprintf(buf, kValueBufSize, "%p", value); - } -}; - -template -struct ToString::value>::type> { - static void Run(const T& value, char* buf) { - snprintf(buf, kValueBufSize, "(enum value %d)", static_cast(value)); - } -}; - -inline void Failure(const char* file, int line, const char* macro, - const char* condition) { - fprintf(stderr, "%s:%d: %s condition not satisfied: %s\n", file, line, macro, - condition); - abort(); -} - -template -inline void Failure(const char* file, int line, const char* macro, - const char* lhs, const LhsType& lhs_value, const char* op, - const char* rhs, const RhsType& rhs_value) { - char lhs_value_buf[kValueBufSize]; - ToString::Run(lhs_value, lhs_value_buf); - char rhs_value_buf[kValueBufSize]; - ToString::Run(rhs_value, rhs_value_buf); - fprintf(stderr, - "%s:%d: %s condition not satisfied: [ %s %s %s ] with values [ " - "%s %s %s ].\n", - file, line, macro, lhs, op, rhs, lhs_value_buf, op, rhs_value_buf); - abort(); -} - -#define RUY_CHECK_IMPL(macro, condition) \ - do { \ - if (!(condition)) { \ - ruy::check_macros::Failure(__FILE__, __LINE__, #macro, #condition); \ - } \ - } while (false) - -#define RUY_CHECK_OP_IMPL(macro, lhs, op, rhs) \ - do { \ - const auto& lhs_value = (lhs); \ - const auto& rhs_value = (rhs); \ - if (!(lhs_value op rhs_value)) { \ - ruy::check_macros::Failure(__FILE__, __LINE__, #macro, #lhs, lhs_value, \ - #op, #rhs, rhs_value); \ - } \ - } while (false) - -#define RUY_CHECK(condition) RUY_CHECK_IMPL(RUY_CHECK, condition) -#define RUY_CHECK_EQ(x, y) RUY_CHECK_OP_IMPL(RUY_CHECK_EQ, x, ==, y) -#define RUY_CHECK_NE(x, y) RUY_CHECK_OP_IMPL(RUY_CHECK_NE, x, !=, y) -#define RUY_CHECK_GE(x, y) RUY_CHECK_OP_IMPL(RUY_CHECK_GE, x, >=, y) -#define RUY_CHECK_GT(x, y) RUY_CHECK_OP_IMPL(RUY_CHECK_GT, x, >, y) -#define RUY_CHECK_LE(x, y) RUY_CHECK_OP_IMPL(RUY_CHECK_LE, x, <=, y) -#define RUY_CHECK_LT(x, y) RUY_CHECK_OP_IMPL(RUY_CHECK_LT, x, <, y) - -#ifdef NDEBUG -#define RUY_DCHECK(condition) -#define RUY_DCHECK_EQ(x, y) -#define RUY_DCHECK_NE(x, y) -#define RUY_DCHECK_GE(x, y) -#define RUY_DCHECK_GT(x, y) -#define RUY_DCHECK_LE(x, y) -#define RUY_DCHECK_LT(x, y) -#else -#define RUY_DCHECK(condition) RUY_CHECK(condition) -#define RUY_DCHECK_EQ(x, y) RUY_CHECK_EQ(x, y) -#define RUY_DCHECK_NE(x, y) RUY_CHECK_NE(x, y) -#define RUY_DCHECK_GE(x, y) RUY_CHECK_GE(x, y) -#define RUY_DCHECK_GT(x, y) RUY_CHECK_GT(x, y) -#define RUY_DCHECK_LE(x, y) RUY_CHECK_LE(x, y) -#define RUY_DCHECK_LT(x, y) RUY_CHECK_LT(x, y) -#endif - -} // end namespace check_macros -} // end namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_CHECK_MACROS_H_ diff --git a/check_macros_test.cc b/check_macros_test.cc deleted file mode 100644 index 459513e..0000000 --- a/check_macros_test.cc +++ /dev/null @@ -1,153 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "check_macros.h" - -#include "testing/base/public/gunit.h" - -namespace { - -#define TEST_CONDITION_FOR_FAMILY(family, vacuously_succeeds, condition) \ - do { \ - if (vacuously_succeeds || (condition)) { \ - RUY_##family(condition); \ - } \ - } while (false) - -#define TEST_COMPARISON_FOR_FAMILY(family, vacuously_succeeds, op_name, x, op, \ - y) \ - do { \ - if (vacuously_succeeds || ((x)op(y))) { \ - RUY_##family##_##op_name(x, y); \ - } \ - } while (false) - -#ifdef NDEBUG -#define TEST_CONDITION(condition) \ - do { \ - TEST_CONDITION_FOR_FAMILY(CHECK, false, condition); \ - } while (false) -#define TEST_COMPARISON(op_name, x, op, y) \ - do { \ - TEST_COMPARISON_FOR_FAMILY(CHECK, false, op_name, x, op, y); \ - } while (false) -#else -#define TEST_CONDITION(condition) \ - do { \ - TEST_CONDITION_FOR_FAMILY(CHECK, false, condition); \ - TEST_CONDITION_FOR_FAMILY(DCHECK, false, condition); \ - } while (false) -#define TEST_COMPARISON(op_name, x, op, y) \ - do { \ - TEST_COMPARISON_FOR_FAMILY(CHECK, false, op_name, x, op, y); \ - TEST_COMPARISON_FOR_FAMILY(DCHECK, false, op_name, x, op, y); \ - } while (false) - -#endif - -template -void TestEqualityComparisons(const LhsType& lhs, const RhsType& rhs) { - RUY_CHECK_EQ(lhs, lhs); - TEST_COMPARISON(EQ, lhs, ==, lhs); - RUY_CHECK_EQ(lhs, lhs); - RUY_CHECK_EQ(lhs, lhs); - if (lhs == rhs) { - RUY_CHECK_EQ(lhs, rhs); - } - if (lhs != rhs) { - RUY_CHECK_NE(lhs, rhs); - } -} - -template -void TestComparisons(const LhsType& lhs, const RhsType& rhs) { - TestEqualityComparisons(lhs, rhs); - if (lhs > rhs) { - RUY_CHECK_GT(lhs, rhs); - } - if (lhs >= rhs) { - RUY_CHECK_GE(lhs, rhs); - } - if (lhs < rhs) { - RUY_CHECK_LT(lhs, rhs); - } - if (lhs <= rhs) { - RUY_CHECK_LE(lhs, rhs); - } -} - -TEST(CheckMacrosTest, IntInt) { - TestComparisons(0, 0); - TestComparisons(0, 1); - TestComparisons(1, -1); - TestComparisons(-1, 0); - TestComparisons(123, -456); - TestComparisons(std::numeric_limits::min(), - std::numeric_limits::max()); - TestComparisons(123, std::numeric_limits::max()); - TestComparisons(123, std::numeric_limits::min()); -} - -TEST(CheckMacrosTest, Uint8Uint8) { - TestComparisons(0, 0); - TestComparisons(255, 0); - TestComparisons(0, 255); - TestComparisons(12, 34); -} - -TEST(CheckMacrosTest, Uint8Int) { - TestComparisons(0, std::numeric_limits::min()); - TestComparisons(255, std::numeric_limits::min()); - TestComparisons(0, std::numeric_limits::max()); - TestComparisons(255, std::numeric_limits::max()); -} - -TEST(CheckMacrosTest, FloatFloat) { - TestComparisons(0.f, 0.f); - TestComparisons(0.f, 1.f); - TestComparisons(1.f, -1.f); - TestComparisons(-1.f, 0.f); - TestComparisons(123.f, -456.f); - TestComparisons(std::numeric_limits::lowest(), - std::numeric_limits::max()); - TestComparisons(123.f, std::numeric_limits::max()); - TestComparisons(123.f, std::numeric_limits::lowest()); -} - -TEST(CheckMacrosTest, IntFloat) { - TestComparisons(0, 0.f); - TestComparisons(0, 1.f); - TestComparisons(1, -1.f); - TestComparisons(-1, 0.f); - TestComparisons(123, -456.f); - TestComparisons(std::numeric_limits::lowest(), - std::numeric_limits::max()); - TestComparisons(123, std::numeric_limits::max()); - TestComparisons(123, std::numeric_limits::lowest()); -} - -TEST(CheckMacrosTest, EnumClass) { - enum class SomeEnumClass { kA, kB, kC }; - TestEqualityComparisons(SomeEnumClass::kA, SomeEnumClass::kA); - TestEqualityComparisons(SomeEnumClass::kA, SomeEnumClass::kB); - TestEqualityComparisons(SomeEnumClass::kC, SomeEnumClass::kB); -} - -} // namespace - -int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); -} diff --git a/common.h b/common.h deleted file mode 100644 index 157399c..0000000 --- a/common.h +++ /dev/null @@ -1,73 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Miscellaneous helpers internal library. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_COMMON_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_COMMON_H_ - -#include -#include - -#include "check_macros.h" -#include "matrix.h" -#include "opt_set.h" -#include "path.h" -#include "platform.h" - -#if RUY_OPT_ENABLED(RUY_OPT_PREFETCH_LOAD) -#define RUY_PREFETCH_LOAD(X) X -#else -#define RUY_PREFETCH_LOAD(X) -#endif - -#if RUY_OPT_ENABLED(RUY_OPT_PREFETCH_STORE) -#define RUY_PREFETCH_STORE(X) X -#else -#define RUY_PREFETCH_STORE(X) -#endif - -#define RUY_STR(s) RUY_STR_UNEXPANDED(s) -#define RUY_STR_UNEXPANDED(s) #s - -namespace ruy { - -// Helper for type-erasing a pointer. -// -// Often inside Ruy, a template parameter holds type information statically, but -// we would like to have a function signature that doesn't depend on the -// template parameters, so that we can dispatch indirectly across multiple -// implementations. This helper is at the core of such type-erasure. -// -// The opposite of this operation is just `static_cast(void_ptr)`. -template -void* ToVoidPtr(T* p) { - return const_cast(static_cast(p)); -} - -template -Scalar SymmetricZeroPoint() { - if (std::is_floating_point::value) { - return 0; - } - if (std::is_signed::value) { - return 0; - } - return std::numeric_limits::max() / 2 + 1; -} - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_COMMON_H_ diff --git a/context.cc b/context.cc deleted file mode 100644 index 4852abf..0000000 --- a/context.cc +++ /dev/null @@ -1,109 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "context.h" - -#include "check_macros.h" -#include "detect_arm.h" -#include "detect_x86.h" -#include "have_built_path_for.h" -#include "platform.h" - -namespace ruy { - -void Context::SetRuntimeEnabledPaths(Path paths) { - runtime_enabled_paths_ = paths; -} - -Path Context::GetRuntimeEnabledPaths() { - // This function should always return the same value on a given machine. - // When runtime_enabled_paths_ has its initial value kNone, it performs - // some platform detection to resolve it to specific Path values. - - // Fast path: already resolved. - if (runtime_enabled_paths_ != Path::kNone) { - return runtime_enabled_paths_; - } - - // Need to resolve now. Start by considering all paths enabled. - runtime_enabled_paths_ = kAllPaths; - - // This mechanism is intended to be used for testing and benchmarking. For - // example, one can set RUY_FORCE_DISABLE_PATHS to Path::kAvx512 in order to - // evaluate AVX2 performance on an AVX-512 machine. -#ifdef RUY_FORCE_DISABLE_PATHS - runtime_enabled_paths_ = runtime_enabled_paths_ & ~(RUY_FORCE_DISABLE_PATHS); -#endif - -#if RUY_PLATFORM(ARM) - // Now selectively disable paths that aren't supported on this machine. - if ((runtime_enabled_paths_ & Path::kNeonDotprod) != Path::kNone) { - if (!DetectDotprod()) { - runtime_enabled_paths_ = runtime_enabled_paths_ & ~Path::kNeonDotprod; - // Sanity check. - RUY_DCHECK((runtime_enabled_paths_ & Path::kNeonDotprod) == Path::kNone); - } - } -#endif // RUY_PLATFORM(ARM) - -#if RUY_PLATFORM(X86) - // TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / - // placeholder. Optimization is not finished. In particular the dimensions of - // the kernel blocks can be changed as desired. - // - if ((runtime_enabled_paths_ & Path::kSse42) != Path::kNone) { - if (!(HaveBuiltPathForSse42() && DetectCpuSse42())) { - runtime_enabled_paths_ = runtime_enabled_paths_ & ~Path::kSse42; - // Sanity check. - RUY_DCHECK((runtime_enabled_paths_ & Path::kSse42) == Path::kNone); - } - } - - if ((runtime_enabled_paths_ & Path::kAvx2) != Path::kNone) { - if (!(HaveBuiltPathForAvx2() && DetectCpuAvx2())) { - runtime_enabled_paths_ = runtime_enabled_paths_ & ~Path::kAvx2; - // Sanity check. - RUY_DCHECK((runtime_enabled_paths_ & Path::kAvx2) == Path::kNone); - } - } - - if ((runtime_enabled_paths_ & Path::kAvx512) != Path::kNone) { - if (!(HaveBuiltPathForAvx512() && DetectCpuAvx512())) { - runtime_enabled_paths_ = runtime_enabled_paths_ & ~Path::kAvx512; - // Sanity check. - RUY_DCHECK((runtime_enabled_paths_ & Path::kAvx512) == Path::kNone); - } - } - - // TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / - // placeholder. Optimization is not finished. In particular the dimensions of - // the kernel blocks can be changed as desired. - // - if ((runtime_enabled_paths_ & Path::kAvxVnni) != Path::kNone) { - if (!(HaveBuiltPathForAvxVnni() && DetectCpuAvxVnni())) { - runtime_enabled_paths_ = runtime_enabled_paths_ & ~Path::kAvxVnni; - // Sanity check. - RUY_DCHECK((runtime_enabled_paths_ & Path::kAvxVnni) == Path::kNone); - } - } -#endif // RUY_PLATFORM(X86) - - // Sanity check. We can't possibly have disabled all paths, as some paths - // are universally available (kReference, kStandardCpp). - RUY_DCHECK_NE(runtime_enabled_paths_, Path::kNone); - return runtime_enabled_paths_; -} - -} // namespace ruy diff --git a/context.h b/context.h deleted file mode 100644 index 6772bed..0000000 --- a/context.h +++ /dev/null @@ -1,109 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_CONTEXT_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_CONTEXT_H_ - -#include -#include -#include - -#include "allocator.h" -#include "path.h" -#include "prepacked_cache.h" -#include "thread_pool.h" -#include "trace.h" -#include "tune.h" - -namespace ruy { - -// The state private to each Ruy thread. -struct PerThreadState { - // Each thread may be running on a different microarchitecture. For example, - // some threads may be on big cores, while others are on little cores. Thus, - // it's best for the tuning to be per-thread. - TuningResolver tuning_resolver; - // Each thread has its own local allocator. - Allocator allocator; -}; - -// A Context holds runtime information used by Ruy. It holds runtime resources -// such as the workers thread pool and the allocator (which holds buffers for -// temporary data), as well as runtime options controlling which Paths are -// enabled (typically based on which instruction sets are detected) and how -// many threads to use. -struct Context final { - Path last_taken_path = Path::kNone; - Tuning explicit_tuning = Tuning::kAuto; - // TODO(benoitjacob) rename that thread_pool. Current name is gemmlowp legacy. - ThreadPool workers_pool; - int max_num_threads = 1; - // State for each thread in the thread pool. Entry 0 is the main thread. - std::vector> per_thread_states; - TracingContext tracing; - CachePolicy cache_policy = CachePolicy::kNoCache; - - Allocator* GetMainAllocator() { - if (!main_allocator_) { - main_allocator_.reset(new Allocator); - } - return main_allocator_.get(); - } - - PrepackedCache* GetPrepackedCache() { - if (!prepacked_cache_) { - prepacked_cache_.reset(new PrepackedCache); - } - return prepacked_cache_.get(); - } - - void ClearPrepackedCache() { prepacked_cache_ = nullptr; } - - void EnsureNPerThreadStates(int thread_count) { - while (per_thread_states.size() < static_cast(thread_count)) { - per_thread_states.emplace_back(new PerThreadState); - } - } - - Tuning GetMainThreadTuning() { - EnsureNPerThreadStates(1); - TuningResolver* tuning_resolver = &per_thread_states[0]->tuning_resolver; - tuning_resolver->SetTuning(explicit_tuning); - return tuning_resolver->Resolve(); - } - - template - Path GetPathToTake() { - last_taken_path = - GetMostSignificantPath(CompiledPaths & GetRuntimeEnabledPaths()); - return last_taken_path; - } - - void SetRuntimeEnabledPaths(Path paths); - Path GetRuntimeEnabledPaths(); - - private: - // Allocator for main thread work before invoking the threadpool. - // Our simple Allocator does not allow reserving/allocating more blocks - // while it's already in committed state, so the main thread needs both - // this allocator, and its per-thread allocator. - std::unique_ptr main_allocator_; - std::unique_ptr prepacked_cache_; - Path runtime_enabled_paths_ = Path::kNone; -}; - -} // end namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_CONTEXT_H_ diff --git a/context_test.cc b/context_test.cc deleted file mode 100644 index 2a9c4cd..0000000 --- a/context_test.cc +++ /dev/null @@ -1,63 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "context.h" - -#include "testing/base/public/gunit.h" -#include "path.h" -#include "platform.h" - -namespace ruy { -namespace { - -TEST(ContextTest, EnabledPathsGeneral) { - ruy::Context ruy_context; - const auto ruy_paths = ruy_context.GetRuntimeEnabledPaths(); - const auto ruy_paths_repeat = ruy_context.GetRuntimeEnabledPaths(); - ASSERT_EQ(ruy_paths, ruy_paths_repeat); - EXPECT_NE(ruy_paths, Path::kNone); - EXPECT_EQ(ruy_paths & Path::kReference, Path::kReference); - EXPECT_EQ(ruy_paths & Path::kStandardCpp, Path::kStandardCpp); -} - -#if RUY_PLATFORM(X86) -TEST(ContextTest, EnabledPathsX86) { - ruy::Context ruy_context; - ruy_context.SetRuntimeEnabledPaths(Path::kSse42 | Path::kAvx2 | - Path::kAvx512 | Path::kAvxVnni); - const auto ruy_paths = ruy_context.GetRuntimeEnabledPaths(); - EXPECT_EQ(ruy_paths & Path::kReference, Path::kNone); - EXPECT_EQ(ruy_paths & Path::kStandardCpp, Path::kNone); -} -#endif // RUY_PLATFORM(X86) - -#if RUY_PLATFORM(ARM) -TEST(ContextTest, EnabledPathsArm) { - ruy::Context ruy_context; - ruy_context.SetRuntimeEnabledPaths(Path::kNeon | Path::kNeonDotprod); - const auto ruy_paths = ruy_context.GetRuntimeEnabledPaths(); - EXPECT_EQ(ruy_paths & Path::kReference, Path::kNone); - EXPECT_EQ(ruy_paths & Path::kStandardCpp, Path::kNone); - EXPECT_EQ(ruy_paths & Path::kNeon, Path::kNeon); -} -#endif // RUY_PLATFORM(ARM) - -} // namespace -} // namespace ruy - -int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); -} diff --git a/cpu_cache_size.h b/cpu_cache_size.h deleted file mode 100644 index 98d4864..0000000 --- a/cpu_cache_size.h +++ /dev/null @@ -1,81 +0,0 @@ -/* Copyright 2020 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_CPU_CACHE_SIZE_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_CPU_CACHE_SIZE_H_ - -#include "path.h" -#include "platform.h" - -namespace ruy { - -// LocalDataCacheSize returns a sane default size for each CPU core's local -// data cache, i.e. the largest data cache that is local to that CPU core, not -// shared with other cores. That allows coarse tuning of code that aims for -// most of its memory accesses to hit such a typically fast data cache. -// -// SharedDataCacheSize returns a sane default size of the total data cache -// accessible to each CPU, including any shared cache. -// -// For example, if we design tune this code for a ARM Cortex-A55 with a local L1 -// cache of 32k, a local L2 cache of 128k and a shared L3 cache of 1M, -// LocalDataCacheSize should return 128k and SharedDataCacheSize -// should return 1M. -// -// Ideally these values would be queried at runtime, and we should probably -// do that on x86, but that is hard to do on ARM. -#if RUY_PLATFORM(ARM_64) -inline int LocalDataCacheSize() { return 1 << 15; } -inline int SharedDataCacheSize() { return 1 << 19; } -#elif RUY_PLATFORM(ARM_32) -inline int LocalDataCacheSize() { return 1 << 14; } -inline int SharedDataCacheSize() { return 1 << 18; } -#elif RUY_PLATFORM(X86) -inline int LocalDataCacheSize() { return 1 << 17; } -inline int SharedDataCacheSize() { return 1 << 21; } -#else -inline int LocalDataCacheSize() { return 1 << 14; } -inline int SharedDataCacheSize() { return 1 << 18; } -#endif -// Variants taking a Path argument which acts -// as a hint telling whether we're targeting more or less recent/powerful CPUs. -inline int LocalDataCacheSize(Path path) { -#if RUY_PLATFORM(ARM_64) - if (path == Path::kNeonDotprod) { - // At the moment, the smallest CPU with dotprod is probably Cortex-A55 with - // 128k L2 local cache. - return 1 << 17; - } -#else - (void)path; -#endif - return LocalDataCacheSize(); -} -inline int SharedDataCacheSize(Path path) { -#if RUY_PLATFORM(ARM_64) - if (path == Path::kNeonDotprod) { - // At the moment, the smallest CPU with dotprod is probably Cortex-A55 with - // 1M L3 shared cache. - return 1 << 20; - } -#else - (void)path; -#endif - return SharedDataCacheSize(); -} - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_CPU_CACHE_SIZE_H_ diff --git a/detect_arm.cc b/detect_arm.cc deleted file mode 100644 index 3d39360..0000000 --- a/detect_arm.cc +++ /dev/null @@ -1,73 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -/* Detection of dotprod instructions on ARM. - * The current Linux-specific code relies on sufficiently new Linux kernels: - * At least Linux 4.15 in general; on Android, at least Linux 4.14.111 thanks to - * a late backport. This was backported just before the Android 10 release, so - * this is leaving out pre-release Android 10 builds as well as earlier Android - * versions. - * - * It is possible to detect instructions in other ways that don't rely on - * an OS-provided feature identification mechanism: - * - * (A) We used to have a SIGILL-handler-based method that worked at least - * on Linux. Its downsides were (1) crashes on a few devices where - * signal handler installation didn't work as intended; (2) additional - * complexity to generalize to other Unix-ish operating systems including - * iOS; (3) source code complexity and fragility of anything installing - * and restoring signal handlers; (4) confusing behavior under a debugger. - * - * (B) We also experimented with a fork-ing approach where a subprocess - * tries the instruction. Compared to (A), this is much simpler and more - * reliable and portable, but also much higher latency on Android where - * an uncaught signal typically causes a 100 ms latency. - * - * Should there be interest in either technique again in the future, - * code implementing both (A) and (B) can be found in earlier revisions of this - * file - in actual code for (A) and in a comment for (B). - */ - -#include "detect_arm.h" - -#if defined __linux__ && defined __aarch64__ -#include -#endif - -namespace ruy { - -namespace { - -#if defined __linux__ && defined __aarch64__ -bool DetectDotprodByLinuxAuxvMethod() { - // This is the value of HWCAP_ASIMDDP in sufficiently recent Linux headers, - // however we need to support building against older headers for the time - // being. - const int kLocalHwcapAsimddp = 1 << 20; - return getauxval(AT_HWCAP) & kLocalHwcapAsimddp; -} -#endif - -} // namespace - -bool DetectDotprod() { -#if defined __linux__ && defined __aarch64__ - return DetectDotprodByLinuxAuxvMethod(); -#endif - - return false; -} - -} // namespace ruy diff --git a/detect_arm.h b/detect_arm.h deleted file mode 100644 index e843a68..0000000 --- a/detect_arm.h +++ /dev/null @@ -1,29 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Temporary dotprod-detection code until we can rely on getauxval. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_DETECT_ARM_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_DETECT_ARM_H_ - -namespace ruy { - -// On A64, returns true if the dotprod extension is present. -// On other architectures, returns false unconditionally. -bool DetectDotprod(); - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_DETECT_ARM_H_ diff --git a/detect_x86.cc b/detect_x86.cc deleted file mode 100644 index 7477ea3..0000000 --- a/detect_x86.cc +++ /dev/null @@ -1,101 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "detect_x86.h" - -#include - -#if RUY_PLATFORM(X86) && RUY_PLATFORM(X86_ENHANCEMENTS) -#include // IWYU pragma: keep - -#endif - -namespace ruy { -#if RUY_PLATFORM(X86) && RUY_PLATFORM(X86_ENHANCEMENTS) - -namespace { - -// See Intel docs, such as http://goo.gl/c6IkGX. -inline void RunCpuid(std::uint32_t eax, std::uint32_t ecx, - std::uint32_t abcd[4]) { - std::uint32_t ebx, edx; -#if defined(__i386__) && defined(__PIC__) - /* in case of PIC under 32-bit EBX cannot be clobbered */ - asm volatile("movl %%ebx, %%edi \n\t cpuid \n\t xchgl %%ebx, %%edi" - : "=D"(ebx), -#else - asm volatile("cpuid" - : "+b"(ebx), -#endif - "+a"(eax), "+c"(ecx), "=d"(edx)); - abcd[0] = eax; - abcd[1] = ebx; - abcd[2] = ecx; - abcd[3] = edx; -} - -} // namespace - -bool DetectCpuSse42() { - std::uint32_t abcd[4]; - - constexpr std::uint32_t kEcxSse42 = 1u << 20; - RunCpuid(1, 0, abcd); - const bool has_sse4_2_base = (abcd[2] & kEcxSse42) == kEcxSse42; - -#ifdef RUY_ENABLE_AMD_CPUID_CHECKS - constexpr std::uint32_t kEcxAbm = 1u << 5; - RunCpuid(0x80000001, 0, abcd); - const bool has_extras = (abcd[2] & kEcxAbm) == kEcxAbm; -#else - constexpr std::uint32_t kEcxPopcnt = 1u << 23; - RunCpuid(1, 0, abcd); - const bool has_extras = (abcd[2] & kEcxPopcnt) == kEcxPopcnt; -#endif - - return has_sse4_2_base && has_extras; -} - -bool DetectCpuAvx2() { - constexpr std::uint32_t kEbxAvx2 = 1u << 5; - constexpr std::uint32_t kEcxFma = 1u << 12; - - std::uint32_t abcd[4]; - - RunCpuid(7, 0, abcd); - const bool has_avx2 = (abcd[1] & kEbxAvx2) == kEbxAvx2; - RunCpuid(1, 0, abcd); - const bool has_fma = (abcd[2] & kEcxFma) == kEcxFma; - - return has_avx2 && has_fma; -} - -bool DetectCpuAvx512() { - constexpr std::uint32_t kEbxAvx512F = 1u << 16; - constexpr std::uint32_t kEbxAvx512Dq = 1u << 17; - constexpr std::uint32_t kEbxAvx512Cd = 1u << 28; - constexpr std::uint32_t kEbxAvx512Bw = 1u << 30; - constexpr std::uint32_t kEbxAvx512Vl = 1u << 31; - - constexpr std::uint32_t kEbxAvx512Mask = - kEbxAvx512F | kEbxAvx512Dq | kEbxAvx512Cd | kEbxAvx512Bw | kEbxAvx512Vl; - std::uint32_t abcd[4]; - RunCpuid(7, 0, abcd); - - return (abcd[1] & kEbxAvx512Mask) == kEbxAvx512Mask; -} - -#endif -} // namespace ruy diff --git a/detect_x86.h b/detect_x86.h deleted file mode 100644 index d330d05..0000000 --- a/detect_x86.h +++ /dev/null @@ -1,49 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_DETECT_X86_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_DETECT_X86_H_ - -#include "platform.h" - -namespace ruy { - -#if RUY_PLATFORM(X86) -#if RUY_PLATFORM(X86_ENHANCEMENTS) - -// This also checks ABM support, which implies LZCNT and POPCNT. -bool DetectCpuSse42(); -bool DetectCpuAvx2(); -bool DetectCpuAvx512(); -// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. -// Optimization is not finished. In particular the dimensions of the kernel -// blocks can be changed as desired. -// -// TODO(b/146646451): Introduce and activate. -inline bool DetectCpuAvxVnni() { return false; } - -#else // RUY_PLATFORM(X86_ENHANCEMENTS) - -inline bool DetectCpuSse42() { return false; } -inline bool DetectCpuAvx2() { return false; } -inline bool DetectCpuAvx512() { return false; } -inline bool DetectCpuAvxVnni() { return false; } - -#endif // !RUY_PLATFORM(X86_ENHANCEMENTS) -#endif // RUY_PLATFORM(X86) - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_DETECT_X86_H_ diff --git a/dispatch.h b/dispatch.h deleted file mode 100644 index 3b9c8b2..0000000 --- a/dispatch.h +++ /dev/null @@ -1,482 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This file implements the translation between Ruy's entry point (ruy::Mul) and -// the internal implementation of matrix multiplication. -// -// The primary elements of this dispatch are: -// - pick suitable gemm kernel and packing routines for the user-specified -// CompiledPaths based on the current CPU. -// - decide on the structure of the packed matrices needed by the internal -// implementation (see pack.h for more information on packing). -// - translate the Mul operation into TrMul (see trmul.h for why that is -// useful). This is done by changing the matrix Layout -- no matrix data is -// actually moved. -// -// This file is also factored to serve as a building block for the advanced API -// as well. -// -// This file also performs some checking of invariants to catch user errors. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_DISPATCH_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_DISPATCH_H_ - -#include -#include -#include // IWYU pragma: keep -#include - -#include "check_macros.h" -#include "common.h" -#include "context.h" -#include "internal_matrix.h" -#include "kernel.h" -#include "kernel_common.h" -#include "matrix.h" -#include "opt_set.h" -#include "pack.h" -#include "pack_common.h" -#include "path.h" -#include "profiler/instrumentation.h" -#include "side_pair.h" -#include "size_util.h" -#include "spec.h" -#include "trmul.h" -#include "trmul_params.h" - -namespace ruy { - -// If the Spec's LayoutSupport covers only some special cases, -// this function enforces that the matrix multiplication at hand falls into -// that special case. -template -void EnforceLayoutSupport(const Layout& lhs_layout, const Layout& rhs_layout, - const Layout& dst_layout) { - if (Spec::kLayoutSupport == LayoutSupport::kRCC) { - RUY_DCHECK(IsRowMajor(lhs_layout)); - RUY_DCHECK(IsColMajor(rhs_layout)); - RUY_DCHECK(IsColMajor(dst_layout)); - } -} - -template -bool IsSymmetricZeroPoint(Scalar zero_point) { - return zero_point == SymmetricZeroPoint(); -} - -template -void CheckZeroPoint(Scalar zero_point) { - if (std::is_floating_point::value || - Spec::kZeroPointSupport == ZeroPointSupport::kSymmetric) { - RUY_DCHECK(IsSymmetricZeroPoint(zero_point)); - } -} - -template -void EnforceZeroPointSupport(LhsScalar lhs_zero_point, RhsScalar rhs_zero_point, - DstScalar dst_zero_point) { - // If the Spec's ZeroPointSupport covers only some special cases, - // this function enforces that the matrix multiplication at hand falls into - // that special case. - CheckZeroPoint(lhs_zero_point); - CheckZeroPoint(rhs_zero_point); - CheckZeroPoint(dst_zero_point); - - // Guard against the case when both LHS and RHS zero_point's are equal to - // the minimum representable value. In that case, padding with zero_point - // values will generate the bad case for fast int8 kernels on NEON - // (pre-dotprod) which attempt to multiply-accumulate two pairs of int8 - // into a int16: this is safe except in the bad case -128*-128 + -128*-128. - // See b/131609283. This only affects the kNeon path but we ban this for all - // paths in order for ruy to have the same supported parameter space - // on all paths. - RUY_DCHECK(lhs_zero_point != std::numeric_limits::lowest() || - rhs_zero_point != std::numeric_limits::lowest()); -} - -template -void EnforceDstSpecSupport(const Spec& spec, DstScalar dst_zero_point) { - static_assert(std::is_same::value, ""); - if (!std::is_same::value) return; - - // If user is looking for the raw accumulator, zero_point and all the other - // dequantize fields don't make sense and should not be set. - RUY_DCHECK_EQ(dst_zero_point, 0); - RUY_DCHECK_EQ(spec.clamp_max, std::numeric_limits::max()); - RUY_DCHECK_EQ(spec.clamp_min, std::numeric_limits::min()); - RUY_DCHECK_EQ(spec.multiplier_fixedpoint, 0); - RUY_DCHECK_EQ(spec.multiplier_exponent, 0); - RUY_DCHECK_EQ(spec.multiplier_fixedpoint_perchannel, nullptr); - RUY_DCHECK_EQ(spec.multiplier_exponent_perchannel, nullptr); -} - -inline bool IsColMajorTrMul(const TrMulParams& params) { - return IsColMajor(params.src[Side::kLhs].layout) && - IsColMajor(params.src[Side::kRhs].layout) && - IsColMajor(params.dst.layout); -} - -inline void CreatePackedLayout(const Layout& src, const Type& scalar, - const KernelLayout& kernel_layout, - PackedLayout* packed) { - packed->order = Order::kColMajor; - packed->rows = round_up_pot(src.rows, kernel_layout.rows); - packed->cols = round_up_pot(src.cols, kernel_layout.cols); - packed->kernel = kernel_layout; - int inner_size = packed->rows; - if (RUY_OPT_ENABLED(RUY_OPT_AVOID_ALIASING)) { - packed->stride = - (inner_size * scalar.size) % 1024 ? inner_size : inner_size + 64; - } else { - packed->stride = inner_size; - } -} - -template -void CreatePackedMatrix(Side side, const KernelLayout& kernel_layout, - TrMulParams* params) { - // Ruy always uses 32-bit signed accumulators for quantized - // matrix multiplication, so we would like to always use std::int32_t - // unconditionally for SumsType. - // However, for floating point types, we still need a reasonable type here to - // avoid tripping assertions elsewhere in the code. - using SumsType = - typename std::conditional::value, Scalar, - std::int32_t>::type; - - const DMatrix& src = params->src[side]; - PMatrix* packed = ¶ms->packed[side]; - packed->data_type = Type::Create(); - packed->sums_type = Type::Create(); - CreatePackedLayout(src.layout, packed->data_type, kernel_layout, - &packed->layout); - packed->zero_point = Pack(src.zero_point); -} - -template -void PopulateTrMulParams(TrMulParams* params) { - static_assert((ThePath & Path::kReference) == Path::kNone, - "Path::kReference should not do TrMul"); - // The optimized code paths don't handle the full generality of Ruy's API. - // Fall back to Path::kStandardCpp if necessary. - bool fallback_to_standard_cpp = false; - if (ThePath != Path::kStandardCpp) { - // The optimized code paths currently only handle the case of all matrices - // being column major. - if (!IsColMajorTrMul(*params)) { - fallback_to_standard_cpp = true; - } - } - - if (fallback_to_standard_cpp) { - PopulateTrMulParams(params); - return; - } - - using PackedLhsScalar = PackedType; - using PackedRhsScalar = PackedType; - using Kernel = - Kernel; - using LhsKernelLayout = typename Kernel::LhsLayout; - using RhsKernelLayout = typename Kernel::RhsLayout; - - params->path = ThePath; - - params->local_data_cache_size = Spec::local_data_cache_size(); - params->shared_data_cache_size = Spec::shared_data_cache_size(); - - CreatePackedMatrix( - Side::kLhs, ToKernelLayout(), params); - CreatePackedMatrix( - Side::kRhs, ToKernelLayout(), params); - params->run_pack[Side::kLhs] = - &RunPack; - params->run_pack[Side::kRhs] = - &RunPack; - params->run_kernel = - &RunKernel; - - return; -} - -// PopulateTrMulParamsAllCompiledPaths calls into one of multiple -// instantiations of PopulateTrMulParams. For each bit that is set in -// CompiledPaths, it statically instantiates PopulateTrMulParams with a Path -// corresponding to that single bit. The call to PopulateTrMulParams is -// guarded by a runtime check that it is in fact the dynamically selected path. -// -// PopulateTrMulParamsAllCompiledPaths is implemented with template -// metaprogramming by mutual recursion between PathSearchCountdown and -// PathSearchCompiledPaths. -// -// PopulateTrMulParamsAllCompiledPaths is logically implementing the following -// computation: -// -// template -// void PopulateTrMulParamsAllCompiledPaths(Path the_path, -// TrMulParams* params) { -// for (int bit = 8 * sizeof(Path) - 1; bit != -1; bit--) { // [1] -// Path current_path = static_cast(1 << bit); -// if ((CompiledPaths & current_path) != Path::kNone) { // [2] -// if (current_path == the_path) { // [3] -// PopulateTrMulParams(the_path, params); -// return; -// } -// } -// } -// } -// -// -// -// [1] - Done by the main definition of PathSearchCountdown. The `bit--` is -// done in the recursion of PathSearchOnlyCompiledPaths. -// [2] - Done by PathSearchOnlyCompiledPaths's partial template -// specialization on InCompiledPaths. This is the check which necessitates -// doing the whole computation at C++ compile time. -// [3] - Done by the `if` in the main definition of -// PathSearchOnlyCompiledPaths. -// -// The template metaprogramming is necessary because: -// - In `PopulateTrMulParams`, current_path must be a C++ -// compile-time constant. -// - PopulateTrMulParamsAllCompiledPaths must not instantiate -// inner loops for paths that are not in CompiledPaths, since that can result in -// bogus instantiations which cause a compile time failure. -template -struct PathSearchCountdown; - -template -struct PathSearchOnlyCompiledPaths { - static constexpr Path kCurrentPath = static_cast(1 << BitNumber); - static void Search(Path the_path, TrMulParams* params) { - if (kCurrentPath == the_path) { - PopulateTrMulParams( - params); - return; - } - PathSearchCountdown::Search(the_path, params); - } -}; - -// Skip this iteration if CompiledPaths doesn't contain the specified path. -template -struct PathSearchOnlyCompiledPaths { - static void Search(Path the_path, TrMulParams* params) { - PathSearchCountdown::Search(the_path, params); - } -}; - -template -struct PathSearchCountdown { - static constexpr Path kCurrentPath = static_cast(1 << BitNumber); - static void Search(Path the_path, TrMulParams* params) { - PathSearchOnlyCompiledPaths< - CompiledPaths, (CompiledPaths & kCurrentPath) != Path::kNone, BitNumber, - LhsScalar, RhsScalar, DstScalar, Spec>::Search(the_path, params); - } -}; - -// Termination of the countdown. If the counter reaches -1, then we haven't -// found the specified path. -template -struct PathSearchCountdown { - static void Search(Path the_path, TrMulParams* params) { RUY_DCHECK(false); } -}; - -template -void PopulateTrMulParamsAllCompiledPaths(Path the_path, TrMulParams* params) { - return PathSearchCountdown::Search(the_path, - params); -} - -template -void CreateTrMulParams(const Matrix& lhs, - const Matrix& rhs, const Spec& spec, - Context* context, Matrix* dst, Path the_path, - TrMulParams* params) { - // Fill in the fields we already know. - params->src[Side::kLhs] = ToDMatrix(lhs); - params->src[Side::kRhs] = ToDMatrix(rhs); - params->dst = ToDMatrix(*dst); - params->spec = ToVoidPtr(&spec); - - // Create inner loops and packed matrices based on the Path. - PopulateTrMulParamsAllCompiledPaths(the_path, params); -} - -template -void ReferenceMul(const Matrix& lhs, const Matrix& rhs, - const Spec& spec, Matrix* dst) { - profiler::ScopeLabel label("ReferenceMul"); - for (int i = 0; i < lhs.layout.rows; i++) { - for (int j = 0; j < rhs.layout.cols; j++) { - using AccumScalar = typename Spec::AccumScalar; - AccumScalar accum = 0; - for (int k = 0; k < lhs.layout.cols; k++) { - AccumScalar lhs_val = Element(lhs, i, k); - AccumScalar rhs_val = Element(rhs, k, j); - accum += (lhs_val - lhs.zero_point) * (rhs_val - rhs.zero_point); - } - if (spec.bias) { - accum += spec.bias[i]; - } - ApplyMultiplier(spec, i, &accum); - accum += dst->zero_point; - accum = std::min(accum, spec.clamp_max); - accum = std::max(accum, spec.clamp_min); - *ElementPtr(dst, i, j) = static_cast(accum); - } - } -} - -// Compile-time dispatch to ReferenceMul. This allows us to statically ensure -// that there is no call to ReferenceMul in the user's binary. -template -struct CompileTimeEnabledReferenceMul { - template - static void Run(const Matrix& lhs, const Matrix& rhs, - const Spec& spec, Matrix* dst) { - ReferenceMul(lhs, rhs, spec, dst); - } -}; - -// When this partial specialization is chosen, it ensures that ReferenceMul -// is never compiled. -template <> -struct CompileTimeEnabledReferenceMul { - template - static void Run(const Matrix& lhs, const Matrix& rhs, - const Spec& spec, Matrix* dst) { - RUY_DCHECK(false); - } -}; - -inline void HandlePrepackedCaching(TrMulParams* params, - const SidePair& cacheable, - Context* context) { - if (context->cache_policy == CachePolicy::kNoCache) { - return; - } - - if (context->cache_policy == CachePolicy::kCacheLHSOnNarrowMul) { - // TODO(b/149304278) Cache on dst.cols <= selected kernel width. - if (!cacheable[Side::kLhs] || params->dst.layout.cols > 4) { - return; - } - PrepackedCache* prepacked_cache = context->GetPrepackedCache(); - auto cache_key = std::make_pair(reinterpret_cast(params->run_kernel), - params->src[Side::kLhs].data); - auto it = prepacked_cache->FindAndUpdate(cache_key); - if (it != prepacked_cache->cend()) { - params->packed[Side::kLhs].data = it->second.first.data; - params->packed[Side::kLhs].sums = it->second.first.sums; - params->is_prepacked[Side::kLhs] = true; - return; - } - - // Allocate the prepacked matrix. - PrepackedMatrix prepacked_lhs; - prepacked_lhs.data_size = DataSize(params->packed[Side::kLhs]); - prepacked_lhs.sums_size = SumsSize(params->packed[Side::kLhs]); - prepacked_cache->AllocatePrepackedMatrix(&prepacked_lhs); - params->packed[Side::kLhs].data = prepacked_lhs.data; - params->packed[Side::kLhs].sums = prepacked_lhs.sums; - params->is_prepacked[Side::kLhs] = true; - Tuning tuning = context->GetMainThreadTuning(); - params->RunPack(Side::kLhs, tuning, 0, - params->packed[Side::kLhs].layout.cols); - prepacked_cache->Insert(cache_key, prepacked_lhs); - return; - } -} - -template -void DispatchMul(const Matrix& lhs, const Matrix& rhs, - const Spec& spec, Context* context, Matrix* dst) { - static_assert(CompiledPaths != Path::kNone, "Must compile at least one Path"); - static_assert((CompiledPaths & ~kAllPaths) == Path::kNone, - "CompiledPaths must be a subset of ruy::kAllPaths"); - - profiler::ScopeLabel mul_label("Mul"); - profiler::ScopeLabel shape_specific_label("matmul shape: %dx%dx%d", - lhs.layout.rows, lhs.layout.cols, - rhs.layout.cols); - - EnforceLayoutSupport(lhs.layout, rhs.layout, dst->layout); - EnforceZeroPointSupport(lhs.zero_point, rhs.zero_point, - dst->zero_point); - EnforceDstSpecSupport(spec, dst->zero_point); - - // This should be a constant, for a given machine and CompiledPaths. - // There is a back door to override it for testing, but in production it will - // always be the "best" Path. I.e. the one with the newest SIMD instructions - // available on the present machine, and avoiding Path::kReference unless - // no other path is compiled. - // - // Unfortunately, it is not a *static* constant, since it depends on runtime - // detection of the available SIMD instructions. - Path the_path = context->GetPathToTake(); - - // Production code should probably never execute Path::kReference. - // Path::kReference implements a Mul, not a TrMul like the rest of Ruy, so if - // that's what we need to do, then get it out of the way before going down the - // TrMul path. - if (the_path == Path::kReference) { - constexpr bool ReferenceMulIsEnabled = - (CompiledPaths & Path::kReference) != Path::kNone; - CompileTimeEnabledReferenceMul::Run(lhs, rhs, spec, - dst); - return; - } - - // As described in the comment at the top of this file, Ruy internally - // converts Mul into TrMul. We handle that here. - // - // This is Ruy's main code path. - constexpr Path TrMulCompiledPaths = CompiledPaths & ~Path::kReference; - Matrix transposed_lhs(lhs); - Transpose(&transposed_lhs); - TrMulParams params; - CreateTrMulParams(transposed_lhs, rhs, spec, context, dst, - the_path, ¶ms); - SidePair cacheable(lhs.cacheable, rhs.cacheable); - HandlePrepackedCaching(¶ms, cacheable, context); - TrMul(¶ms, context); -} - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_DISPATCH_H_ diff --git a/example.cc b/example.cc deleted file mode 100644 index ce3fd81..0000000 --- a/example.cc +++ /dev/null @@ -1,136 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include - -#include "ruy.h" - -void ExampleMulFloat(ruy::Context *context) { - const float lhs_data[] = {1, 2, 3, 4}; - const float rhs_data[] = {1, 2, 3, 4}; - float dst_data[4]; - - ruy::Matrix lhs; - ruy::MakeSimpleLayout(2, 2, ruy::Order::kRowMajor, &lhs.layout); - lhs.data = lhs_data; - ruy::Matrix rhs; - ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, &rhs.layout); - rhs.data = rhs_data; - ruy::Matrix dst; - ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, &dst.layout); - dst.data = dst_data; - - ruy::BasicSpec spec; - ruy::Mul(lhs, rhs, spec, context, &dst); - - std::cout << "Example Mul, float:\n"; - std::cout << "LHS:\n" << lhs; - std::cout << "RHS:\n" << rhs; - std::cout << "Result:\n" << dst << "\n"; -} - -void ExampleMulFloatWithBiasAddAndClamp(ruy::Context *context) { - const float lhs_data[] = {1, 2, 3, 4}; - const float rhs_data[] = {1, 2, 3, 4}; - const float bias_data[] = {1, 0}; - float dst_data[4]; - - ruy::Matrix lhs; - ruy::MakeSimpleLayout(2, 2, ruy::Order::kRowMajor, &lhs.layout); - lhs.data = lhs_data; - ruy::Matrix rhs; - ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, &rhs.layout); - rhs.data = rhs_data; - ruy::Matrix dst; - ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, &dst.layout); - dst.data = dst_data; - - ruy::BasicSpec spec; - spec.bias = bias_data; - spec.clamp_min = 0; - spec.clamp_max = 15; - ruy::Mul(lhs, rhs, spec, context, &dst); - - std::cout << "Example Mul, float with bias addition and clamp:\n"; - std::cout << "LHS:\n" << lhs; - std::cout << "RHS:\n" << rhs; - std::cout << "Result:\n" << dst << "\n"; -} - -void ExampleMulUint8AsymmetricQuantized(ruy::Context *context) { - const std::uint8_t lhs_data[] = {124, 125, 126, 127}; - const std::uint8_t rhs_data[] = {129, 130, 131, 132}; - std::uint8_t dst_data[4]; - - ruy::Matrix lhs; - ruy::MakeSimpleLayout(2, 2, ruy::Order::kRowMajor, &lhs.layout); - lhs.data = lhs_data; - lhs.zero_point = 125; - ruy::Matrix rhs; - ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, &rhs.layout); - rhs.data = rhs_data; - rhs.zero_point = 132; - ruy::Matrix dst; - ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, &dst.layout); - dst.data = dst_data; - dst.zero_point = 129; - - ruy::BasicSpec spec; - spec.multiplier_fixedpoint = 1 << 30; - - spec.multiplier_exponent = 0; - ruy::Mul(lhs, rhs, spec, context, &dst); - - std::cout << "Example Mul, uint8 quantized with asymmetric zero points:\n"; - std::cout << "LHS:\n" << lhs; - std::cout << "RHS:\n" << rhs; - std::cout << "Result:\n" << dst << "\n"; -} -void ExampleMulInt8PerChannelQuantized(ruy::Context *context) { - const std::int8_t lhs_data[] = {1, 2, 3, 4}; - const std::int8_t rhs_data[] = {1, 2, 3, 4}; - const std::int32_t multiplier_data[] = {3 << 28, 5 << 28}; - const int exponent_data[] = {1, -2}; - std::int8_t dst_data[4]; - - ruy::Matrix lhs; - ruy::MakeSimpleLayout(2, 2, ruy::Order::kRowMajor, &lhs.layout); - lhs.data = lhs_data; - ruy::Matrix rhs; - ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, &rhs.layout); - rhs.data = rhs_data; - ruy::Matrix dst; - ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, &dst.layout); - dst.data = dst_data; - - ruy::BasicSpec spec; - spec.multiplier_fixedpoint_perchannel = multiplier_data; - spec.multiplier_exponent_perchannel = exponent_data; - ruy::Mul(lhs, rhs, spec, context, &dst); - - std::cout << "Example Mul, int8 quantized with per-channel multipliers\n"; - std::cout << "LHS:\n" << lhs; - std::cout << "RHS:\n" << rhs; - std::cout << "Result:\n" << dst << "\n"; -} - -int main() { - ruy::Context context; - ExampleMulFloat(&context); - ExampleMulFloatWithBiasAddAndClamp(&context); - ExampleMulUint8AsymmetricQuantized(&context); - ExampleMulInt8PerChannelQuantized(&context); -} diff --git a/example_advanced.cc b/example_advanced.cc deleted file mode 100644 index 90a6473..0000000 --- a/example_advanced.cc +++ /dev/null @@ -1,83 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include -#include - -#include "ruy_advanced.h" - -// Simple allocator for allocating pre-packed matrices. -class SimpleAllocator { - public: - void* AllocateBytes(std::size_t num_bytes) { - char* p = new char[num_bytes]; - buffers_.emplace_back(p); - return static_cast(p); - } - - private: - std::vector> buffers_; -}; - -void ExamplePrepack(ruy::Context* context) { - const float lhs_data[] = {1, 2, 3, 4}; - const float rhs_data[] = {1, 2, 3, 4}; - float dst_data[4]; - - // Set up the matrix layouts and spec. - ruy::Matrix lhs; - ruy::MakeSimpleLayout(2, 2, ruy::Order::kRowMajor, &lhs.layout); - ruy::Matrix rhs; - ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, &rhs.layout); - ruy::Matrix dst; - ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, &dst.layout); - ruy::BasicSpec spec; - - SimpleAllocator allocator; - auto alloc_fn = [&allocator](std::size_t num_bytes) -> void* { - return allocator.AllocateBytes(num_bytes); - }; - - // In this example, we pre-pack only the RHS, but either will work. - // Note that we only need to set the data pointer for the matrix we are - // pre-packing. - ruy::PrepackedMatrix prepacked_rhs; - rhs.data = rhs_data; - ruy::PrePackForMul(lhs, rhs, spec, context, &dst, - /*prepacked_lhs=*/nullptr, &prepacked_rhs, - alloc_fn); - - // No data will be read from the RHS input matrix when using a pre-packed RHS. - rhs.data = nullptr; - lhs.data = lhs_data; - dst.data = dst_data; - ruy::MulWithPrepacked(lhs, rhs, spec, context, &dst, - /*prepacked_lhs=*/nullptr, - &prepacked_rhs); - rhs.data = rhs_data; - - // Print out the results. - std::cout << "Example Mul with pre-packing RHS, float:\n"; - std::cout << "LHS:\n" << lhs; - std::cout << "RHS:\n" << rhs; - std::cout << "Result:\n" << dst << "\n"; -} - -int main() { - ruy::Context context; - ExamplePrepack(&context); -} diff --git a/have_built_path_for.h b/have_built_path_for.h deleted file mode 100644 index 98c6af5..0000000 --- a/have_built_path_for.h +++ /dev/null @@ -1,32 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_HAVE_BUILT_PATH_FOR_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_HAVE_BUILT_PATH_FOR_H_ - -#include "platform.h" - -namespace ruy { - -#if RUY_PLATFORM(X86) -bool HaveBuiltPathForSse42(); -bool HaveBuiltPathForAvx2(); -bool HaveBuiltPathForAvx512(); -bool HaveBuiltPathForAvxVnni(); -#endif // RUY_PLATFORM(X86) - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_HAVE_BUILT_PATH_FOR_H_ diff --git a/have_built_path_for_avx2.cc b/have_built_path_for_avx2.cc deleted file mode 100644 index 33d1b1c..0000000 --- a/have_built_path_for_avx2.cc +++ /dev/null @@ -1,35 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "have_built_path_for.h" -#include "opt_set.h" - -namespace ruy { - -#if RUY_PLATFORM(X86) -// IMPORTANT: -// These patterns must match those in the pack and kernel cc files. -#if !(RUY_PLATFORM(AVX2) && RUY_OPT_ENABLED(RUY_OPT_ASM)) - -bool HaveBuiltPathForAvx2() { return false; } - -#else // RUY_PLATFORM(AVX2) && RUY_OPT_ENABLED(RUY_OPT_ASM) - -bool HaveBuiltPathForAvx2() { return true; } - -#endif // RUY_PLATFORM(AVX2) && RUY_OPT_ENABLED(RUY_OPT_ASM) -#endif // RUY_PLATFORM(X86) - -} // namespace ruy diff --git a/have_built_path_for_avx512.cc b/have_built_path_for_avx512.cc deleted file mode 100644 index 35c4095..0000000 --- a/have_built_path_for_avx512.cc +++ /dev/null @@ -1,35 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "have_built_path_for.h" -#include "opt_set.h" - -namespace ruy { - -#if RUY_PLATFORM(X86) -// IMPORTANT: -// These patterns must match those in the pack and kernel cc files. -#if !(RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_ASM)) - -bool HaveBuiltPathForAvx512() { return false; } - -#else // RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_ASM) - -bool HaveBuiltPathForAvx512() { return true; } - -#endif // RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_ASM) -#endif // RUY_PLATFORM(X86) - -} // namespace ruy diff --git a/have_built_path_for_avxvnni.cc b/have_built_path_for_avxvnni.cc deleted file mode 100644 index 5c642a3..0000000 --- a/have_built_path_for_avxvnni.cc +++ /dev/null @@ -1,39 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "have_built_path_for.h" -#include "opt_set.h" - -namespace ruy { - -#if RUY_PLATFORM(X86) -// IMPORTANT: -// These patterns must match those in the pack and kernel cc files. -#if !(RUY_PLATFORM(AVX_VNNI) && RUY_OPT_ENABLED(RUY_OPT_ASM)) - -bool HaveBuiltPathForAvxVnni() { return false; } - -#else // RUY_PLATFORM(AVX_VNNI) && RUY_OPT_ENABLED(RUY_OPT_ASM) - -// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. -// Optimization is not finished. In particular the dimensions of the kernel -// blocks can be changed as desired. -// -bool HaveBuiltPathForAvxVnni() { return true; } - -#endif // RUY_PLATFORM(AVX_VNNI) && RUY_OPT_ENABLED(RUY_OPT_ASM) -#endif // RUY_PLATFORM(X86) - -} // namespace ruy diff --git a/have_built_path_for_sse42.cc b/have_built_path_for_sse42.cc deleted file mode 100644 index 04c7e6b..0000000 --- a/have_built_path_for_sse42.cc +++ /dev/null @@ -1,39 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "have_built_path_for.h" -#include "opt_set.h" - -namespace ruy { - -#if RUY_PLATFORM(X86) -// IMPORTANT: -// These patterns must match those in the pack and kernel cc files. -#if !(RUY_PLATFORM(SSE42) && RUY_OPT_ENABLED(RUY_OPT_ASM)) - -bool HaveBuiltPathForSse42() { return false; } - -#else // RUY_PLATFORM(SSE42) && RUY_OPT_ENABLED(RUY_OPT_ASM) - -// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. -// Optimization is not finished. In particular the dimensions of the kernel -// blocks can be changed as desired. -// -bool HaveBuiltPathForSse42() { return true; } - -#endif // RUY_PLATFORM(SSE42) && RUY_OPT_ENABLED(RUY_OPT_ASM) -#endif // RUY_PLATFORM(X86) - -} // namespace ruy diff --git a/internal_matrix.h b/internal_matrix.h deleted file mode 100644 index 586fa8b..0000000 --- a/internal_matrix.h +++ /dev/null @@ -1,388 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Internal types and helpers for matrices. -// -// Ruy has a couple slightly different notions of matrices, besides the -// Matrix class that we expose to the user-facing API. -// -// TODO(silvasean): Put parts of this architecture description somewhere more -// prominent. -// -// The 4 main matrix types are: -// - Matrix: This is a user-facing type on Ruy's external API boundary. It is -// also used internally. -// - DMatrix: This is a type-erased version of Matrix. "D" = "dynamic". -// - PMatrix: This represents a packed matrix, which requires tracking kernel -// layout and row/column sums for quantization. It is type-erased. -// - PackedMatrix: This is a statically typed variant of PMatrix for -// convenience inside typed routines. -// -// Note that Matrix is *not* implemented in terms of the internal types. It -// is an independent, simple, and user-facing type. -// -// The use of type-erasure might seem surprising for a library like Ruy with a -// heavily-templated entry point, but it is motivated by the desire for most of -// Ruy's "middle-end" to be non-templated. Ruy can be thought of as having 3 -// main parts: -// - "front-end" (dispatch.h) - this is the highly templated ruy::Mul entry -// point, along with routines that select RunKernel and RunPack implementations -// statically based on those template parameters. -// - "back-end" (kernel.h, pack.h)- this consists of the implementations of -// RunKernel and RunPack, often in assembly code, which are the building blocks -// that Ruy calls to perform matrix multiplication. These are templated so that -// only the requested types/Path's are actually emitted by the compiler. -// - "middle-end" (trmul.h) - this is the part of Ruy that orchestrates the -// calls to the "back-end" optimized building blocks. This layer has to deal -// with issues like cache locality and low-overhead multi-threading. -// -// There is a desire for the "middle-end" to be non-templated in order to -// simplify the implementation and reduce code-size. We type-erase when going -// from the "front-end" to the "middle-end", and un-type-erase going from the -// "middle-end" to the "back-end". The un-type-erasure is possible because the -// "front-end" is responsible for instantiating the needed "back-end" templates, -// and thus the static type information is still present. -// -// Each layer of Ruy uses matrix types: -// - "front-end": Matrix -// - "middle-end": DMatrix, PMatrix -// - "back-end": Matrix, PackedMatrix -// -// The use of separate types for packed matrices is not essential, but makes it -// obvious at a glance whether a matrix is a packed matrix or not. We would -// reconsider this decision if there was significant duplication between packed -// and unpacked matrices, but that doesn't seem to be the case at the moment. -// -// Another goal is to keep the user-facing Matrix as simple and -// understandable as possible. Ideally, a user should be able to read the struct -// definition for Matrix and see a very simple definition with no internal -// details like sums and kernel block layout. -// -// To present another structured view of our various matrix types, here's a -// table: -// Plain matrices Packed matrices -// +---------------------------------- -// Templated | Matrix PackedMatrix -// Type-erased | DMatrix PMatrix -// -// -// There is 1 additional matrix type not mentioned above, due to its low -// importance: -// - PrepackedMatrix: This is a user-facing version of PMatrix. It has the bare -// minimum of fields needed for representing the raw data and sums buffers of a -// packed matrix for the "advanced" explicit pre-packing API. This type plays no -// role in Ruy's internals and can generally by ignored. The only reason it -// exists is so that PMatrix is not exposed to users -- we prefer to keep the -// internal matrix types hidden, even from "advanced" users. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_INTERNAL_MATRIX_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_INTERNAL_MATRIX_H_ - -#include -#include -#include -#include - -#include "check_macros.h" -#include "common.h" -#include "matrix.h" -#include "size_util.h" - -namespace ruy { - -// KernelLayout describes small-scale block structure in a packed matrix layout. -// It's a runtime (as opposed to compile-time-constant) version of the -// FixedKernelLayout struct used to declare kernel layouts. -// -// This is is sometimes known as "tiling" in other contexts. -// -// For example, consider a packed matrix in column-major format with a -// column-major KernelLayout. The matrix logically has a shape of -// `[cols, rows]`. However, the matrix is laid out as though it were a 4D array -// of shape `[cols / kcols, rows / krows, kcols, krows]`. -// -// Note that in the case of kcols=1, krows=1, this degenerates to -// `[cols, rows, 1, 1]` which is equivalent to having no small-scale block -// structure. -struct KernelLayout { - Order order = Order::kColMajor; - std::uint8_t rows = 1; - std::uint8_t cols = 1; -}; - -// A packed matrix has a small-scale block structure that is not present in in -// the input matrices. This block structure is necessary for the kernels to -// process data efficiently. -// -// This struct is very similar to Layout, but has the extra KernelLayout field. -struct PackedLayout { - std::int32_t rows = 0; - std::int32_t cols = 0; - // Stride is the offset between two adjacent matrix elements - // in the non-contiguous direction. - std::int32_t stride = 0; - Order order = Order::kColMajor; - // Small scale layout shuffling, potentially departing from - // linear row-major or column-major storage. See KernelLayout. - KernelLayout kernel; -}; - -// Dynamic representation for a type. -// -// The most important field in this struct is the size, which Ruy uses to know -// how much memory to allocate without having to be templated on a type. -// Signed-ness and floating-point-ness are mainly present as debugging checks. -// -// Note: Ruy does not use this struct to to dynamically dispatch between -// different typed implementations. As described in the comment at the top of -// this file, Ruy's "front-end", which is templated, instantiates all the -// necessary "back-end" routines with complete static knowledge of all the -// types. -struct Type { - template - static Type Create() { - Type ret; - ret.is_signed = std::is_signed::value; - ret.is_floating_point = std::is_floating_point::value; - ret.size = sizeof(T); - return ret; - } - - template - void AssertIs() const { - RUY_DCHECK_EQ(is_signed, Create().is_signed); - RUY_DCHECK_EQ(is_floating_point, Create().is_floating_point); - RUY_DCHECK_EQ(size, Create().size); - } - - bool is_signed = false; - bool is_floating_point = false; - std::uint8_t size = 0; -}; - -// Type-erased matrix. -struct DMatrix { - Type data_type; - void* data = nullptr; - Layout layout; - std::int32_t zero_point = 0; -}; - -// Type-erased packed matrix. -struct PMatrix { - Type data_type; - void* data = nullptr; - Type sums_type; - void* sums = nullptr; - PackedLayout layout; - std::int32_t zero_point = 0; -}; - -// Convenient typed helper for packed matrices. -template -struct PackedMatrix { - // The row/column sums needed for quantized matrix multiplication when - // the opposite operand of the multiplication uses a non-symmetric zero - // point. - // This member is only relevant for packed matrices. - // Additionally, Ruy always uses 32-bit signed accumulators for quantized - // matrix multiplication. - // For floating point types, there is no quantization, so this pointer - // will always be null. We still need code referencing it to compile - // though, even if it is always branched around. Hence we use Scalar* - // itself as the type in that case. - using SumsType = - typename std::conditional::value, Scalar, - std::int32_t>::type; - - Scalar* data = nullptr; - SumsType* sums = nullptr; - PackedLayout layout; - std::int32_t zero_point = 0; -}; - -template -DMatrix ToDMatrix(const Matrix& matrix) { - DMatrix ret; - ret.data_type = Type::Create(); - ret.data = ToVoidPtr(matrix.data.get()); - ret.layout = matrix.layout; - ret.zero_point = matrix.zero_point; - return ret; -} - -template -Matrix ToMatrix(const DMatrix& dmatrix) { - dmatrix.data_type.AssertIs(); - Matrix ret; - ret.data = static_cast(dmatrix.data); - ret.layout = dmatrix.layout; - ret.zero_point = dmatrix.zero_point; - return ret; -} - -template -PackedMatrix ToPackedMatrix(const PMatrix& pmatrix) { - using SumsType = typename PackedMatrix::SumsType; - pmatrix.data_type.AssertIs(); - pmatrix.sums_type.AssertIs(); - PackedMatrix ret; - ret.data = static_cast(pmatrix.data); - ret.sums = static_cast(pmatrix.sums); - ret.layout = pmatrix.layout; - ret.zero_point = pmatrix.zero_point; - return ret; -} - -// Helpers for Layout / PackedLayout. - -inline bool IsPacked(const Layout& layout) { - if (layout.order == Order::kColMajor) { - return layout.stride == layout.rows; - } else { - return layout.stride == layout.cols; - } -} - -inline bool IsRowMajor(const Layout& layout) { - return layout.order == Order::kRowMajor; -} - -template -inline bool IsColMajor(const LayoutOrPackedLayout& layout) { - return layout.order == Order::kColMajor; -} - -template -inline int FlatSize(const LayoutOrPackedLayout& layout) { - const int outerdim = - layout.order == Order::kColMajor ? layout.cols : layout.rows; - return layout.stride * outerdim; -} - -// TODO(b/130417400) add a unit test -inline int Offset(const Layout& layout, int row, int col) { - // TODO(benoitjacob) - should check this but this make the _slow tests take - // 5x longer. Find a mitigation like in Eigen with an 'internal' variant - // bypassing the check? - // RUY_DCHECK_GE(row, 0); - // RUY_DCHECK_GE(col, 0); - // RUY_DCHECK_LT(row, layout.rows); - // RUY_DCHECK_LT(col, layout.cols); - int row_stride = layout.order == Order::kColMajor ? 1 : layout.stride; - int col_stride = layout.order == Order::kRowMajor ? 1 : layout.stride; - return row * row_stride + col * col_stride; -} - -// TODO(b/130417400) add a unit test -inline int Offset(const PackedLayout& layout, int row, int col) { - RUY_DCHECK(is_pot(layout.kernel.rows)); - RUY_DCHECK(is_pot(layout.kernel.cols)); - int row_outer = row & ~(layout.kernel.rows - 1); - int col_outer = col & ~(layout.kernel.cols - 1); - int row_stride_outer = - layout.order == Order::kColMajor ? layout.kernel.cols : layout.stride; - int col_stride_outer = - layout.order == Order::kRowMajor ? layout.kernel.rows : layout.stride; - int offset_outer = - row_outer * row_stride_outer + col_outer * col_stride_outer; - int row_inner = row - row_outer; - int col_inner = col - col_outer; - int row_stride_inner = - layout.kernel.order == Order::kColMajor ? 1 : layout.kernel.cols; - int col_stride_inner = - layout.kernel.order == Order::kRowMajor ? 1 : layout.kernel.rows; - int offset_inner = - row_inner * row_stride_inner + col_inner * col_stride_inner; - return offset_outer + offset_inner; -} - -// Helpers for Matrix. - -template -const Scalar* ElementPtr(const Matrix& mat, int row, int col) { - return mat.data.get() + Offset(mat.layout, row, col); -} - -template -Scalar* ElementPtr(Matrix* mat, int row, int col) { - return mat->data.get() + Offset(mat->layout, row, col); -} - -template -Scalar Element(const Matrix& mat, int row, int col) { - return *ElementPtr(mat, row, col); -} - -// Helpers for PackedMatrix. -// Duplicated from Matrix, but the duplication seems acceptable. - -template -const Scalar* ElementPtr(const PackedMatrix& mat, int row, int col) { - return mat.data + Offset(mat.layout, row, col); -} - -template -Scalar* ElementPtr(PackedMatrix* mat, int row, int col) { - return mat->data + Offset(mat->layout, row, col); -} - -template -Scalar Element(const PackedMatrix& mat, int row, int col) { - return *ElementPtr(mat, row, col); -} - -// Helpers for PMatrix. - -inline std::size_t DataSize(const PMatrix& packed) { - return FlatSize(packed.layout) * packed.data_type.size; -} - -inline std::size_t SumsSize(const PMatrix& packed) { - // Packed matrices are only relevant for Ruy's TrMul implementations. For - // TrMul, the number of sums is always equal to the number of columns. - return packed.layout.cols * packed.sums_type.size; -} - -// Transpose helpers. - -inline void Transpose(Order* order) { - *order = *order == Order::kColMajor ? Order::kRowMajor : Order::kColMajor; -} - -inline void Transpose(Layout* layout) { - Transpose(&layout->order); - std::swap(layout->rows, layout->cols); -} - -template -inline void Transpose(Matrix* matrix) { - Transpose(&matrix->layout); -} - -// Helpers for KernelLayout. - -template -KernelLayout ToKernelLayout() { - KernelLayout ret; - ret.order = FixedKernelLayout::kOrder; - ret.rows = FixedKernelLayout::kRows; - ret.cols = FixedKernelLayout::kCols; - return ret; -} - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_INTERNAL_MATRIX_H_ diff --git a/kernel.h b/kernel.h deleted file mode 100644 index d41d26c..0000000 --- a/kernel.h +++ /dev/null @@ -1,31 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_KERNEL_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_KERNEL_H_ - -#include "platform.h" - -// IWYU pragma: begin_exports -#if RUY_PLATFORM(NEON) -#include "kernel_arm.h" -#elif RUY_PLATFORM(X86) -#include "kernel_x86.h" -#else -#include "kernel_common.h" -#endif -// IWYU pragma: end_exports - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_KERNEL_H_ diff --git a/kernel_arm.h b/kernel_arm.h deleted file mode 100644 index 480c41f..0000000 --- a/kernel_arm.h +++ /dev/null @@ -1,211 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_KERNEL_ARM_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_KERNEL_ARM_H_ - -#include -#include - -#include "common.h" -#include "internal_matrix.h" -#include "kernel_common.h" -#include "matrix.h" -#include "opt_set.h" -#include "path.h" -#include "platform.h" -#include "profiler/instrumentation.h" -#include "side_pair.h" -#include "size_util.h" -#include "spec.h" -#include "tune.h" - -namespace ruy { - -#if RUY_PLATFORM(NEON) && RUY_OPT_ENABLED(RUY_OPT_ASM) - -#if RUY_PLATFORM(NEON_64) -void Kernel8bitNeonOutOfOrder(const KernelParams8bit<4, 4>& params); -void Kernel8bitNeonOutOfOrder1Col(const KernelParams8bit<4, 4>& params); -#elif RUY_PLATFORM(NEON_32) -void Kernel8bitNeonOutOfOrder(const KernelParams8bit<4, 2>& params); -void Kernel8bitNeonOutOfOrder1Col(const KernelParams8bit<4, 2>& params); -#endif -void Kernel8bitNeonInOrder(const KernelParams8bit<4, 4>& params); -void Kernel8bitNeonDotprodOutOfOrder(const KernelParams8bit<8, 8>& params); -void Kernel8bitNeonDotprodOutOfOrder1Col(const KernelParams8bit<8, 8>& params); -void Kernel8bitNeonDotprodInOrder(const KernelParams8bit<8, 8>& params); - -#if RUY_PLATFORM(NEON_64) -template -struct Kernel> { - using LhsLayout = FixedKernelLayout; - using RhsLayout = FixedKernelLayout; - Tuning tuning = Tuning::kAuto; - explicit Kernel(Tuning tuning_) : tuning(tuning_) {} - void Run(const PackedMatrix& lhs, - const PackedMatrix& rhs, - const BasicSpec& spec, int start_row, - int start_col, int end_row, int end_col, - Matrix* dst) const { - KernelParams8bit params; - MakeKernelParams8bit(lhs, rhs, spec, start_row, start_col, end_row, end_col, - dst, ¶ms); - if (dst->layout.cols == 1) { - Kernel8bitNeonOutOfOrder1Col(params); - return; - } - if (__builtin_expect(tuning == Tuning::kInOrder, true)) { - Kernel8bitNeonInOrder(params); - } else { - Kernel8bitNeonOutOfOrder(params); - } - } -}; -#endif - -#if RUY_PLATFORM(NEON_32) -template -struct Kernel> { - using LhsLayout = FixedKernelLayout; - using RhsLayout = FixedKernelLayout; - Tuning tuning = Tuning::kAuto; - explicit Kernel(Tuning tuning_) : tuning(tuning_) {} - void Run(const PackedMatrix& lhs, - const PackedMatrix& rhs, - const BasicSpec& spec, int start_row, - int start_col, int end_row, int end_col, - Matrix* dst) const { - KernelParams8bit params; - MakeKernelParams8bit(lhs, rhs, spec, start_row, start_col, end_row, end_col, - dst, ¶ms); - if (dst->layout.cols == 1) { - Kernel8bitNeonOutOfOrder1Col(params); - return; - } - Kernel8bitNeonOutOfOrder(params); - } -}; -#endif - -#if RUY_PLATFORM(NEON_64) -template -struct Kernel> { - Tuning tuning = Tuning::kAuto; - using LhsLayout = FixedKernelLayout; - using RhsLayout = FixedKernelLayout; - explicit Kernel(Tuning tuning_) : tuning(tuning_) {} - void Run(const PackedMatrix& lhs, - const PackedMatrix& rhs, - const BasicSpec& spec, int start_row, - int start_col, int end_row, int end_col, - Matrix* dst) const { - KernelParams8bit params; - MakeKernelParams8bit(lhs, rhs, spec, start_row, start_col, end_row, end_col, - dst, ¶ms); - if (dst->layout.cols == 1) { - Kernel8bitNeonDotprodOutOfOrder1Col(params); - } else if (__builtin_expect(tuning == Tuning::kInOrder, true)) { - Kernel8bitNeonDotprodInOrder(params); - } else { - Kernel8bitNeonDotprodOutOfOrder(params); - } - } -}; -#endif - -void KernelFloatNeonOutOfOrder(const KernelParamsFloat<8, 8>& params); -void KernelFloatNeonInOrder(const KernelParamsFloat<8, 8>& params); -void KernelFloat32NeonOutOfOrder(const KernelParamsFloat<8, 4>& params); -void KernelFloatNeonDotprodInOrder(const KernelParamsFloat<8, 8>& params); - -#if RUY_PLATFORM(NEON_64) -// A Float kernel for ARM64 Neon. -template <> -struct Kernel> { - Tuning tuning = Tuning::kAuto; - using LhsLayout = FixedKernelLayout; - using RhsLayout = FixedKernelLayout; - explicit Kernel(Tuning tuning_) : tuning(tuning_) {} - void Run(const PackedMatrix& lhs, const PackedMatrix& rhs, - const BasicSpec& spec, int start_row, int start_col, - int end_row, int end_col, Matrix* dst) const { - KernelParamsFloat params; - MakeKernelParamsFloat(lhs, rhs, spec, start_row, start_col, end_row, - end_col, dst, ¶ms); - if (__builtin_expect(tuning == Tuning::kInOrder, true)) { - KernelFloatNeonInOrder(params); - } else { - KernelFloatNeonOutOfOrder(params); - } - } -}; -#endif - -#if RUY_PLATFORM(NEON_32) -// A Float kernel for ARM32 Neon. -template <> -struct Kernel> { - Tuning tuning = Tuning::kAuto; - using LhsLayout = FixedKernelLayout; - using RhsLayout = FixedKernelLayout; - explicit Kernel(Tuning tuning_) : tuning(tuning_) {} - void Run(const PackedMatrix& lhs, const PackedMatrix& rhs, - const BasicSpec& spec, int start_row, int start_col, - int end_row, int end_col, Matrix* dst) const { - KernelParamsFloat<8, 4> params; - - MakeKernelParamsFloat(lhs, rhs, spec, start_row, start_col, end_row, - end_col, dst, ¶ms); - - KernelFloat32NeonOutOfOrder(params); - } -}; -#endif - -// While the dotprod NEON extension does not concern floating-point arithmetic, -// its presence allows us to distinguish, in the in-order tuning case, between -// A53 and A55r1. TODO: should this be folded into tuning? -template <> -struct Kernel> { - Tuning tuning = Tuning::kAuto; - using LhsLayout = FixedKernelLayout; - using RhsLayout = FixedKernelLayout; - using Base = - Kernel>; - explicit Kernel(Tuning tuning_) : tuning(tuning_) {} - void Run(const PackedMatrix& lhs, const PackedMatrix& rhs, - const BasicSpec& spec, int start_row, int start_col, - int end_row, int end_col, Matrix* dst) const { - KernelParamsFloat params; - MakeKernelParamsFloat(lhs, rhs, spec, start_row, start_col, end_row, - end_col, dst, ¶ms); - if (__builtin_expect(tuning == Tuning::kInOrder, true)) { - KernelFloatNeonDotprodInOrder(params); - } else { - KernelFloatNeonOutOfOrder(params); - } - } -}; - -#endif // RUY_PLATFORM(NEON) && RUY_OPT_ENABLED(RUY_OPT_ASM) - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_KERNEL_ARM_H_ diff --git a/kernel_arm32.cc b/kernel_arm32.cc deleted file mode 100644 index 8d7e55d..0000000 --- a/kernel_arm32.cc +++ /dev/null @@ -1,2499 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "kernel.h" -#include "opt_set.h" -#include "platform.h" -#include "profiler/instrumentation.h" - -namespace ruy { - -#if RUY_PLATFORM(NEON_32) && RUY_OPT_ENABLED(RUY_OPT_ASM) - -#define RUY_ASM_LABEL_STORE_UINT8 91 -#define RUY_ASM_LABEL_STORE_INT8 92 -#define RUY_ASM_LABEL_STORE_INT16 93 -#define RUY_ASM_LABEL_STORE_INT32 94 -#define RUY_ASM_LABEL_AFTER_STORE 99 - -#define RUY_OFFSET_LHS_BASE_PTR 0 -#define RUY_OFFSET_RHS_BASE_PTR 4 -#define RUY_OFFSET_DST_BASE_PTR 8 -#define RUY_OFFSET_BIAS 12 -#define RUY_OFFSET_START_ROW 16 -#define RUY_OFFSET_START_COL 20 -#define RUY_OFFSET_LAST_ROW 24 -#define RUY_OFFSET_LAST_COL 28 -#define RUY_OFFSET_DST_ROWS 32 -#define RUY_OFFSET_DST_COLS 36 -#define RUY_OFFSET_LHS_STRIDE 40 -#define RUY_OFFSET_RHS_STRIDE 44 -#define RUY_OFFSET_DST_STRIDE 48 -#define RUY_OFFSET_DEPTH 52 -#define RUY_OFFSET_CLAMP_MIN 56 -#define RUY_OFFSET_CLAMP_MAX 60 -#define RUY_OFFSET_FLAGS 64 - -#define RUY_STACK_OFFSET_SIZE 96 -#define RUY_STACK_OFFSET_DST_COL_PTR 0 -#define RUY_STACK_OFFSET_DST_PTR 16 -#define RUY_STACK_OFFSET_ROW 32 -#define RUY_STACK_OFFSET_COL 48 -#define RUY_STACK_OFFSET_LHS_COL_PTR 64 -#define RUY_STACK_OFFSET_RHS_COL_PTR 80 - -template -void CheckOffsetsInKernelParamsFloat32(const Params&) { - static_assert(offsetof(Params, lhs_base_ptr) == RUY_OFFSET_LHS_BASE_PTR, ""); - static_assert(offsetof(Params, rhs_base_ptr) == RUY_OFFSET_RHS_BASE_PTR, ""); - static_assert(offsetof(Params, dst_base_ptr) == RUY_OFFSET_DST_BASE_PTR, ""); - static_assert(offsetof(Params, bias) == RUY_OFFSET_BIAS, ""); - static_assert(offsetof(Params, start_row) == RUY_OFFSET_START_ROW, ""); - static_assert(offsetof(Params, start_col) == RUY_OFFSET_START_COL, ""); - static_assert(offsetof(Params, last_row) == RUY_OFFSET_LAST_ROW, ""); - static_assert(offsetof(Params, last_col) == RUY_OFFSET_LAST_COL, ""); - static_assert(offsetof(Params, dst_rows) == RUY_OFFSET_DST_ROWS, ""); - static_assert(offsetof(Params, lhs_stride) == RUY_OFFSET_LHS_STRIDE, ""); - static_assert(offsetof(Params, rhs_stride) == RUY_OFFSET_RHS_STRIDE, ""); - static_assert(offsetof(Params, dst_stride) == RUY_OFFSET_DST_STRIDE, ""); - static_assert(offsetof(Params, depth) == RUY_OFFSET_DEPTH, ""); - static_assert(offsetof(Params, clamp_min) == RUY_OFFSET_CLAMP_MIN, ""); - static_assert(offsetof(Params, clamp_max) == RUY_OFFSET_CLAMP_MAX, ""); - static_assert(offsetof(Params, flags) == RUY_OFFSET_FLAGS, ""); -} - -// Float kernel for ARM32 out-of-order cores. -// Just like Float 64 version, except accumulate in to 8x4 block to only -// use 16 128-bit NEON registers. This is a "first pass" kernel and not -// tuned. It is meant to run on out-of-order CPUs like the Krait 400 or A9. -void KernelFloat32NeonOutOfOrder(const KernelParamsFloat<8, 4>& params) { - CheckOffsetsInKernelParamsFloat32(params); - profiler::ScopeLabel label( - "Kernel (kNeon, optimized for out-of-order cores)"); - - const float* lhs_ptr = params.lhs_base_ptr; - const float* rhs_ptr = params.rhs_base_ptr; - // In ARM32 NEON, there are 16 128-bit "q" registers. These registers are - // each composed of two 64-bit "d" registers. The asm kernel below has the - // following NEON register allocation: - // Registers q3 -- q10 are accumulators. During accumulation, - // q0 -- q2 (d0 -- d5) are used to load data from LHS and RHS. q0 and q1 - // are used to load a 8x1 block of LHS, and q2 is used to load a 1x4 block - // of RHS, like this: - - // Register layout in "q" registers: - // RHS 1x4 block - // /--------------------------\ - // |q2.s[0] ... q2.s[3] | - // \--------------------------/ - // LHS 8x1 block - // /---------------------\ /--------------------- \ - // | q0.s[0] | | q3.s[0] ... q9.s[0] | - // | ... | | ... ... | - // | q0.s[3] | | q3.s[3] q9.s[3] | - // | q1.s[0] | | q4.s[0] q10.s[0] | - // | ... | | ... ... ... | - // | q1.s[3] | | q4.s[3] .. q10.s[3] | - // \---------------------/ \--------------------------/ - // accumulators 8x4 block - // q11, q14, q15 currently unused. q12 and q13 are used to load - // parameters used for the post-accumulation part of the kernel. - // For completeness, here is the register layout in "d" registers: - // RHS 1x4 block - // /--------------------------\ - // |d4[0] ... d5[1] | - // \--------------------------/ - // LHS 8x1 block - // /---------------------\ /--------------------------\ - // | d0[0] | | d6[0] ... d18[0] | - // | ... | | ... ... | - // | d1[1] | | d7[1] d19[1] | - // | d2[0] | | d8[0] d20[0] | - // | ... | | ... ... ... | - // | d3[1] | | d9[1] ... d21[1] | - // \---------------------/ \--------------------------/ - // accumulators 8x4 block - asm volatile( -#define RUY_MAKE_ZERO(reg) "vmov.f32 " #reg ", #0.0\n" - - // clang-format off - - // Load the first 32 bytes of LHS and RHS data. - // Load q0, q1 - "vld1.32 {d0, d1}, [%[lhs_ptr]]!\n" - "vld1.32 {d2, d3}, [%[lhs_ptr]]!\n" - RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n") - // Load q2 - "vld1.32 {d4, d5}, [%[rhs_ptr]]!\n" - RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n") - - "sub sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n" - - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n" - "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n" - - "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n" - "str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" - "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - - "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_START_COL) "]\n" - "str r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" - - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" - - "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_RHS_BASE_PTR) "]\n" - "str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" - // Clear accumulators. - RUY_MAKE_ZERO(q3) - RUY_MAKE_ZERO(q4) - RUY_MAKE_ZERO(q5) - RUY_MAKE_ZERO(q6) - RUY_MAKE_ZERO(q7) - RUY_MAKE_ZERO(q8) - RUY_MAKE_ZERO(q9) - RUY_MAKE_ZERO(q10) - - // r1 is the number of levels of depth that we have already loaded - // LHS and RHS data for. Corresponding to the initial ld1 instructions - // above, this is currently 1. - "mov r1, #1\n" - - // Main loop of the whole GEMM, over rows and columns of the - // destination matrix. - "1:\n" - - // Accumulation loop - "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" - "cmp r1, r2\n" - "beq 79f\n" - - "2:\n" - - "vmla.f32 q3, q0, d4[0]\n" - "vmla.f32 q5, q0, d4[1]\n" - "vmla.f32 q7, q0, d5[0]\n" - "vmla.f32 q9, q0, d5[1]\n" - "vld1.32 {d0, d1}, [%[lhs_ptr]]!\n" // Reload LHS - - "vmla.f32 q4, q1, d4[0]\n" - "vmla.f32 q6, q1, d4[1]\n" - "vmla.f32 q8, q1, d5[0]\n" - "vmla.f32 q10, q1, d5[1]\n" - "vld1.32 {d2, d3}, [%[lhs_ptr]]!\n" // Reload LHS - RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n") - "vld1.32 {d4, d5}, [%[rhs_ptr]]!\n" // Reload RHS - RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n") - - "add r1, r1, #1\n" - "cmp r1, r2\n" - - "blt 2b\n" - - "79:\n" - - // End of the inner loop on depth. Now perform the remaining - // multiply-adds of the last level of depth, for which the LHS - // and RHS data is already loaded. - - "vmla.f32 q3, q0, d4[0]\n" - "vmla.f32 q5, q0, d4[1]\n" - "vmla.f32 q7, q0, d5[0]\n" - "vmla.f32 q9, q0, d5[1]\n" - - "vmla.f32 q4, q1, d4[0]\n" - "vmla.f32 q6, q1, d4[1]\n" - "vmla.f32 q8, q1, d5[0]\n" - "vmla.f32 q10, q1, d5[1]\n" - - // End of accumulation. The registers q3 -- q10 contain the final - // float32 accumulator values of the current 8x8 destination block. - // We now have to compute the final values from these accumulators - // and advance to the next 8x8 block. We intertwine - // these two aspects whenever possible for optimal pipelining, both - // at the data flow level (prefetch data for next block as early as - // possible) and instruction pipelining level (some of the next-block - // work can dual-issue with some of the final work on the current - // block). - - // Logic to advance to the next block in preparation for the next - // iteration of the main loop. For now, we only want to compute - // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are - // not yet ready to update the values of row and col, as we still need - // the current values for the rest of the work on the current block. - - "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" - "ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - "cmp r1, r3\n" // Have we finished the last row? - - "bge 4f\n" // If finished last row, go to 4 - // Not finished last row: then advance to next row. - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" - "add r4, r4, r1, lsl #3\n" - "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" - "b 5f\n" - "4:\n" // Finished last row... - "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - // Go back to first row - "str r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" - // Now we need to advance to the next column. If we already - // finished the last column, then in principle we are done, however - // we can't just return here, as we need to allow the end work of the - // current block to complete. The good news is that at this point it - // doesn't matter what data we load for the next column, since - // we will exit from the main loop below before actually storing - // anything computed from that data. - "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" - "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" - "cmp r8, r4\n" // Have we finished the last column? - "bge 5f\n" // If yes, just carry on without updating the column pointer. - // Not finished last column: then advance to next column. - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" - "ldr r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" - "add r10, r10, r1, lsl #2\n" - "str r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" - "5:\n" - - // Set the LHS and RHS data pointers to the start of the columns just - // computed. - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" - "mov %[lhs_ptr], r4\n" - "ldr r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" - "mov %[rhs_ptr], r5\n" - - // Load some parameters needed for the end work on current block. - "ldrb r4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" - - // Offset these base pointers as needed given the current row, col. - "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - "add r5, r1, r8, lsl #2\n" - - "tst r4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" - "it ne\n" - "movne r1, r5\n" - - // Load 8 bias values. - "vld1.32 {d24, d25, d26, d27}, [r1]\n" - - // Now that we know what LHS and RHS data the next iteration of the - // main loop will need to load, we start loading the first 32 bytes of - // each of LHS and RHS, into q0 -- q2, as we don't need q0 -- q2 anymore - // in the rest of the work on the current block. - // Load q0, q1 - "vld1.32 {d0, d1, d2, d3}, [%[lhs_ptr]]!\n" - RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n") - // Load q2 - "vld1.32 {d4, d5}, [%[rhs_ptr]]!\n" - RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n") - - // Perform the bias-addition (per the above, we have just folded into - // the bias the (depth * lhs_zero_point * rhs_zero_point) term.) - "vadd.f32 q3, q3, q12\n" - "vadd.f32 q4, q4, q13\n" - "vadd.f32 q5, q5, q12\n" - "vadd.f32 q6, q6, q13\n" - "vadd.f32 q7, q7, q12\n" - "vadd.f32 q8, q8, q13\n" - "vadd.f32 q9, q9, q12\n" - "vadd.f32 q10, q10, q13\n" - - // Load the clamp_min, clamp_max bounds - "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" - "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" - "vdup.32 q12, r2\n" // clamp_min - "vdup.32 q13, r3\n" // clamp_max - - // Apply the clamp_min bound - "vmax.f32 q3, q3, q12\n" - "vmax.f32 q4, q4, q12\n" - "vmax.f32 q5, q5, q12\n" - "vmax.f32 q6, q6, q12\n" - "vmax.f32 q7, q7, q12\n" - "vmax.f32 q8, q8, q12\n" - "vmax.f32 q9, q9, q12\n" - "vmax.f32 q10, q10, q12\n" - - // Apply the clamp_max bound - "vmin.f32 q3, q3, q13\n" - "vmin.f32 q4, q4, q13\n" - "vmin.f32 q5, q5, q13\n" - "vmin.f32 q6, q6, q13\n" - "vmin.f32 q7, q7, q13\n" - "vmin.f32 q8, q8, q13\n" - "vmin.f32 q9, q9, q13\n" - "vmin.f32 q10, q10, q13\n" - - // Compute how much of the 8x4 block of destination values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 8x4, there are some 8x8 blocks along the boundaries that do - // not fit entirely. - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n" - "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - "sub r1, r1, r8\n" - - "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n" - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" - "sub r2, r2, r4\n" - "mov r3, #8\n" - "mov r5, #4\n" - "cmp r1, #8\n" - // Compute r1 = how many rows of the 8x4 block fit - "it gt\n" - "movgt r1, r3\n" - "cmp r2, #4\n" - // Compute r2 = how many cols of the 8x4 block fit - "it gt\n" - "movgt r2, r5\n" - - // Test if r1==8 && r2 == 4, i.e. if all of the 8x4 block fits. - "cmp r1, r3\n" - "it eq\n" - "cmpeq r2, r5\n" - // Yes, all of the 8x4 block fits, go to fast path. - "beq 30f\n" - // Not all of the 8x4 block fits. - // Set (r3 address, r4 stride) to write to dst_tmp_buf - "mov r3, %[dst_tmp_buf]\n" - "mov r4, #32\n" - "b 31f\n" - "30:\n" - // Yes, all of the 8x4 block fits. - // Set (r3 address, r4 stride) to write directly to destination matrix. - "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" - "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "mov r4, r5\n" - "31:\n" - - // Write our float values to the destination described by - // (r3 address, r4 stride) - "vst1.32 {d6, d7, d8, d9}, [r3]\n" - "add r3, r3, r4\n" - RUY_MAKE_ZERO(q3) - RUY_MAKE_ZERO(q4) - "vst1.32 {d10, d11, d12, d13}, [r3]\n" - "add r3, r3, r4\n" - RUY_MAKE_ZERO(q5) - RUY_MAKE_ZERO(q6) - "vst1.32 {d14, d15, d16, d17}, [r3]\n" - "add r3, r3, r4\n" - RUY_MAKE_ZERO(q7) - RUY_MAKE_ZERO(q8) - "vst1.32 {d18, d19, d20, d21}, [r3]\n" - "add r3, r3, r4\n" - RUY_MAKE_ZERO(q9) - RUY_MAKE_ZERO(q10) - - // If all of the 8x4 block fits, we just finished writing it to the - // destination, so we skip the next part. - "beq 41f\n" - // Not all of the 8x8 block fits in the destination matrix. We just - // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over - // it to copy into the destination matrix the part that fits. - "ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" - "mov r3, %[dst_tmp_buf]\n" - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "mov r6, #0\n" - "50:\n" - "mov r5, #0\n" - "51:\n" - "ldr r10, [r3, r5, lsl #2]\n" - "str r10, [r4, r5, lsl #2]\n" - "add r5, r5, #1\n" - "cmp r5, r1\n" - "blt 51b\n" - "add r6, r6, #1\n" - "add r3, r3, #32\n" - "add r4, r4, r8\n" - // r2 = how many cols of the 8x4 block fit - "cmp r6, r2\n" - "blt 50b\n" - "41:\n" - // Load dst_ptr, increment, and write back. - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "add r4, r4, #32\n" - "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - // At this point we have completely finished writing values to the - // destination matrix for the current block. - - // Reload some params --- we had used r3, r5, r10 for a few other things - // since the last time we had loaded them. - "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - "ldr r6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" - "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" - - // Move to the next block of the destination matrix, for the next iter - // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already - // been updated earlier. - // Have we reached the end row? - "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - "cmp r8, r3\n" - - "beq 20f\n" // yes, end row. - // Not end row. Move to the next row. - "add r8, r8, #8\n" - // Store new value of row - "str r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - - "b 21f\n" - "20:\n" - // Was already at end row. - // Move back to first row. - "str r6, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - // Move to the next column. - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" - "add r4, r4, #4\n" - "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" - - "ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" - "ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n" - // Increment dst_col_ptr by 4 * dst_stride (i.e. 4 columns) - "add r1, r1, r8, lsl #2\n" - // Store dst_col_ptr - "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n" - // Store dst_ptr - "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "21:\n" - - // Main loop exit condition: have we hit the end column? - "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" - "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" - "cmp r8, r4\n" - - // r1 is the number of levels of depth that we have already loaded - // LHS and RHS data for. Corresponding to the initial ld1 instructions - // above, this is currently 1. - "mov r1, #1\n" - - "ble 1b\n" - - // Restore stack pointer. - "add sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n" - - // clang-format on - : [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr) - : [ params ] "r"(¶ms), [dst_tmp_buf] "r"(params.dst_tmp_buf) - // Clobber list must specify q registers (and not their constituent - // d registers). There is a (currently unexplained) slowdown if - // d registers are listed in the clobbers list. - : "r0", "r1", "r2", "r3", "r4", "r5", "r6", "r8", "r10", "cc", - "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", - "q9", "q10", "q12", "q13"); -} - -#undef RUY_MAKE_ZERO -#undef RUY_STACK_OFFSET_SIZE -#undef RUY_STACK_OFFSET_DST_COL_PTR -#undef RUY_STACK_OFFSET_DST_PTR -#undef RUY_STACK_OFFSET_ROW -#undef RUY_STACK_OFFSET_COL -#undef RUY_STACK_OFFSET_LHS_COL_PTR -#undef RUY_STACK_OFFSET_RHS_COL_PTR - -#undef RUY_OFFSET_LHS_BASE_PTR -#undef RUY_OFFSET_RHS_BASE_PTR -#undef RUY_OFFSET_DST_BASE_PTR -#undef RUY_OFFSET_BIAS -#undef RUY_OFFSET_START_ROW -#undef RUY_OFFSET_START_COL -#undef RUY_OFFSET_LAST_ROW -#undef RUY_OFFSET_LAST_COL -#undef RUY_OFFSET_DST_ROWS -#undef RUY_OFFSET_DST_COLS -#undef RUY_OFFSET_LHS_STRIDE -#undef RUY_OFFSET_RHS_STRIDE -#undef RUY_OFFSET_DST_STRIDE -#undef RUY_OFFSET_DEPTH -#undef RUY_OFFSET_CLAMP_MIN -#undef RUY_OFFSET_CLAMP_MAX -#undef RUY_OFFSET_FLAGS - -#define RUY_OFFSET_BIAS 0 -#define RUY_OFFSET_LHS_SUMS 4 -#define RUY_OFFSET_RHS_SUMS 8 -#define RUY_OFFSET_LHS_BASE_PTR 12 -#define RUY_OFFSET_MULTIPLIER_FIXEDPOINT 16 -#define RUY_OFFSET_MULTIPLIER_EXPONENT 20 -#define RUY_OFFSET_RHS_BASE_PTR 24 -#define RUY_OFFSET_DST_BASE_PTR 28 -#define RUY_OFFSET_LHS_ZERO_POINT 32 -#define RUY_OFFSET_RHS_ZERO_POINT 36 -#define RUY_OFFSET_DST_ZERO_POINT 40 -#define RUY_OFFSET_PROD_ZP_DEPTH 44 -#define RUY_OFFSET_START_ROW 48 -#define RUY_OFFSET_START_COL 52 -#define RUY_OFFSET_LAST_ROW 56 -#define RUY_OFFSET_LAST_COL 60 -#define RUY_OFFSET_DST_ROWS 64 -#define RUY_OFFSET_DST_COLS 68 -#define RUY_OFFSET_LHS_STRIDE 72 -#define RUY_OFFSET_RHS_STRIDE 76 -#define RUY_OFFSET_DST_STRIDE 80 -#define RUY_OFFSET_DEPTH 84 -#define RUY_OFFSET_CLAMP_MIN 88 -#define RUY_OFFSET_CLAMP_MAX 92 -#define RUY_OFFSET_FLAGS 96 -#define RUY_OFFSET_DST_TYPE_ID 97 - -#define RUY_STACK_OFFSET_SIZE 96 -#define RUY_STACK_OFFSET_DST_COL_PTR 0 -#define RUY_STACK_OFFSET_DST_PTR 16 -#define RUY_STACK_OFFSET_ROW 32 -#define RUY_STACK_OFFSET_COL 48 -#define RUY_STACK_OFFSET_LHS_COL_PTR 64 -#define RUY_STACK_OFFSET_RHS_COL_PTR 80 - -template -void CheckOffsetsInKernelParams8bit(const Params&) { - static_assert(offsetof(Params, lhs_zero_point) == RUY_OFFSET_LHS_ZERO_POINT, - ""); - static_assert(offsetof(Params, rhs_zero_point) == RUY_OFFSET_RHS_ZERO_POINT, - ""); - static_assert(offsetof(Params, dst_zero_point) == RUY_OFFSET_DST_ZERO_POINT, - ""); - static_assert(offsetof(Params, prod_zp_depth) == RUY_OFFSET_PROD_ZP_DEPTH, - ""); - static_assert(offsetof(Params, multiplier_fixedpoint) == - RUY_OFFSET_MULTIPLIER_FIXEDPOINT, - ""); - static_assert( - offsetof(Params, multiplier_exponent) == RUY_OFFSET_MULTIPLIER_EXPONENT, - ""); - static_assert(offsetof(Params, clamp_min) == RUY_OFFSET_CLAMP_MIN, ""); - static_assert(offsetof(Params, clamp_max) == RUY_OFFSET_CLAMP_MAX, ""); - static_assert(offsetof(Params, bias) == RUY_OFFSET_BIAS, ""); - static_assert(offsetof(Params, lhs_sums) == RUY_OFFSET_LHS_SUMS, ""); - static_assert(offsetof(Params, rhs_sums) == RUY_OFFSET_RHS_SUMS, ""); - static_assert(offsetof(Params, flags) == RUY_OFFSET_FLAGS, ""); - static_assert(offsetof(Params, lhs_base_ptr) == RUY_OFFSET_LHS_BASE_PTR, ""); - static_assert(offsetof(Params, start_row) == RUY_OFFSET_START_ROW, ""); - static_assert(offsetof(Params, last_row) == RUY_OFFSET_LAST_ROW, ""); - static_assert(offsetof(Params, last_col) == RUY_OFFSET_LAST_COL, ""); - static_assert(offsetof(Params, lhs_stride) == RUY_OFFSET_LHS_STRIDE, ""); - static_assert(offsetof(Params, rhs_stride) == RUY_OFFSET_RHS_STRIDE, ""); - static_assert(offsetof(Params, dst_stride) == RUY_OFFSET_DST_STRIDE, ""); - static_assert(offsetof(Params, depth) == RUY_OFFSET_DEPTH, ""); -} - -// Fast-int8 kernel, ported from ARM 64 version. -// Relevant target CPUs for this kernel include Krait 400 and A9, -// since these are 32-bit, out-of-order CPUs. -void Kernel8bitNeonOutOfOrder(const KernelParams8bit<4, 2>& params) { - profiler::ScopeLabel label( - "Kernel (kNeon, optimized for out-of-order cores)"); - - CheckOffsetsInKernelParams8bit(params); - - const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; - const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; - const std::int8_t* lhs_ptr = lhs_col_ptr; - const std::int8_t* rhs_ptr = rhs_col_ptr; - - // The asm kernel below has the following NEON register allocation: - // - // q6 - q13 are 128-bit (4x32b) accumulators. - // During accumulation, d0 -- d7 are used to load int8 data from LHS and - // d8 -- d11 from RHS: - // int8 RHS 16x2 block - // /-----------------------------\ - // |d8.b[0-7] ..... d10.b[0-7]| - // | ... ... | - // |d9.b[0-7] ..... d11.b[0-7]| - // \-----------------------------/ - // int8 LHS 4x16 block - // /------------------------\ /-----------------------------\ - // |d0.b[0-7] ... d1.b[0-7] | | q6 ..... q10 | - // |d2.b[0-7] ... d3.b[0-7] | | q7 ..... q11 | - // (Reload d0, d1, d2, d3) - // |d0.b[0-7] ... d1.b[0-7] | | q8 ..... q12 | - // |d2.b[0-7] ... d3.b[0-7] | | q9 ..... q13 | - // \------------------------/ \-----------------------------/ - // 128-bit accumulators 4x2 block - // - // No attempt had been made so far at implementing the RUY_OPT_MAX_STREAMING - // optimization for this kernel. - asm volatile( -#define RUY_MAKE_ZERO(reg) "vmov.i32 " #reg ", #0x00000000\n" - - // clang-format off - - // Load the first 64 bytes of LHS and RHS data. - "vld1.8 {d0, d1}, [%[lhs_ptr]]!\n" - // Clear accumulators. - RUY_MAKE_ZERO(q6) - "vld1.8 {d2, d3}, [%[lhs_ptr]]!\n" - RUY_MAKE_ZERO(q8) - RUY_MAKE_ZERO(q9) - RUY_MAKE_ZERO(q10) - "vld1.8 {d8, d9}, [%[rhs_ptr]]!\n" - RUY_MAKE_ZERO(q11) - "vld1.8 {d10, d11}, [%[rhs_ptr]]!\n" - - "sub sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n" - - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n" - RUY_MAKE_ZERO(q12) - "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n" - - "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n" - RUY_MAKE_ZERO(q13) - "str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" - RUY_MAKE_ZERO(q14) - "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - - "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_START_COL) "]\n" - RUY_MAKE_ZERO(q15) - "str r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" - - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" - - "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_RHS_BASE_PTR) "]\n" - "str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" - - - // r1 is the number of levels of depth that we have already loaded - // LHS and RHS data for. Corresponding to the initial ld1 instructions - // above, this is currently 16. - "mov r1, #16\n" - - // Main loop of the whole GEMM, over rows and columns of the - // destination matrix. - "1:\n" - - // r1 is how many levels of depth we have already loaded - // data for, r10 is the total depth. - "ldr r10, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" - "cmp r1, r10\n" - "beq 79f\n" - - "2:\n" - - // Mult, mult-acc in to q14, q15, q2, q3 - "vmull.s8 q14, d0, d8\n" - "vmull.s8 q2, d0, d10\n" - - "vmull.s8 q15, d2, d8\n" - "vmull.s8 q3, d2, d10\n" - - "vmlal.s8 q14, d1, d9\n" - "vmlal.s8 q2, d1, d11\n" - "vmlal.s8 q15, d3, d9\n" - "vmlal.s8 q3, d3, d11\n" - "vld1.8 {d0, d1, d2, d3}, [%[lhs_ptr]]!\n" // Reload LHS - - // Then pairwise accumulate in to q6, q7, q10, q11 - "vpadal.s16 q6, q14\n" - "vpadal.s16 q7, q15\n" - "vpadal.s16 q10, q2\n" - "vpadal.s16 q11, q3\n" - - // Mult, mult-acc in to q14, q15, q2, q3 - "vmull.s8 q14, d0, d8\n" - "vmull.s8 q2, d0, d10\n" - - "vmull.s8 q15, d2, d8\n" - "vmull.s8 q3, d2, d10\n" - - "vmlal.s8 q14, d1, d9\n" - "vmlal.s8 q2, d1, d11\n" - "vmlal.s8 q15, d3, d9\n" - "vmlal.s8 q3, d3, d11\n" - "vld1.8 {d0, d1, d2, d3}, [%[lhs_ptr]]!\n" // Reload LHS - - // Then pairwise accumulate in to q8, q9, q12, q13 - "vpadal.s16 q8, q14\n" - "vld1.8 {d8, d9, d10, d11}, [%[rhs_ptr]]!\n" - "vpadal.s16 q9, q15\n" - "vpadal.s16 q12, q2\n" - "vpadal.s16 q13, q3\n" - - // Prefetch the next 64 bytes of LHS and RHS data. - RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n") - RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n") - - // Each iteration of this loop advances by 16 levels of depth. - "add r1, r1, #16\n" - - // Loop termination condition - "cmp r1, r10\n" - - "blt 2b\n" - - "79:\n" - - // Mult, mult-acc in to q14, q15, q2, q3 - "vmull.s8 q14, d0, d8\n" - "vmull.s8 q2, d0, d10\n" - - "vmull.s8 q15, d2, d8\n" - "vmull.s8 q3, d2, d10\n" - - "vmlal.s8 q14, d1, d9\n" - "vmlal.s8 q2, d1, d11\n" - "vmlal.s8 q15, d3, d9\n" - "vmlal.s8 q3, d3, d11\n" - "vld1.8 {d0, d1, d2, d3}, [%[lhs_ptr]]!\n" // Reload LHS - - // Then pairwise accumulate in to q6, q7, q10, q11 - "vpadal.s16 q6, q14\n" - "vpadal.s16 q7, q15\n" - "vpadal.s16 q10, q2\n" - "vpadal.s16 q11, q3\n" - - // Mult, mult-acc in to q14, q15, q2, q3 - "vmull.s8 q14, d0, d8\n" - "vmull.s8 q2, d0, d10\n" - - "vmull.s8 q15, d2, d8\n" - "vmull.s8 q3, d2, d10\n" - - "vmlal.s8 q14, d1, d9\n" - "vmlal.s8 q2, d1, d11\n" - "vmlal.s8 q15, d3, d9\n" - "vmlal.s8 q3, d3, d11\n" - - // Then pairwise accumulate in to q8, q9, q12, q13 - "vpadal.s16 q8, q14\n" - "vpadal.s16 q9, q15\n" - "vpadal.s16 q12, q2\n" - "vpadal.s16 q13, q3\n" - - - // All accumulation over depth done. q6 - q13 contain the 4x32b - // accumulators for the 4x2 final matrix. - // We now have to compute the final 8-bit values from these int32 - // accumulators, and advance to the next 4x2 block. We intertwine - // these two aspects whenever possible for optimal pipelining, both - // at the data flow level (prefetch data for next block as early as - // possible) and instruction pipelining level (some of the next-block - // work can dual-issue with some of the final work on the current - // block). - - // q6-q13 now contain 4 x 32b - "vpadd.i32 d0, d12, d13\n" - "vpadd.i32 d1, d14, d15\n" - "vpadd.i32 d2, d16, d17\n" - "vpadd.i32 d3, d18, d19\n" - "vpadd.i32 d4, d20, d21\n" - "vpadd.i32 d5, d22, d23\n" - "vpadd.i32 d6, d24, d25\n" - "vpadd.i32 d7, d26, d27\n" - - // d0-d7 each contain 2 x 32b accumulators. - // Need to add pairwise to get 1 x 32b for each of the 4x2 entries - // of destination, (Four 'd' registers total) - "vpadd.i32 d28, d0, d1\n" - "vpadd.i32 d29, d2, d3\n" - "vpadd.i32 d30, d4, d5\n" - "vpadd.i32 d31, d6, d7\n" - - //Now d28 - d31 have the 1 x 32b accumulators for the 4x2 entries - - // Logic to advance to the next block in preparation for the next - // iteration of the main loop. For now, we only want to compute - // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are - // not yet ready to update the values of row and col, as we still need - // the current values for the rest of the work on the current block. - - "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" - "ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - "cmp r1, r3\n" // Have we finished the last row? - - "bge 4f\n" // If finished last row, go to 4 - // Not finished last row: then advance to next row. - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" - "add r4, r4, r1, lsl #2\n" - "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" - "b 5f\n" - "4:\n" // Finished last row... - "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - // Go back to first row - "str r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" - - // Now we need to advance to the next column. If we already - // finished the last column, then in principle we are done, however - // we can't just return here, as we need to allow the end work of the - // current block to complete. The good news is that at this point it - // doesn't matter what data we load for the next column, since - // we will exit from the main loop below before actually storing - // anything computed from that data. - - "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" - "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" - "cmp r8, r4\n" // Have we finished the last column? - "bge 5f\n" // If yes, just carry on without updating the column pointer. - // Not finished last column: then advance to next column. - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" - "ldr r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" - "add r10, r10, r1, lsl #1\n" - "str r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" - "5:\n" - - // Set the LHS and RHS data pointers to the start of the columns just - // computed. - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" - "mov %[lhs_ptr], r4\n" - "ldr r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" - "mov %[rhs_ptr], r5\n" - - // Now we load: bias data, LHS sums data, RHS sums data. - - // First, load the base pointers from the params. - "ldrb r4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" - - // Offset these base pointers as needed given the current row, col. - "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - "add r5, r1, r8, lsl #2\n" - - "tst r4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" - "it ne\n" - "movne r1, r5\n" - - // Load 4 bias values. - "vld1.32 {d24, d25}, [r1]\n" - - // Now that we know what LHS and RHS data the next iteration of the - // main loop will need to load, we start loading the first 32 bytes of - // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore - // in the rest of the work on the current block. - "vld1.8 {d0, d1, d2, d3}, [%[lhs_ptr]]!\n" - RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n") - "vld1.8 {d8, d9, d10, d11}, [%[rhs_ptr]]!\n" - RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n") - - // Add to the bias values the product - // (depth * lhs_zero_point * rhs_zero_point), - // See the term NZ1Z2 in equation (7) in - // https://arxiv.org/pdf/1712.05877.pdf - "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n" - "vdup.32 q9, r3\n" - "vadd.i32 q12, q12, q9\n" - - // Perform the bias-addition (per the above, we have just folded into - // the bias the (depth * lhs_zero_point * rhs_zero_point) term.) - "vadd.i32 q14, q14, q12\n" - "vadd.i32 q15, q15, q12\n" - - // LHS/RHS zero points - // Has RHS sums - "ldrb r6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" - "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n" - "beq 401f\n" - "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n" - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" - // Offset by current col * number of bytes per value - "add r3, r3, r4, lsl #2\n" - "vld1.32 { d12 }, [r3]\n" - "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n" - "vdup.32 q10, r5\n" // create lhs_zero_point_vec - // Subtract rhs_sums * lhs_zero_point, per - // equation (7) in https://arxiv.org/pdf/1712.05877.pdf - "vmls.i32 q14, q10, d12[0]\n" - "vmls.i32 q15, q10, d12[1]\n" - "401:\n" - - // Has LHS sums - "ldrb r6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" - "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n" - "beq 402f\n" - "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n" - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - // Offset by current row * number of bytes per value - "add r2, r2, r4, lsl #2\n" - "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n" - - // Load 4 lhs_sums values. - "vld1.32 {d22, d23}, [r2]\n" - "vdup.32 d13, r5\n" // rhs_zero_point - - // Compute lhs_sums * rhs_zero_point. - "vmul.i32 q11, q11, d13[1]\n" - // Subtract lhs_sums * rhs_zero_point, per - // equation (7) in https://arxiv.org/pdf/1712.05877.pdf - "vsub.s32 q14, q14, q11\n" - "vsub.s32 q15, q15, q11\n" - - // If the destination is int32, it means the user asks for the raw - // accumulators, no need for us to downquantize the value. - "ldrb r10, [%[params], #" RUY_STR(RUY_OFFSET_DST_TYPE_ID) "]\n" - "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n" - "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n" - - "402:\n" - - // At this point we have computed the final int32 values. Now we - // start down-quantizing them to obtain the final 8bit values from them. - - // As part of this down-quantization, our int32 values will be - // multiplied by a multiplier that has a fixed-point component and an - // exponent component. - - //Load the exponent part of the multiplier. - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n" - "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - "add r5, r1, r4, lsl #2\n" - "it ne\n" - "movne r1, r5\n" - - "vld1.32 {q10}, [r1]\n" - - RUY_MAKE_ZERO(q8) - "vmax.s32 q12, q10, q8\n" - - "vshl.s32 q14, q14, q12\n" - "vshl.s32 q15, q15, q12\n" - - "vmin.s32 q12, q10, q8\n" - - // Load fixed point part of the multiplier - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n" - // r6 has flags, r4 has row - "add r5, r1, r4, lsl #2\n" - "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" - "it ne\n" - "movne r1, r5\n" - "vld1.32 {q10}, [r1]\n" // multiplier_fixedpoint - - // Apply the fixed-point part of the multiplier. - "vqrdmulh.s32 q14, q14, q10\n" - "vqrdmulh.s32 q15, q15, q10\n" - - // We have some rounding division-by-power-of-two to do. This should - // always use "round to nearest". We allow for some - // freedom in how ties are broken, to strike a good compromise of - // performance on given hardware vs. perfect agreement of results - // across hardware. - // - // When RUY_OPT_NATIVE_ROUNDING is enabled, we allow for implementation - // defined tie-breaks to help performance. On NEON, this means that we - // can just use the NEON rounding instructions, such as srshl. They - // happen to be breaking ties upward. - // - // When RUY_OPT_NATIVE_ROUNDING is disabled, we implement strict - // break-ties-away-from zero, as described in Appendix B of - // https://arxiv.org/pdf/1712.05877.pdf - // When we wrote that, we thought that that would be better unbiased - // than the NEON upwards tie-breaks, and we had observed some - // improvement on some model. However, that is only more unbiased for - // data centered at zero, which was likely the case in that model, - // but is not always the case. If we wanted something more consistently - // unbiased then we should try breaking ties toward-nearest-even. -#if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING) - // Fix up values to be right-shifted, so that the (round to nearest, - // break ties upward) behavior of srshl applied to these fixed-up - // values, produces the same result as the desired (round to nearest, - // break ties away from zero) behavior on the original values. - "vand q8, q14, q12\n" - "vand q9, q15, q12\n" - "vshr.s32 q8, q8, #31\n" - "vshr.s32 q9, q9, #31\n" - "vqadd.s32 q14, q14, q8\n" - "vqadd.s34 q15, q15, q9\n" - -#endif - // At this point we have reduced the problem of correctly implementing - // rounding divide-by-power-of-two, to what the SRSHL instruction can - // do. - "vrshl.s32 q14, q14, q12\n" - "vrshl.s32 q15, q15, q12\n" - - "ldrb r10, [%[params], #" RUY_STR(RUY_OFFSET_DST_TYPE_ID) "]\n" - "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n" - "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n" - "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n" - "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n" - - // Store uint8 values: - RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n" - - // Cast-and-saturate from int32 to int16 - // After this, all values for output are in q14. - "vqmovn.s32 d28, q14\n" - "vqmovn.s32 d29, q15\n" - - // At this point, d12 -- d26, d30, d31 aren't used anymore for the - // current block, so we can start clearing these accumulators for the - // next block (next iteration of the main loop). - RUY_MAKE_ZERO(q6) - RUY_MAKE_ZERO(q7) - RUY_MAKE_ZERO(q8) - RUY_MAKE_ZERO(q9) - RUY_MAKE_ZERO(q10) - RUY_MAKE_ZERO(q11) - RUY_MAKE_ZERO(q12) - RUY_MAKE_ZERO(q13) - RUY_MAKE_ZERO(q15) - - // Load the destination zero point into each of the 8 16-bit slots - // in a q register. - "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" - "vdup.16 q13, r4\n" // dst_zero_point - - // Add the destination zero point - "vadd.i16 q14, q14, q13\n" - - // Cast-and-saturate from int16 to uint8 - // Now all 8 1-byte values are in d30. - "vqmovun.s16 d30, q14\n" - - // Load the clamp_min, clamp_max bounds - "ldrb r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" - "ldrb r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" - "vdup.8 d28, r2\n" // clamp_min - "vdup.8 d29, r3\n" // clamp_max - - // Apply the clamp_min bound - "vmax.u8 d30, d30, d28\n" - // Apply the clamp_max bound - "vmin.u8 d30, d30, d29\n" - - // Compute how much of the 4x2 block of destination 8bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 4x2, there are some 4x2 blocks along the boundaries that do - // not fit entirely. - - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n" - "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - "sub r1, r1, r8\n" - - "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n" - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" - "sub r2, r2, r4\n" - "mov r3, #4\n" - "mov r5, #2\n" - "cmp r1, #4\n" - // Compute r1 = how many rows of the 4x2 block fit - "it gt\n" - "movgt r1, r3\n" - - "cmp r2, #2\n" - // Compute r2 = how many cols of the 4x2 block fit - "it gt\n" - "movgt r2, r5\n" - - // Test if r1==4 && r2 == 2, i.e. if all of the 4x2 block fits. - "cmp r1, r3\n" - "it eq\n" - "cmpeq r2, r5\n" - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" - // Yes, all of the 4x2 block fits, go to fast path. - "beq 30f\n" - // Not all of the 4x2 block fits. - // Store to dst_tmp_buf - // Set r3 address to write to dst_tmp_buf. - "mov r3, %[dst_tmp_buf]\n" - "vst1.8 {d30}, [r3]\n" - - // Slow loop copying from dst_tmp_buf to dst. - "mov r6, #0\n" - "50:\n" - "mov r8, #0\n" - "51:\n" - "ldrb r10, [r3, r8]\n" - "strb r10, [r4, r8]\n" - "add r8, r8, #1\n" - "cmp r8, r1\n" - "blt 51b\n" - "add r6, r6, #1\n" - "add r3, r3, #4\n" - "add r4, r4, r5\n" - "cmp r6, r2\n" - "blt 50b\n" - "b 31f\n" - "30:\n" - // Yes, all of the 4x2 block fits. - // r3 address, r5 stride - "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "mov r4, r3\n" - "mov r6, #1\n" - - "vst1.32 {d30[0]}, [r3]\n" - "add r4, r4, r5\n" - "mov r3, r4\n" - "vst1.32 {d30[1]}, [r3]\n" - - "31:\n" - - // Load dst_ptr, increment, and write back. - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "add r4, r4, #4\n" - "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - - RUY_MAKE_ZERO(q13) - RUY_MAKE_ZERO(q14) - RUY_MAKE_ZERO(q15) - - "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" - - // Store int8 values: - RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n" - - // Cast-and-saturate from int32 to int16 - // After this, all values for output are in q14. - "vqmovn.s32 d28, q14\n" - "vqmovn.s32 d29, q15\n" - - // At this point, d12 -- d26, d30, d31 aren't used anymore for the - // current block, so we can start clearing these accumulators for the - // next block (next iteration of the main loop). - RUY_MAKE_ZERO(q6) - RUY_MAKE_ZERO(q7) - RUY_MAKE_ZERO(q8) - RUY_MAKE_ZERO(q9) - RUY_MAKE_ZERO(q10) - RUY_MAKE_ZERO(q11) - RUY_MAKE_ZERO(q12) - RUY_MAKE_ZERO(q13) - RUY_MAKE_ZERO(q15) - - // Load the destination zero point into each of the 8 16-bit slots - // in a q register. - "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" - "vdup.16 q13, r4\n" // dst_zero_point - - // Add the destination zero point - "vadd.i16 q14, q14, q13\n" - - // Cast-and-saturate from int16 to int8 - // Now all 8 1-byte values are in d30. - "vqmovn.s16 d30, q14\n" - - // Load the clamp_min, clamp_max bounds - "ldrb r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" - "ldrb r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" - "vdup.8 d28, r2\n" // clamp_min - "vdup.8 d29, r3\n" // clamp_max - - // Apply the clamp_min bound - "vmax.s8 d30, d30, d28\n" - // Apply the clamp_max bound - "vmin.s8 d30, d30, d29\n" - - // Compute how much of the 4x2 block of destination 8bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 4x2, there are some 4x2 blocks along the boundaries that do - // not fit entirely. - - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n" - "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - "sub r1, r1, r8\n" - - "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n" - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" - "sub r2, r2, r4\n" - "mov r3, #4\n" - "mov r5, #2\n" - "cmp r1, #4\n" - // Compute r1 = how many rows of the 4x2 block fit - "it gt\n" - "movgt r1, r3\n" - - "cmp r2, #2\n" - // Compute r2 = how many cols of the 4x2 block fit - "it gt\n" - "movgt r2, r5\n" - - // Test if r1==4 && r2 == 2, i.e. if all of the 4x2 block fits. - "cmp r1, r3\n" - "it eq\n" - "cmpeq r2, r5\n" - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" - // Yes, all of the 4x2 block fits, go to fast path. - "beq 30f\n" - // Not all of the 4x2 block fits. - // Store to dst_tmp_buf - // Set r3 address to write to dst_tmp_buf. - "mov r3, %[dst_tmp_buf]\n" - "vst1.8 {d30}, [r3]\n" - - // Slow loop copying from dst_tmp_buf to dst. - "mov r6, #0\n" - "50:\n" - "mov r8, #0\n" - "51:\n" - "ldrb r10, [r3, r8]\n" - "strb r10, [r4, r8]\n" - "add r8, r8, #1\n" - "cmp r8, r1\n" - "blt 51b\n" - "add r6, r6, #1\n" - "add r3, r3, #4\n" - "add r4, r4, r5\n" - "cmp r6, r2\n" - "blt 50b\n" - "b 31f\n" - "30:\n" - // Yes, all of the 4x2 block fits. - // r3 address, r5 stride - "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "mov r4, r3\n" - "mov r6, #1\n" - - "vst1.32 {d30[0]}, [r3]\n" - "add r4, r4, r5\n" - "mov r3, r4\n" - "vst1.32 {d30[1]}, [r3]\n" - - "31:\n" - - // Load dst_ptr, increment, and write back. - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "add r4, r4, #4\n" - "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - - RUY_MAKE_ZERO(q13) - RUY_MAKE_ZERO(q14) - RUY_MAKE_ZERO(q15) - - "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" - - RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n" - - // Load the destination zero point into each of the 4 32-bit slots - // in a q register. - "ldrsh r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" - "vdup.32 q13, r4\n" // dst_zero_point - // Add the destination zero point - "vadd.s32 q14, q14, q13\n" - "vadd.s32 q15, q15, q13\n" - - // Cast-and-saturate from int32 to int16 - // After this, all values for output are in q14. - "vqmovn.s32 d28, q14\n" - "vqmovn.s32 d29, q15\n" - - // At this point, v18 -- v31 aren't used anymore for the current block, - // so we can start clearing these accumulators for the next block - // (next iteration of the main loop). - RUY_MAKE_ZERO(q6) - RUY_MAKE_ZERO(q7) - RUY_MAKE_ZERO(q8) - RUY_MAKE_ZERO(q9) - RUY_MAKE_ZERO(q10) - RUY_MAKE_ZERO(q11) - RUY_MAKE_ZERO(q15) - - // Load the clamp_min, clamp_max bounds - "ldrh r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" - "ldrh r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" - "vdup.16 q12, r2\n" // clamp_min - "vdup.16 q13, r3\n" // clamp_max - - // Apply the clamp_min bound - "vmax.s16 q14, q14, q12\n" - // Apply the clamp_max bound - "vmin.s16 q14, q14, q13\n" - - RUY_MAKE_ZERO(q12) - RUY_MAKE_ZERO(q13) - - // Compute how much of the 4x2 block of destination 16-bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 4x2, there are some 4x2 blocks along the boundaries that do - // not fit entirely. - - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n" - "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - "sub r1, r1, r8\n" - - "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n" - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" - "sub r2, r2, r4\n" - "mov r3, #4\n" - "mov r5, #2\n" - "cmp r1, #4\n" - // Compute r1 = how many rows of the 4x2 block fit - "it gt\n" - "movgt r1, r3\n" - - "cmp r2, #2\n" - // Compute r2 = how many cols of the 4x2 block fit - "it gt\n" - "movgt r2, r5\n" - - // Test if r1==4 && r2 == 2, i.e. if all of the 4x2 block fits. - "cmp r1, r3\n" - "it eq\n" - "cmpeq r2, r5\n" - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" - // Yes, all of the 4x2 block fits, go to fast path. - "beq 30f\n" - // Not all of the 4x2 block fits. - // Store to dst_tmp_buf - // Set r3 address to write to dst_tmp_buf. - "mov r3, %[dst_tmp_buf]\n" - "vst1.16 {q14}, [r3]\n" - - // Slow loop copying from dst_tmp_buf to dst. - "mov r6, #0\n" - "50:\n" - "mov r8, #0\n" - "51:\n" - // Shift of offset register for half-word loads not allowed in A32, - // so we shift, load/store, then shift back r8. - "lsl r8, r8, #1\n" - "ldrh r10, [r3, r8]\n" - "strh r10, [r4, r8]\n" - "lsr r8, r8, #1\n" - "add r8, r8, #1\n" - "cmp r8, r1\n" - "blt 51b\n" - "add r6, r6, #1\n" - "add r3, r3, #8\n" - "add r4, r4, r5\n" - "cmp r6, r2\n" - "blt 50b\n" - "b 31f\n" - "30:\n" - // Yes, all of the 4x2 block fits. - // r3 address, r5 stride - "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "mov r4, r3\n" - "mov r6, #2\n" - - "vst1.16 {d28[0]}, [r3], r6\n" - "add r4, r4, r5\n" - "vst1.16 {d28[1]}, [r3], r6\n" - "vst1.16 {d28[2]}, [r3], r6\n" - "vst1.16 {d28[3]}, [r3], r6\n" - "mov r3, r4\n" - "vst1.16 {d29[0]}, [r3], r6\n" - "vst1.16 {d29[1]}, [r3], r6\n" - "vst1.16 {d29[2]}, [r3], r6\n" - "vst1.16 {d29[3]}, [r3], r6\n" - "31:\n" - - // Load dst_ptr, increment, and write back. - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "add r4, r4, #8\n" - "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - - RUY_MAKE_ZERO(q14) - - "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" - - RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n" - - // Since the store type is the same as the accum type, no need for - // downcast. There's also no need for clamp by min/max. - - // At this point, v20 -- v31 aren't used anymore for the current block, - // so we can start clearing these accumulators for the next block - // (next iteration of the main loop). - // Clear accumulators. - RUY_MAKE_ZERO(q6) - RUY_MAKE_ZERO(q7) - RUY_MAKE_ZERO(q8) - RUY_MAKE_ZERO(q9) - RUY_MAKE_ZERO(q10) - RUY_MAKE_ZERO(q11) - RUY_MAKE_ZERO(q12) - RUY_MAKE_ZERO(q13) - - // Compute how much of the 4x2 block of destination 32 bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 4x2, there are some 4x4 blocks along the boundaries that do - // not fit entirely. - - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n" - "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - "sub r1, r1, r8\n" - - "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n" - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" - "sub r2, r2, r4\n" - "mov r3, #4\n" - "mov r5, #2\n" - "cmp r1, #4\n" - // Compute r1 = how many rows of the 4x2 block fit - "it gt\n" - "movgt r1, r3\n" - - "cmp r2, #2\n" - // Compute r2 = how many cols of the 4x2 block fit - "it gt\n" - "movgt r2, r5\n" - - // Test if r1==4 && r2 == 2, i.e. if all of the 4x2 block fits. - "cmp r1, r3\n" - "it eq\n" - "cmpeq r2, r5\n" - // Yes, all of the 4x2 block fits, go to fast path. - "beq 30f\n" - // Not all of the 4x2 block fits. - // Set (r3 address, r4 stride) to write to dst_tmp_buf - "mov r3, %[dst_tmp_buf]\n" - "mov r4, #16\n" - "b 31f\n" - - "30:\n" - // Yes, all of the 4x2 block fits. - "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" - // r3 address, r4 stride - "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "mov r4, r5\n" - - "31:\n" - - "vst1.32 {d28, d29}, [r3]\n" - "add r3, r3, r4\n" - "vst1.32 {d30, d31}, [r3]\n" - - // If all of the 4x2 block fits, we just finished writing it to the - // destination, so we skip the next part. - "beq 41f\n" - // Not all of the 4x2 block fits in the destination matrix. We just - // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over - // it to copy into the destination matrix the part that fits. - "ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" - "mov r3, %[dst_tmp_buf]\n" - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "mov r6, #0\n" - "50:\n" - "mov r5, #0\n" - "51:\n" - "ldr r10, [r3, r5, lsl #2]\n" - "str r10, [r4, r5, lsl #2]\n" - "add r5, r5, #1\n" - "cmp r5, r1\n" - "blt 51b\n" - "add r6, r6, #1\n" - "add r3, r3, #16\n" - "add r4, r4, r8\n" - // r2 = how many cols of the 8x4 block fit - "cmp r6, r2\n" - "blt 50b\n" - - "41:\n" - // Load dst_ptr, increment, and write back. - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "add r4, r4, #16\n" - "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - - RUY_MAKE_ZERO(q10) - RUY_MAKE_ZERO(q11) - - "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" - - RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n" - - // Reload some params --- we had used x5 -- x7 for a few other things - // since the last time we had loaded them. - "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - "ldr r6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" - "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" - - // Move to the next block of the destination matrix, for the next iter - // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already - // been updated earlier. - // Have we reached the end row? - "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - "cmp r8, r3\n" - - "beq 20f\n" // yes, end row. - // Not end row. Move to the next row. - "add r8, r8, #4\n" - // Store new value of row - "str r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - - "b 21f\n" - "20:\n" - // Was already at end row. - // Move back to first row. - "str r6, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - // Move to the next column. - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" - "add r4, r4, #2\n" - "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" - - "ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" - "ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n" - // Increment dst_col_ptr by 2 * dst_stride (i.e. 2 columns) - "add r1, r1, r8, lsl #1\n" - // Store dst_col_ptr - "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n" - // Store dst_ptr - "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "21:\n" - - // Main loop exit condition: have we hit the end column? - "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" - "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" - "cmp r8, r4\n" - - // w1 is the number of levels of depth that we have already loaded - // LHS and RHS data for. Corresponding to the initial ld1 instructions - // above, this is currently 16. - "mov r1, #16\n" - - "ble 1b\n" - - // Restore stack pointer. - "add sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n" - - // clang-format on - - : [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr) - : [ params ] "r"(¶ms), [dst_tmp_buf] "r"(params.dst_tmp_buf) - : "r0", "r1", "r2", "r3", "r4", "r5", "r6", "r8", "r10", "cc", - // Clobber list must specify q registers (and not their constituent - // d registers). There is a (currently unexplained) slowdown if - // d registers are listed in the clobbers list. - "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", - "q9", "q10", "q12", "q13", "q14", "q15"); -} - -// Fast-int8 true "GEMV" kernel (RHS has 1 column). We assume the RHS -// is still packed as if it has two columns -void Kernel8bitNeonOutOfOrder1Col(const KernelParams8bit<4, 2>& params) { - profiler::ScopeLabel label( - "Kernel (kNeon, optimized for out-of-order cores)"); - - CheckOffsetsInKernelParams8bit(params); - - const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; - const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; - const std::int8_t* lhs_ptr = lhs_col_ptr; - const std::int8_t* rhs_ptr = rhs_col_ptr; - - // The asm kernel below has the following NEON register allocation: - // - // q6 - q13 are 128-bit (4x32b) accumulators. - // During accumulation, d0 -- d7 are used to load int8 data from LHS and - // d8 -- d11 from RHS: - // int8 RHS 16x1 block - // /------------\ - // | d8.b[0] | - // | ... | - // | d8.b[7] | - // | d9.b[0] | - // | ... | - // | d9.b[7] | - // \------------/ - // int8 LHS 4x16 block - // /-----------------------------------------\ /------------\ - // |d0.b[0] ... d0.b[7] d1.b[0] ... d1.b[7] | | q6 | - // |d2.b[0] ... d2.b[7] d3.b[0] ... d3.b[7] | | q7 | - // |d4.b[0] ... d4.b[7] d5.b[0] ... d5.b[7] | | q8 | - // |d6.b[0] ... d6.b[7] d7.b[0] ... d7.b[7] | | q9 | - // \-----------------------------------------/ \------------/ - // 128-bit accumulators 4x1 block - // - // No attempt had been made so far at implementing the RUY_OPT_MAX_STREAMING - // optimization for this kernel. - asm volatile( -#define RUY_MAKE_ZERO(reg) "vmov.i32 " #reg ", #0x00000000\n" - - // clang-format off - - // Load the first 64 bytes of LHS and RHS data. - "vld1.8 {d0, d1}, [%[lhs_ptr]]!\n" - "vld1.8 {d2, d3}, [%[lhs_ptr]]!\n" - "vld1.8 {d4, d5}, [%[lhs_ptr]]!\n" - "vld1.8 {d6, d7}, [%[lhs_ptr]]!\n" - "vld1.8 {d8, d9}, [%[rhs_ptr]]!\n" - // Skip the other column and advance the pointer. - "add %[rhs_ptr], %[rhs_ptr], #16\n" - - "sub sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n" - - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n" - "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n" - - "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n" - "str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" - "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - - "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_START_COL) "]\n" - "str r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" - - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" - - "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_RHS_BASE_PTR) "]\n" - "str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" - - // Clear accumulators. - RUY_MAKE_ZERO(q6) - RUY_MAKE_ZERO(q7) - RUY_MAKE_ZERO(q8) - RUY_MAKE_ZERO(q9) - RUY_MAKE_ZERO(q10) - RUY_MAKE_ZERO(q11) - RUY_MAKE_ZERO(q12) - RUY_MAKE_ZERO(q13) - RUY_MAKE_ZERO(q14) - RUY_MAKE_ZERO(q15) - - // r1 is the number of levels of depth that we have already loaded - // LHS and RHS data for. Corresponding to the initial ld1 instructions - // above, this is currently 16. - "mov r1, #16\n" - - // Main loop of the whole GEMM, over rows and columns of the - // destination matrix. - "1:\n" - - // r1 is how many levels of depth we have already loaded - // data for, r10 is the total depth. - "ldr r10, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" - "cmp r1, r10\n" - "beq 79f\n" - - "2:\n" - - // Mult, mult-acc in to q14, q15 - "vmull.s8 q14, d0, d8\n" - "vmull.s8 q15, d2, d8\n" - "vmlal.s8 q14, d1, d9\n" - "vmlal.s8 q15, d3, d9\n" - - // Then pairwise accumulate in to q6, q7 - "vpadal.s16 q6, q14\n" - "vpadal.s16 q7, q15\n" - - // Mult, mult-acc in to q14, q15 - "vmull.s8 q14, d4, d8\n" - "vmull.s8 q15, d6, d8\n" - "vmlal.s8 q14, d5, d9\n" - "vmlal.s8 q15, d7, d9\n" - - // Then pairwise accumulate in to q8, q9 - "vpadal.s16 q8, q14\n" - "vpadal.s16 q9, q15\n" - - - // Load the next 64 bytes of LHS and RHS data. - "vld1.8 {d0, d1}, [%[lhs_ptr]]!\n" - "vld1.8 {d2, d3}, [%[lhs_ptr]]!\n" - "vld1.8 {d4, d5}, [%[lhs_ptr]]!\n" - "vld1.8 {d6, d7}, [%[lhs_ptr]]!\n" - RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n") - "vld1.8 {d8, d9}, [%[rhs_ptr]]!\n" - // Skip the other column and advance the pointer. - "add %[rhs_ptr], %[rhs_ptr], #16\n" - RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n") - - // Each iteration of this loop advances by 16 levels of depth. - "add r1, r1, #16\n" - - // Loop termination condition - "cmp r1, r10\n" - - "blt 2b\n" - - "79:\n" - - // Mult, mult-acc in to q14, q15 - "vmull.s8 q14, d0, d8\n" - "vmull.s8 q15, d2, d8\n" - "vmlal.s8 q14, d1, d9\n" - "vmlal.s8 q15, d3, d9\n" - - // Then pairwise accumulate in to q6, q7 - "vpadal.s16 q6, q14\n" - "vpadal.s16 q7, q15\n" - - // Mult, mult-acc in to q14, q15 - "vmull.s8 q14, d4, d8\n" - "vmull.s8 q15, d6, d8\n" - "vmlal.s8 q14, d5, d9\n" - "vmlal.s8 q15, d7, d9\n" - - // Then pairwise accumulate in to q8, q9 - "vpadal.s16 q8, q14\n" - "vpadal.s16 q9, q15\n" - - // All accumulation over depth done. q6 - q9 contain the 4x32b - // accumulators for the 4x1 final matrix. - // We now have to compute the final 8-bit values from these int32 - // accumulators, and advance to the next 4x2 block. We intertwine - // these two aspects whenever possible for optimal pipelining, both - // at the data flow level (prefetch data for next block as early as - // possible) and instruction pipelining level (some of the next-block - // work can dual-issue with some of the final work on the current - // block). - - // q6-q9 now contain 4 x 32b - "vpadd.i32 d0, d12, d13\n" - "vpadd.i32 d1, d14, d15\n" - "vpadd.i32 d2, d16, d17\n" - "vpadd.i32 d3, d18, d19\n" - - // d0-d4 each contain 2 x 32b accumulators. - // Need to add pairwise to get 1 x 32b for each of the 4x1 entries - // of destination, (Four 'd' registers total) - "vpadd.i32 d28, d0, d1\n" - "vpadd.i32 d29, d2, d3\n" - - // Now d28,d29 have the 1 x 32b accumulators for the 4x1 entries. - - // Logic to advance to the next block in preparation for the next - // iteration of the main loop. For now, we only want to compute - // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are - // not yet ready to update the values of row and col, as we still need - // the current values for the rest of the work on the current block. - - "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" - "ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - "cmp r1, r3\n" // Have we finished the last row? - - "bge 4f\n" // If finished last row, go to 4 - // Not finished last row: then advance to next row. - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" - "add r4, r4, r1, lsl #2\n" - "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" - "b 5f\n" - "4:\n" // Finished last row... - "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - // Go back to first row - "str r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" - - // Now we need to advance to the next column. If we already - // finished the last column, then in principle we are done, however - // we can't just return here, as we need to allow the end work of the - // current block to complete. The good news is that at this point it - // doesn't matter what data we load for the next column, since - // we will exit from the main loop below before actually storing - // anything computed from that data. - - "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" - "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" - "cmp r8, r4\n" // Have we finished the last column? - "bge 5f\n" // If yes, just carry on without updating the column pointer. - // Not finished last column: then advance to next column. - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" - "ldr r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" - "add r10, r10, r1, lsl #1\n" - "str r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" - "5:\n" - - // Set the LHS and RHS data pointers to the start of the columns just - // computed. - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" - "mov %[lhs_ptr], r4\n" - "ldr r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" - "mov %[rhs_ptr], r5\n" - - // Now we load: bias data, LHS sums data, RHS sums data. - - // First, load the base pointers from the params. - "ldrb r4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" - - // Offset these base pointers as needed given the current row, col. - "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - "add r5, r1, r8, lsl #2\n" - - "tst r4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" - "it ne\n" - "movne r1, r5\n" - - // Load 4 bias values. - "vld1.32 {d24, d25}, [r1]\n" - - // Now that we know what LHS and RHS data the next iteration of the - // main loop will need to load, we start loading the first 32 bytes of - // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore - // in the rest of the work on the current block. - "vld1.8 {d0, d1}, [%[lhs_ptr]]!\n" - "vld1.8 {d2, d3}, [%[lhs_ptr]]!\n" - "vld1.8 {d4, d5}, [%[lhs_ptr]]!\n" - "vld1.8 {d6, d7}, [%[lhs_ptr]]!\n" - RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n") - "vld1.8 {d8, d9}, [%[rhs_ptr]]!\n" - // Skip the other column and advance the pointer. - "add %[rhs_ptr], %[rhs_ptr], #16\n" - RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n") - - // Add to the bias values the product - // (depth * lhs_zero_point * rhs_zero_point), - // See the term NZ1Z2 in equation (7) in - // https://arxiv.org/pdf/1712.05877.pdf - "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n" - "vdup.32 q9, r3\n" - "vadd.i32 q12, q12, q9\n" - - // Perform the bias-addition (per the above, we have just folded into - // the bias the (depth * lhs_zero_point * rhs_zero_point) term.) - "vadd.i32 q14, q14, q12\n" - - // LHS/RHS zero points - // Has RHS sums - "ldrb r6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" - "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n" - "beq 401f\n" - "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n" - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" - // Offset by current col * number of bytes per value - "add r3, r3, r4, lsl #2\n" - "vld1.32 { d12 }, [r3]\n" - "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n" - "vdup.32 q10, r5\n" // create lhs_zero_point_vec - // Subtract rhs_sums * lhs_zero_point, per - // equation (7) in https://arxiv.org/pdf/1712.05877.pdf - "vmls.i32 q14, q10, d12[0]\n" - "401:\n" - - // Has LHS sums - "ldrb r6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" - "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n" - "beq 402f\n" - "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n" - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - // Offset by current row * number of bytes per value - "add r2, r2, r4, lsl #2\n" - "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n" - - // Load 4 lhs_sums values. - "vld1.32 {d22, d23}, [r2]\n" - "vdup.32 d13, r5\n" // rhs_zero_point - - // Compute lhs_sums * rhs_zero_point. - "vmul.i32 q11, q11, d13[1]\n" - // Subtract lhs_sums * rhs_zero_point, per - // equation (7) in https://arxiv.org/pdf/1712.05877.pdf - "vsub.s32 q14, q14, q11\n" - - // If the destination is int32, it means the user asks for the raw - // accumulators, no need for us to downquantize the value. - "ldrb r10, [%[params], #" RUY_STR(RUY_OFFSET_DST_TYPE_ID) "]\n" - "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n" - "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n" - - "402:\n" - - // At this point we have computed the final int32 values. Now we - // start down-quantizing them to obtain the final 8bit values from them. - - // As part of this down-quantization, our int32 values will be - // multiplied by a multiplier that has a fixed-point component and an - // exponent component. - - //Load the exponent part of the multiplier. - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n" - "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - "add r5, r1, r4, lsl #2\n" - "it ne\n" - "movne r1, r5\n" - - "vld1.32 {q10}, [r1]\n" - - RUY_MAKE_ZERO(q8) - "vmax.s32 q12, q10, q8\n" - - "vshl.s32 q14, q14, q12\n" - - "vmin.s32 q12, q10, q8\n" - - // Load fixed point part of the multiplier - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n" - // r6 has flags, r4 has row - "add r5, r1, r4, lsl #2\n" - "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" - "it ne\n" - "movne r1, r5\n" - "vld1.32 {q10}, [r1]\n" // multiplier_fixedpoint - - // Apply the fixed-point part of the multiplier. - "vqrdmulh.s32 q14, q14, q10\n" - - // We have some rounding division-by-power-of-two to do. This should - // always use "round to nearest". We allow for some - // freedom in how ties are broken, to strike a good compromise of - // performance on given hardware vs. perfect agreement of results - // across hardware. - // - // When RUY_OPT_NATIVE_ROUNDING is enabled, we allow for implementation - // defined tie-breaks to help performance. On NEON, this means that we - // can just use the NEON rounding instructions, such as srshl. They - // happen to be breaking ties upward. - // - // When RUY_OPT_NATIVE_ROUNDING is disabled, we implement strict - // break-ties-away-from zero, as described in Appendix B of - // https://arxiv.org/pdf/1712.05877.pdf - // When we wrote that, we thought that that would be better unbiased - // than the NEON upwards tie-breaks, and we had observed some - // improvement on some model. However, that is only more unbiased for - // data centered at zero, which was likely the case in that model, - // but is not always the case. If we wanted something more consistently - // unbiased then we should try breaking ties toward-nearest-even. -#if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING) - // Fix up values to be right-shifted, so that the (round to nearest, - // break ties upward) behavior of srshl applied to these fixed-up - // values, produces the same result as the desired (round to nearest, - // break ties away from zero) behavior on the original values. - "vand q8, q14, q12\n" - "vshr.s32 q8, q8, #31\n" - "vqadd.s32 q14, q14, q8\n" - -#endif - // At this point we have reduced the problem of correctly implementing - // rounding divide-by-power-of-two, to what the SRSHL instruction can - // do. - "vrshl.s32 q14, q14, q12\n" - - "ldrb r10, [%[params], #" RUY_STR(RUY_OFFSET_DST_TYPE_ID) "]\n" - "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n" - "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n" - "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n" - "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n" - - // Store uint8 values: - RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n" - - // Cast-and-saturate from int32 to int16 - // After this, all values for output are in d28. - "vqmovn.s32 d28, q14\n" - - // At this point, d12 -- d26, d29, d30, d31 aren't used anymore for the - // current block, so we can start clearing these accumulators for the - // next block (next iteration of the main loop). - RUY_MAKE_ZERO(q6) - RUY_MAKE_ZERO(q7) - RUY_MAKE_ZERO(q8) - RUY_MAKE_ZERO(q9) - RUY_MAKE_ZERO(q10) - RUY_MAKE_ZERO(q11) - RUY_MAKE_ZERO(q12) - RUY_MAKE_ZERO(q13) - RUY_MAKE_ZERO(q15) - - // Load the destination zero point into each of the 8 16-bit slots - // in a q register. - "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" - "vdup.16 q13, r4\n" // dst_zero_point - - // Add the destination zero point - "vadd.i16 q14, q14, q13\n" - - // Cast-and-saturate from int16 to uint8 - "vqmovun.s16 d30, q14\n" - // At this point, we only need 4 8-bit values in the lower half - // of d30. - - - // Load the clamp_min, clamp_max bounds - "ldrb r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" - "ldrb r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" - "vdup.8 d28, r2\n" // clamp_min - "vdup.8 d29, r3\n" // clamp_max - - // Apply the clamp_min bound - "vmax.u8 d30, d30, d28\n" - // Apply the clamp_max bound - "vmin.u8 d30, d30, d29\n" - - // Compute how much of the 4x1 block of destination 8bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 4x1, there are some 4x1 blocks along the boundaries that do - // not fit entirely. - - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n" - "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - "sub r1, r1, r8\n" - - "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n" - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" - "sub r2, r2, r4\n" - "mov r3, #4\n" - "mov r5, #2\n" - "cmp r1, #4\n" - // Compute r1 = how many rows of the 4x1 block fit - "it gt\n" - "movgt r1, r3\n" - - // Test if r1==4, i.e. if all of the 4x1 block fits. - "cmp r1, r3\n" - - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" - // Yes, all of the 4x1 block fits, go to fast path. - "beq 30f\n" - // Not all of the 4x1 block fits. - // Store to dst_tmp_buf - // Set r3 address to write to dst_tmp_buf. - "mov r3, %[dst_tmp_buf]\n" - "vst1.8 {d30}, [r3]\n" - - // Slow loop copying from dst_tmp_buf to dst. - "50:\n" - "mov r8, #0\n" - "51:\n" - "ldrb r10, [r3, r8]\n" - "strb r10, [r4, r8]\n" - "add r8, r8, #1\n" - "cmp r8, r1\n" - "blt 51b\n" - "b 31f\n" - "30:\n" - // Yes, all of the 4x1 block fits. - // r3 address, r5 stride - "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "mov r4, r3\n" - "mov r6, #1\n" - - "vst1.8 {d30[0]}, [r3], r6\n" - "vst1.8 {d30[1]}, [r3], r6\n" - "vst1.8 {d30[2]}, [r3], r6\n" - "vst1.8 {d30[3]}, [r3], r6\n" - "31:\n" - - // Load dst_ptr, increment, and write back. - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "add r4, r4, #4\n" - "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - - RUY_MAKE_ZERO(q13) - RUY_MAKE_ZERO(q14) - RUY_MAKE_ZERO(q15) - - "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" - - // Store int8 values: - RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n" - - // Cast-and-saturate from int32 to int16 - // After this, all values for output are in d28. - "vqmovn.s32 d28, q14\n" - - // At this point, d12 -- d26, d29, d30, d31 aren't used anymore for the - // current block, so we can start clearing these accumulators for the - // next block (next iteration of the main loop). - RUY_MAKE_ZERO(q6) - RUY_MAKE_ZERO(q7) - RUY_MAKE_ZERO(q8) - RUY_MAKE_ZERO(q9) - RUY_MAKE_ZERO(q10) - RUY_MAKE_ZERO(q11) - RUY_MAKE_ZERO(q12) - RUY_MAKE_ZERO(q13) - RUY_MAKE_ZERO(q15) - - // Load the destination zero point into each of the 8 16-bit slots - // in a q register. - "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" - "vdup.16 q13, r4\n" // dst_zero_point - - // Add the destination zero point - "vadd.i16 q14, q14, q13\n" - - // Cast-and-saturate from int16 to int8 - "vqmovn.s16 d30, q14\n" - // At this point, we only need 4 8-bit values in the lower half - // of d30. - - // Load the clamp_min, clamp_max bounds - "ldrb r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" - "ldrb r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" - "vdup.8 d28, r2\n" // clamp_min - "vdup.8 d29, r3\n" // clamp_max - - // Apply the clamp_min bound - "vmax.s8 d30, d30, d28\n" - // Apply the clamp_max bound - "vmin.s8 d30, d30, d29\n" - - // Compute how much of the 4x1 block of destination 8bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 4x2, there are some 4x2 blocks along the boundaries that do - // not fit entirely. - - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n" - "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - "sub r1, r1, r8\n" - - "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n" - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" - "sub r2, r2, r4\n" - "mov r3, #4\n" - "mov r5, #2\n" - "cmp r1, #4\n" - // Compute r1 = how many rows of the 4x2 block fit - "it gt\n" - "movgt r1, r3\n" - - // Test if r1==4 i.e. if all of the 4x1 block fits. - "cmp r1, r3\n" - - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" - // Yes, all of the 4x2 block fits, go to fast path. - "beq 30f\n" - // Not all of the 4x2 block fits. - // Store to dst_tmp_buf - // Set r3 address to write to dst_tmp_buf. - "mov r3, %[dst_tmp_buf]\n" - "vst1.8 {d30}, [r3]\n" - - // Slow loop copying from dst_tmp_buf to dst. - "50:\n" - "mov r8, #0\n" - "51:\n" - "ldrb r10, [r3, r8]\n" - "strb r10, [r4, r8]\n" - "add r8, r8, #1\n" - "cmp r8, r1\n" - "blt 51b\n" - "b 31f\n" - "30:\n" - // Yes, all of the 4x1 block fits. - // r3 address, r5 stride - "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "mov r4, r3\n" - "mov r6, #1\n" - - "vst1.8 {d30[0]}, [r3], r6\n" - "vst1.8 {d30[1]}, [r3], r6\n" - "vst1.8 {d30[2]}, [r3], r6\n" - "vst1.8 {d30[3]}, [r3], r6\n" - "31:\n" - - // Load dst_ptr, increment, and write back. - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "add r4, r4, #4\n" - "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - - RUY_MAKE_ZERO(q13) - RUY_MAKE_ZERO(q14) - RUY_MAKE_ZERO(q15) - - "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" - - RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n" - - // Load the destination zero point into each of the 4 32-bit slots - // in a q register. - "ldrsh r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" - "vdup.32 q13, r4\n" // dst_zero_point - // Add the destination zero point - "vadd.s32 q14, q14, q13\n" - //"vadd.s32 q15, q15, q13\n" - - // Cast-and-saturate from int32 to int16 - // After this, all values for output are in d28. - "vqmovn.s32 d28, q14\n" - - // At this point, d12 -- d26, d29, d30, d31 aren't used anymore for the - // so we can start clearing these accumulators for the next block - // (next iteration of the main loop). - RUY_MAKE_ZERO(q6) - RUY_MAKE_ZERO(q7) - RUY_MAKE_ZERO(q8) - RUY_MAKE_ZERO(q9) - RUY_MAKE_ZERO(q10) - RUY_MAKE_ZERO(q11) - RUY_MAKE_ZERO(q15) - - // Load the clamp_min, clamp_max bounds - "ldrh r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" - "ldrh r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" - "vdup.16 d24, r2\n" // clamp_min - "vdup.16 d26, r3\n" // clamp_max - - // Apply the clamp_min bound - "vmax.s16 d28, d28, d24\n" - // Apply the clamp_max bound - "vmin.s16 d28, d28, d26\n" - - RUY_MAKE_ZERO(q12) - RUY_MAKE_ZERO(q13) - - // Compute how much of the 4x1 block of destination 16-bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 4x1, there are some 4x1 blocks along the boundaries that do - // not fit entirely. - - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n" - "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - "sub r1, r1, r8\n" - - "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n" - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" - "sub r2, r2, r4\n" - "mov r3, #4\n" - "mov r5, #2\n" - "cmp r1, #4\n" - // Compute r1 = how many rows of the 4x1 block fit - "it gt\n" - "movgt r1, r3\n" - - // Test if r1==4, i.e. if all of the 4x1 block fits. - "cmp r1, r3\n" - - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" - // Yes, all of the 4x1 block fits, go to fast path. - "beq 30f\n" - // Not all of the 4x1 block fits. - // Store to dst_tmp_buf - // Set r3 address to write to dst_tmp_buf. - "mov r3, %[dst_tmp_buf]\n" - "vst1.16 {d28}, [r3]\n" - - // Slow loop copying from dst_tmp_buf to dst. - "50:\n" - "mov r8, #0\n" - "51:\n" - // Shift of offset register for half-word loads not allowed in A32, - // so we shift, load/store, then shift back r8. - "lsl r8, r8, #1\n" - "ldrh r10, [r3, r8]\n" - "strh r10, [r4, r8]\n" - "lsr r8, r8, #1\n" - "add r8, r8, #1\n" - "cmp r8, r1\n" - "blt 51b\n" - "b 31f\n" - "30:\n" - // Yes, all of the 4x1 block fits. - // r3 address, r5 stride - "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "mov r4, r3\n" - "mov r6, #2\n" - - "vst1.16 {d28[0]}, [r3], r6\n" - "vst1.16 {d28[1]}, [r3], r6\n" - "vst1.16 {d28[2]}, [r3], r6\n" - "vst1.16 {d28[3]}, [r3], r6\n" - "31:\n" - - // Load dst_ptr, increment, and write back. - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "add r4, r4, #8\n" - "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - - RUY_MAKE_ZERO(q14) - - "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" - - RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n" - - // Since the store type is the same as the accum type, no need for - // downcast. There's also no need for clamp by min/max. - - // At this point, v20 -- v31 aren't used anymore for the current block, - // so we can start clearing these accumulators for the next block - // (next iteration of the main loop). - // Clear accumulators. - RUY_MAKE_ZERO(q6) - RUY_MAKE_ZERO(q7) - RUY_MAKE_ZERO(q8) - RUY_MAKE_ZERO(q9) - RUY_MAKE_ZERO(q10) - RUY_MAKE_ZERO(q11) - RUY_MAKE_ZERO(q12) - RUY_MAKE_ZERO(q13) - - // Compute how much of the 4x1 block of destination 32 bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 4x2, there are some 4x4 blocks along the boundaries that do - // not fit entirely. - - "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n" - "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - "sub r1, r1, r8\n" - - "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n" - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" - "sub r2, r2, r4\n" - "mov r3, #4\n" - "mov r5, #2\n" - "cmp r1, #4\n" - // Compute r1 = how many rows of the 4x2 block fit - "it gt\n" - "movgt r1, r3\n" - - // Test if r1==4, i.e. if all of the 4x1 block fits. - "cmp r1, r3\n" - - // Yes, all of the 4x1 block fits, go to fast path. - "beq 30f\n" - // Not all of the 4x1 block fits. - // Set (r3 address, r4 stride) to write to dst_tmp_buf - "mov r3, %[dst_tmp_buf]\n" - "mov r4, #16\n" - "b 31f\n" - - "30:\n" - // Yes, all of the 4x1 block fits. - "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" - // r3 address, r4 stride - "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "mov r4, r5\n" - - "31:\n" - - "vst1.32 {d28, d29}, [r3]\n" - - // If all of the 4x1 block fits, we just finished writing it to the - // destination, so we skip the next part. - "beq 41f\n" - // Not all of the 4x1 block fits in the destination matrix. We just - // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over - // it to copy into the destination matrix the part that fits. - "ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" - "mov r3, %[dst_tmp_buf]\n" - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "50:\n" - "mov r5, #0\n" - "51:\n" - "ldr r10, [r3, r5, lsl #2]\n" - "str r10, [r4, r5, lsl #2]\n" - "add r5, r5, #1\n" - "cmp r5, r1\n" - "blt 51b\n" - - "41:\n" - // Load dst_ptr, increment, and write back. - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "add r4, r4, #16\n" - "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - - RUY_MAKE_ZERO(q10) - RUY_MAKE_ZERO(q11) - - "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" - - RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n" - - // Reload some params --- we had used x5 -- x7 for a few other things - // since the last time we had loaded them. - "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - "ldr r6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" - "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" - - // Move to the next block of the destination matrix, for the next iter - // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already - // been updated earlier. - // Have we reached the end row? - "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - "cmp r8, r3\n" - - "beq 20f\n" // yes, end row. - // Not end row. Move to the next row. - "add r8, r8, #4\n" - // Store new value of row - "str r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - - "b 21f\n" - "20:\n" - // Was already at end row. - // Move back to first row. - "str r6, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" - // Move to the next column. - "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" - "add r4, r4, #2\n" - "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" - - "ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" - "ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n" - // Increment dst_col_ptr by dst_stride (i.e. 1 column) - "add r1, r1, r8\n" - // Store dst_col_ptr - "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n" - // Store dst_ptr - "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" - "21:\n" - - // Main loop exit condition: have we hit the end column? - "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" - "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" - "cmp r8, r4\n" - - // w1 is the number of levels of depth that we have already loaded - // LHS and RHS data for. Corresponding to the initial ld1 instructions - // above, this is currently 16. - "mov r1, #16\n" - - "ble 1b\n" - - // Restore stack pointer. - "add sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n" - - // clang-format on - - : [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr) - : [ params ] "r"(¶ms), [dst_tmp_buf] "r"(params.dst_tmp_buf) - : "r0", "r1", "r2", "r3", "r4", "r5", "r6", "r8", "r10", "cc", - // Clobber list must specify q registers (and not their constituent - // d registers). There is a (currently unexplained) slowdown if - // d registers are listed in the clobbers list. - "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", - "q9", "q10", "q12", "q13", "q14", "q15"); -} - -#undef RUY_OFFSET_BIAS -#undef RUY_OFFSET_LHS_SUMS -#undef RUY_OFFSET_RHS_SUMS -#undef RUY_OFFSET_LHS_BASE_PTR -#undef RUY_OFFSET_MULTIPLIER_FIXEDPOINT -#undef RUY_OFFSET_MULTIPLIER_EXPONENT -#undef RUY_OFFSET_RHS_BASE_PTR -#undef RUY_OFFSET_DST_BASE_PTR -#undef RUY_OFFSET_LHS_ZERO_POINT -#undef RUY_OFFSET_RHS_ZERO_POINT -#undef RUY_OFFSET_DST_ZERO_POINT -#undef RUY_OFFSET_PROD_ZP_DEPTH -#undef RUY_OFFSET_START_ROW -#undef RUY_OFFSET_START_COL -#undef RUY_OFFSET_LAST_ROW -#undef RUY_OFFSET_LAST_COL -#undef RUY_OFFSET_DST_ROWS -#undef RUY_OFFSET_DST_COLS -#undef RUY_OFFSET_LHS_STRIDE -#undef RUY_OFFSET_RHS_STRIDE -#undef RUY_OFFSET_DST_STRIDE -#undef RUY_OFFSET_DEPTH -#undef RUY_OFFSET_CLAMP_MIN -#undef RUY_OFFSET_CLAMP_MAX -#undef RUY_OFFSET_FLAGS -#undef RUY_OFFSET_DST_TYPE_ID - -#undef RUY_STACK_OFFSET_SIZE -#undef RUY_STACK_OFFSET_DST_COL_PTR -#undef RUY_STACK_OFFSET_DST_PTR -#undef RUY_STACK_OFFSET_ROW -#undef RUY_STACK_OFFSET_COL -#undef RUY_STACK_OFFSET_LHS_COL_PTR -#undef RUY_STACK_OFFSET_RHS_COL_PTR - -#endif // RUY_PLATFORM(NEON_32) && (RUY_OPT_ENABLED(RUY_OPT_ASM) -} // namespace ruy diff --git a/kernel_arm64.cc b/kernel_arm64.cc deleted file mode 100644 index 52381fd..0000000 --- a/kernel_arm64.cc +++ /dev/null @@ -1,7835 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include - -#include "common.h" -#include "kernel.h" -#include "opt_set.h" -#include "platform.h" -#include "profiler/instrumentation.h" - -namespace ruy { - -#if RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM) - -#define RUY_ASM_LABEL_STORE_UINT8 91 -#define RUY_ASM_LABEL_STORE_INT8 92 -#define RUY_ASM_LABEL_STORE_INT16 93 -#define RUY_ASM_LABEL_STORE_INT32 94 -#define RUY_ASM_LABEL_AFTER_STORE 99 - -#define RUY_OFFSET_BIAS 0 -#define RUY_OFFSET_LHS_SUMS 8 -#define RUY_OFFSET_RHS_SUMS 16 -#define RUY_OFFSET_LHS_BASE_PTR 24 -#define RUY_OFFSET_MULTIPLIER_FIXEDPOINT 32 -#define RUY_OFFSET_MULTIPLIER_EXPONENT 40 -#define RUY_OFFSET_RHS_BASE_PTR 48 -#define RUY_OFFSET_DST_BASE_PTR 56 -#define RUY_OFFSET_LHS_ZERO_POINT 64 -#define RUY_OFFSET_RHS_ZERO_POINT 68 -#define RUY_OFFSET_DST_ZERO_POINT 72 -#define RUY_OFFSET_PROD_ZP_DEPTH 76 -#define RUY_OFFSET_START_ROW 80 -#define RUY_OFFSET_START_COL 84 -#define RUY_OFFSET_LAST_ROW 88 -#define RUY_OFFSET_LAST_COL 92 -#define RUY_OFFSET_DST_ROWS 96 -#define RUY_OFFSET_DST_COLS 100 -#define RUY_OFFSET_LHS_STRIDE 104 -#define RUY_OFFSET_RHS_STRIDE 108 -#define RUY_OFFSET_DST_STRIDE 112 -#define RUY_OFFSET_DEPTH 116 -#define RUY_OFFSET_CLAMP_MIN 120 -#define RUY_OFFSET_CLAMP_MAX 124 -#define RUY_OFFSET_FLAGS 128 - -template -void CheckOffsetsInKernelParams8bit(const Params&) { - static_assert(offsetof(Params, lhs_zero_point) == RUY_OFFSET_LHS_ZERO_POINT, - ""); - static_assert(offsetof(Params, rhs_zero_point) == RUY_OFFSET_RHS_ZERO_POINT, - ""); - static_assert(offsetof(Params, dst_zero_point) == RUY_OFFSET_DST_ZERO_POINT, - ""); - static_assert(offsetof(Params, prod_zp_depth) == RUY_OFFSET_PROD_ZP_DEPTH, - ""); - static_assert(offsetof(Params, multiplier_fixedpoint) == - RUY_OFFSET_MULTIPLIER_FIXEDPOINT, - ""); - static_assert( - offsetof(Params, multiplier_exponent) == RUY_OFFSET_MULTIPLIER_EXPONENT, - ""); - static_assert(offsetof(Params, clamp_min) == RUY_OFFSET_CLAMP_MIN, ""); - static_assert(offsetof(Params, clamp_max) == RUY_OFFSET_CLAMP_MAX, ""); - static_assert(offsetof(Params, bias) == RUY_OFFSET_BIAS, ""); - static_assert(offsetof(Params, lhs_sums) == RUY_OFFSET_LHS_SUMS, ""); - static_assert(offsetof(Params, rhs_sums) == RUY_OFFSET_RHS_SUMS, ""); - static_assert(offsetof(Params, flags) == RUY_OFFSET_FLAGS, ""); - static_assert(offsetof(Params, lhs_base_ptr) == RUY_OFFSET_LHS_BASE_PTR, ""); - static_assert(offsetof(Params, start_row) == RUY_OFFSET_START_ROW, ""); - static_assert(offsetof(Params, last_row) == RUY_OFFSET_LAST_ROW, ""); - static_assert(offsetof(Params, last_col) == RUY_OFFSET_LAST_COL, ""); - static_assert(offsetof(Params, lhs_stride) == RUY_OFFSET_LHS_STRIDE, ""); - static_assert(offsetof(Params, rhs_stride) == RUY_OFFSET_RHS_STRIDE, ""); - static_assert(offsetof(Params, dst_stride) == RUY_OFFSET_DST_STRIDE, ""); - static_assert(offsetof(Params, depth) == RUY_OFFSET_DEPTH, ""); -} - -// Fast-int8-trick kernel, similar to this production gemmlowp kernel: -// NEON_64bit_GEMM_Int8Operands_AccumTwoWithin16Bits -// https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L2296 -// -// Relevant target CPUs for this kernel include ARM Cortex-A73 and Cortex-A75, -// since these are 64-bit, out-of-order and without dotprod support. -void Kernel8bitNeonOutOfOrder(const KernelParams8bit<4, 4>& params) { - profiler::ScopeLabel label( - "Kernel (kNeon, optimized for out-of-order cores)"); - - CheckOffsetsInKernelParams8bit(params); - - const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; - const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; - const std::int8_t* lhs_ptr = lhs_col_ptr; - const std::int8_t* rhs_ptr = rhs_col_ptr; - void* dst_col_ptr = params.dst_base_ptr; - void* dst_ptr = dst_col_ptr; - int row = params.start_row; - int col = params.start_col; - - // The asm kernel below has the following NEON register allocation: - // - // v16 -- v31 are int32 accumulators. - // During accumulation, v0 -- v3 are used to load int8 data from LHS and - // v4 -- v7 from RHS: - // - // int8 RHS 16x4 block - // /-----------------------------------------\ - // |v4.b[0] ... v7.b[0] | - // | ... ... | - // |v4.b[15] ... v7.b[15] | - // \-----------------------------------------/ - // int8 LHS 4x16 block - // /---------------------\ /-----------------------------------------\ - // |v0.b[0] ... v0.b[15] | |v16.4s ... v28.4s | - // |v1.b[0] ... v1.b[15] | |v17.4s ... v29.4s | - // |v2.b[0] ... v2.b[15] | |v18.4s ... v30.4s | - // |v3.b[0] ... v3.b[15] | |v19.4s ... v31.4s | - // \---------------------/ \-----------------------------------------/ - // int32 accumulators 4x4 block - // - // No attempt had been made so far at implementing the RUY_OPT_MAX_STREAMING - // optimization for this kernel. - asm volatile( -#define RUY_MAKE_ZERO(reg) "dup " #reg ".4s, wzr\n" - - // clang-format off - - // Load some parameters into registers. - "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" - "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" - "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" - "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" - "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" - "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" - "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" - - // Load the first 64 bytes of LHS and RHS data. - "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v2.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v3.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v4.16b}, [%[rhs_ptr]], #16\n" - "ld1 {v5.16b}, [%[rhs_ptr]], #16\n" - "ld1 {v6.16b}, [%[rhs_ptr]], #16\n" - "ld1 {v7.16b}, [%[rhs_ptr]], #16\n" - - // Clear accumulators. - RUY_MAKE_ZERO(v16) - RUY_MAKE_ZERO(v17) - RUY_MAKE_ZERO(v18) - RUY_MAKE_ZERO(v19) - RUY_MAKE_ZERO(v20) - RUY_MAKE_ZERO(v21) - RUY_MAKE_ZERO(v22) - RUY_MAKE_ZERO(v23) - RUY_MAKE_ZERO(v24) - RUY_MAKE_ZERO(v25) - RUY_MAKE_ZERO(v26) - RUY_MAKE_ZERO(v27) - RUY_MAKE_ZERO(v28) - RUY_MAKE_ZERO(v29) - RUY_MAKE_ZERO(v30) - RUY_MAKE_ZERO(v31) - - // w1 is the number of levels of depth that we have already loaded - // LHS and RHS data for. Corresponding to the initial ld1 instructions - // above, this is currently 16. - "mov w1, #16\n" - - // Perform the first few multiply-adds on the data that we have already - // loaded. - "smull v8.8h, v0.8b, v4.8b\n" - "smull v9.8h, v1.8b, v4.8b\n" - "smull v10.8h, v2.8b, v4.8b\n" - "smull v11.8h, v3.8b, v4.8b\n" - "smull v12.8h, v0.8b, v5.8b\n" - "smull v13.8h, v1.8b, v5.8b\n" - "smull v14.8h, v2.8b, v5.8b\n" - "smull v15.8h, v3.8b, v5.8b\n" - - // Multiply-accumulate second-half, again into the same - // 16bit local accumulator registers. This is where we - // take advantage of having int8 instead of uint8 and therefore - // being able to accumulate two products into int16. - "smlal2 v8.8h, v0.16b, v4.16b\n" - "smlal2 v9.8h, v1.16b, v4.16b\n" - "smlal2 v10.8h, v2.16b, v4.16b\n" - "smlal2 v11.8h, v3.16b, v4.16b\n" - "smlal2 v12.8h, v0.16b, v5.16b\n" - "smlal2 v13.8h, v1.16b, v5.16b\n" - "smlal2 v14.8h, v2.16b, v5.16b\n" - "smlal2 v15.8h, v3.16b, v5.16b\n" - - - // Main loop of the whole GEMM, over rows and columns of the - // destination matrix. - "1:\n" - - // Reminder - w1 is how many levels of depth we have already loaded - // data for, w12 is the total depth. - "cmp w1, w12\n" - "beq 79f\n" - - "2:\n" - - // Some multiplications and 16-bit accumulation were already done above, - // so we start right away in the middle. - "sadalp v16.4s, v8.8h\n" - "ld1 {v4.16b}, [%[rhs_ptr]], #16\n" - "smull v8.8h, v0.8b, v6.8b\n" - "sadalp v17.4s, v9.8h\n" - "ld1 {v5.16b}, [%[rhs_ptr]], #16\n" - "smull v9.8h, v1.8b, v6.8b\n" - "sadalp v18.4s, v10.8h\n" - "smull v10.8h, v2.8b, v6.8b\n" - "sadalp v19.4s, v11.8h\n" - "smull v11.8h, v3.8b, v6.8b\n" - "sadalp v20.4s, v12.8h\n" - "smull v12.8h, v0.8b, v7.8b\n" - "sadalp v21.4s, v13.8h\n" - "smull v13.8h, v1.8b, v7.8b\n" - "sadalp v22.4s, v14.8h\n" - "smull v14.8h, v2.8b, v7.8b\n" - "sadalp v23.4s, v15.8h\n" - "smull v15.8h, v3.8b, v7.8b\n" - - // Multiply-accumulate second-half, again into the same - // 16bit local accumulator registers. This is where we - // take advantage of having int8 instead of uint8 and therefore - // being able to accumulate two products into int16. - "smlal2 v8.8h, v0.16b, v6.16b\n" - "smlal2 v9.8h, v1.16b, v6.16b\n" - "smlal2 v10.8h, v2.16b, v6.16b\n" - "smlal2 v11.8h, v3.16b, v6.16b\n" - - "ld1 {v6.16b}, [%[rhs_ptr]], #16\n" - - "smlal2 v12.8h, v0.16b, v7.16b\n" - "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" - "smlal2 v13.8h, v1.16b, v7.16b\n" - "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" - "smlal2 v14.8h, v2.16b, v7.16b\n" - "ld1 {v2.16b}, [%[lhs_ptr]], #16\n" - "smlal2 v15.8h, v3.16b, v7.16b\n" - "ld1 {v3.16b}, [%[lhs_ptr]], #16\n" - - "sadalp v24.4s, v8.8h\n" - "smull v8.8h, v0.8b, v4.8b\n" - "sadalp v25.4s, v9.8h\n" - "ld1 {v7.16b}, [%[rhs_ptr]], #16\n" - "smull v9.8h, v1.8b, v4.8b\n" - "sadalp v26.4s, v10.8h\n" - "smull v10.8h, v2.8b, v4.8b\n" - "sadalp v27.4s, v11.8h\n" - "smull v11.8h, v3.8b, v4.8b\n" - "sadalp v28.4s, v12.8h\n" - "smull v12.8h, v0.8b, v5.8b\n" - "sadalp v29.4s, v13.8h\n" - "smull v13.8h, v1.8b, v5.8b\n" - "sadalp v30.4s, v14.8h\n" - "smull v14.8h, v2.8b, v5.8b\n" - "sadalp v31.4s, v15.8h\n" - "smull v15.8h, v3.8b, v5.8b\n" - - // Multiply-accumulate second-half, again into the same - // 16bit local accumulator registers. This is where we - // take advantage of having int8 instead of uint8 and therefore - // being able to accumulate two products into int16. - "smlal2 v8.8h, v0.16b, v4.16b\n" - "smlal2 v9.8h, v1.16b, v4.16b\n" - "smlal2 v10.8h, v2.16b, v4.16b\n" - "smlal2 v11.8h, v3.16b, v4.16b\n" - - "smlal2 v12.8h, v0.16b, v5.16b\n" - "smlal2 v13.8h, v1.16b, v5.16b\n" - "smlal2 v14.8h, v2.16b, v5.16b\n" - "smlal2 v15.8h, v3.16b, v5.16b\n" - - - - // Each iteration of this loop advances by 16 levels of depth. - "add w1, w1, #16\n" - - // Loop termination condition - "cmp w1, w12\n" - - "blt 2b\n" - - "79:\n" - - "sadalp v16.4s, v8.8h\n" - "smull v8.8h, v0.8b, v6.8b\n" - "sadalp v17.4s, v9.8h\n" - "smull v9.8h, v1.8b, v6.8b\n" - "sadalp v18.4s, v10.8h\n" - "smull v10.8h, v2.8b, v6.8b\n" - "sadalp v19.4s, v11.8h\n" - "smull v11.8h, v3.8b, v6.8b\n" - "sadalp v20.4s, v12.8h\n" - "smull v12.8h, v0.8b, v7.8b\n" - "sadalp v21.4s, v13.8h\n" - "smull v13.8h, v1.8b, v7.8b\n" - "sadalp v22.4s, v14.8h\n" - "smull v14.8h, v2.8b, v7.8b\n" - "sadalp v23.4s, v15.8h\n" - "smull v15.8h, v3.8b, v7.8b\n" - - // Multiply-accumulate second-half, again into the same - // 16bit local accumulator registers. This is where we - // take advantage of having int8 instead of uint8 and therefore - // being able to accumulate two products into int16. - "smlal2 v8.8h, v0.16b, v6.16b\n" - "smlal2 v9.8h, v1.16b, v6.16b\n" - "smlal2 v10.8h, v2.16b, v6.16b\n" - "smlal2 v11.8h, v3.16b, v6.16b\n" - - "smlal2 v12.8h, v0.16b, v7.16b\n" - "smlal2 v13.8h, v1.16b, v7.16b\n" - "smlal2 v14.8h, v2.16b, v7.16b\n" - "smlal2 v15.8h, v3.16b, v7.16b\n" - - "sadalp v24.4s, v8.8h\n" - "sadalp v25.4s, v9.8h\n" - "sadalp v26.4s, v10.8h\n" - "sadalp v27.4s, v11.8h\n" - "sadalp v28.4s, v12.8h\n" - "sadalp v29.4s, v13.8h\n" - "sadalp v30.4s, v14.8h\n" - "sadalp v31.4s, v15.8h\n" - - // End of accumulation. The registers v16 -- v31 contain the final - // int32 accumulator values of the current 4x4 destination block. - // We now have to compute the final 8-bit values from these int32 - // accumulators, and advance to the next 4x4 block. We intertwine - // these two aspects whenever possible for optimal pipelining, both - // at the data flow level (prefetch data for next block as early as - // possible) and instruction pipelining level (some of the next-block - // work can dual-issue with some of the final work on the current - // block). - - // Reduce 32bit accumulators horizontally. - "addp v16.4s, v16.4s, v17.4s\n" - "addp v18.4s, v18.4s, v19.4s\n" - "addp v20.4s, v20.4s, v21.4s\n" - "addp v22.4s, v22.4s, v23.4s\n" - "addp v24.4s, v24.4s, v25.4s\n" - "addp v26.4s, v26.4s, v27.4s\n" - "addp v28.4s, v28.4s, v29.4s\n" - "addp v30.4s, v30.4s, v31.4s\n" - - // Reduce 32bit accumulators horizontally, second pass - // (each pass adds pairwise. we need to add 4-wise). - "addp v16.4s, v16.4s, v18.4s\n" - "addp v17.4s, v20.4s, v22.4s\n" - "addp v18.4s, v24.4s, v26.4s\n" - "addp v19.4s, v28.4s, v30.4s\n" - - // Logic to advance to the next block in preparation for the next - // iteration of the main loop. For now, we only want to compute - // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are - // not yet ready to update the values of row and col, as we still need - // the current values for the rest of the work on the current block. - - "cmp %w[row], w7\n" // Have we finished the last row? - "bge 4f\n" // If finished last row, go to 4 - // Not finished last row: then advance to next row. - "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #2\n" - "b 5f\n" - "4:\n" // Finished last row... - "mov %[lhs_col_ptr], x5\n" // Go back to first row - // Now we need to advance to the next column. If we already - // finished the last column, then in principle we are done, however - // we can't just return here, as we need to allow the end work of the - // current block to complete. The good news is that at this point it - // doesn't matter what data we load for the next column, since - // we will exit from the main loop below before actually storing - // anything computed from that data. - "cmp %w[col], w8\n" // Have we finished the last column? - "bge 5f\n" // If yes, just carry on without updating the column pointer. - // Not finished last column: then advance to next column. - "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #2\n" - "5:\n" - - // Set the LHS and RHS data pointers to the start of the columns just - // computed. - "mov %[lhs_ptr], %[lhs_col_ptr]\n" - "mov %[rhs_ptr], %[rhs_col_ptr]\n" - - // Load some parameters needed for the end work on current block. - RUY_MAKE_ZERO(v8) - "ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" - "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n" - "ins v13.h[4], w4\n" // dst_zero_point - "ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n" - "ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" - "dup v9.4s, w3\n" // create prod_zp_depth_vec - "add x5, x4, %x[row], lsl #2\n" - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" - "csel x4, x4, x5, eq\n" - - "ld1 {v15.4s}, [x4]\n" // multiplier_fixedpoint - - // Now we load: bias data, LHS sums data, RHS sums data. - - // First, load the base pointers from the params. - "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" - - "add x5, x1, %x[row], lsl #2\n" - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" - "csel x1, x1, x5, eq\n" - - // Load 4 bias values. - "ld1 {v14.4s}, [x1]\n" - - // Now that we know what LHS and RHS data the next iteration of the - // main loop will need to load, we start loading the first 32 bytes of - // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore - // in the rest of the work on the current block. - "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v2.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v3.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v4.16b}, [%[rhs_ptr]], #16\n" - "ld1 {v5.16b}, [%[rhs_ptr]], #16\n" - "ld1 {v6.16b}, [%[rhs_ptr]], #16\n" - "ld1 {v7.16b}, [%[rhs_ptr]], #16\n" - - // Add to the bias values the product (depth * lhs_zero_point * rhs_zero_point), - // See the term NZ1Z2 in equation (7) in https://arxiv.org/pdf/1712.05877.pdf - "add v14.4s, v14.4s, v9.4s\n" - - // Perform the bias-addition (per the above, we have just folded into - // the bias the (depth * lhs_zero_point * rhs_zero_point) term.) - "add v16.4s, v16.4s, v14.4s\n" - "add v17.4s, v17.4s, v14.4s\n" - "add v18.4s, v18.4s, v14.4s\n" - "add v19.4s, v19.4s, v14.4s\n" - - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n" - "beq 401f\n" - "ldr x3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n" - "add x3, x3, %x[col], lsl #2\n" - "ld1 {v14.4s}, [x3]\n" - "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n" - "dup v10.4s, w5\n" // create lhs_zero_point_vec - // Subtract rhs_sums * lhs_zero_point, per - // equation (7) in https://arxiv.org/pdf/1712.05877.pdf - "mls v16.4s, v10.4s, v14.s[0]\n" - "mls v17.4s, v10.4s, v14.s[1]\n" - "mls v18.4s, v10.4s, v14.s[2]\n" - "mls v19.4s, v10.4s, v14.s[3]\n" - "401:\n" - - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n" - "beq 402f\n" - "ldr x2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n" - "add x2, x2, %x[row], lsl #2\n" - "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n" - // Load 4 lhs_sums values. - "ld1 {v11.4s}, [x2]\n" - "ins v13.s[1], w5\n" // rhs_zero_point - // Compute lhs_sums * rhs_zero_point. - "mul v11.4s, v11.4s, v13.s[1]\n" - // Subtract lhs_sums * rhs_zero_point, per - // equation (7) in https://arxiv.org/pdf/1712.05877.pdf - "sub v16.4s, v16.4s, v11.4s\n" - "sub v17.4s, v17.4s, v11.4s\n" - "sub v18.4s, v18.4s, v11.4s\n" - "sub v19.4s, v19.4s, v11.4s\n" - - // If the destination is int32, it means the user asks for the raw - // accumulators, no need for us to downquantize the value. - "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n" - "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n" - - "402:\n" - - // At this point we have computed the final int32 values. Now we - // start down-quantizing them to obtain the final 8bit values from them. - - // As part of this down-quantization, our int32 values will be - // multiplied by a multiplier that has a fixed-point component and an - // exponent component. - - //Load the exponent part of the multiplier. - "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n" - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" - "add x5, x1, %x[row], lsl #2\n" - "csel x1, x1, x5, eq\n" - - "ld1 {v14.4s}, [x1]\n" - - "smax v12.4s, v14.4s, v8.4s\n" - - "sshl v16.4s, v16.4s, v12.4s\n" - "sshl v17.4s, v17.4s, v12.4s\n" - "sshl v18.4s, v18.4s, v12.4s\n" - "sshl v19.4s, v19.4s, v12.4s\n" - - "smin v12.4s, v14.4s, v8.4s\n" - - // Apply the fixed-point part of the multiplier. - "sqrdmulh v16.4s, v16.4s, v15.4s\n" - "sqrdmulh v17.4s, v17.4s, v15.4s\n" - "sqrdmulh v18.4s, v18.4s, v15.4s\n" - "sqrdmulh v19.4s, v19.4s, v15.4s\n" - - // We have some rounding division-by-power-of-two to do. This should - // always use "round to nearest". We allow for some - // freedom in how ties are broken, to strike a good compromise of - // performance on given hardware vs. perfect agreement of results - // across hardware. - // - // When RUY_OPT_NATIVE_ROUNDING is enabled, we allow for implementation - // defined tie-breaks to help performance. On NEON, this means that we - // can just use the NEON rounding instructions, such as srshl. They - // happen to be breaking ties upward. - // - // When RUY_OPT_NATIVE_ROUNDING is disabled, we implement strict - // break-ties-away-from zero, as described in Appendix B of - // https://arxiv.org/pdf/1712.05877.pdf - // When we wrote that, we thought that that would be better unbiased - // than the NEON upwards tie-breaks, and we had observed some - // improvement on some model. However, that is only more unbiased for - // data centered at zero, which was likely the case in that model, - // but is not always the case. If we wanted something more consistently - // unbiased then we should try breaking ties toward-nearest-even. -#if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING) - // Fix up values to be right-shifted, so that the (round to nearest, - // break ties upward) behavior of srshl applied to these fixed-up - // values, produces the same result as the desired (round to nearest, - // break ties away from zero) behavior on the original values. - "and v8.16b, v16.16b, v12.16b\n" - "and v9.16b, v17.16b, v12.16b\n" - "and v14.16b, v18.16b, v12.16b\n" - "and v15.16b, v19.16b, v12.16b\n" - "sshr v8.4s, v8.4s, #31\n" - "sshr v9.4s, v9.4s, #31\n" - "sshr v14.4s, v14.4s, #31\n" - "sshr v15.4s, v15.4s, #31\n" - "sqadd v16.4s, v16.4s, v8.4s\n" - "sqadd v17.4s, v17.4s, v9.4s\n" - "sqadd v18.4s, v18.4s, v14.4s\n" - "sqadd v19.4s, v19.4s, v15.4s\n" -#endif - // At this point we have reduced the problem of correctly implementing - // rounding divide-by-power-of-two, to what the SRSHL instruction can - // do. - "srshl v16.4s, v16.4s, v12.4s\n" - "srshl v17.4s, v17.4s, v12.4s\n" - "srshl v18.4s, v18.4s, v12.4s\n" - "srshl v19.4s, v19.4s, v12.4s\n" - - "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n" - "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n" - "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n" - "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n" - - RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n" - - // Cast-and-saturate from int32 to int16 - "sqxtn v16.4h, v16.4s\n" - "sqxtn2 v16.8h, v17.4s\n" - "sqxtn v17.4h, v18.4s\n" - "sqxtn2 v17.8h, v19.4s\n" - - // At this point, v18 -- v31 aren't used anymore for the current block, - // so we can start clearing these accumulators for the next block - // (next iteration of the main loop). - RUY_MAKE_ZERO(v18) - RUY_MAKE_ZERO(v19) - RUY_MAKE_ZERO(v20) - RUY_MAKE_ZERO(v21) - RUY_MAKE_ZERO(v22) - RUY_MAKE_ZERO(v23) - RUY_MAKE_ZERO(v24) - RUY_MAKE_ZERO(v25) - RUY_MAKE_ZERO(v26) - RUY_MAKE_ZERO(v27) - RUY_MAKE_ZERO(v28) - RUY_MAKE_ZERO(v29) - RUY_MAKE_ZERO(v30) - RUY_MAKE_ZERO(v31) - - // Add the destination zero point - "dup v14.8h, v13.h[4]\n" - "add v16.8h, v16.8h, v14.8h\n" - "add v17.8h, v17.8h, v14.8h\n" - - // Cast-and-saturate from int16 to uint8 - "sqxtun v16.8b, v16.8h\n" - "sqxtun2 v16.16b, v17.8h\n" - - // Load the clamp_min, clamp_max bounds - "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" - "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" - "dup v14.16b, w2\n" // clamp_min - "dup v15.16b, w3\n" // clamp_max - - // Apply the clamp_min bound - "umax v16.16b, v16.16b, v14.16b\n" - // Apply the clamp_max bound - "umin v16.16b, v16.16b, v15.16b\n" - - // Compute how much of the 4x4 block of destination 8bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 4x4, there are some 4x4 blocks along the boundaries that do - // not fit entirely. - "sub w1, %w[dst_rows], %w[row]\n" - "sub w2, %w[dst_cols], %w[col]\n" - "mov w3, #4\n" - "cmp w1, #4\n" - // Compute w1 = how many rows of the 4x4 block fit - "csel w1, w1, w3, le\n" - "cmp w2, #4\n" - // Compute w2 = how many cols of the 4x4 block fit - "csel w2, w2, w3, le\n" - - // Test if w1==4 && w2 == 4, i.e. if all of the 4x4 block fits. - "cmp w1, w3\n" - "ccmp w2, w3, 0, eq\n" - "mov x4, %[dst_ptr]\n" - // Yes, all of the 4x4 block fits, go to fast path. - "beq 30f\n" - // Not all of the 4x4 block fits. - // Store to dst_tmp_buf - "st1 {v16.16b}, [%[dst_tmp_buf]]\n" - // Slow loop copying from dst_tmp_buf to dst. - "mov x3, %[dst_tmp_buf]\n" - "mov w6, #0\n" - "50:\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov w5, #0\n" - "51:\n" - "ldrb w7, [x3, w5, uxtw]\n" - "strb w7, [x4, w5, uxtw]\n" - "add w5, w5, #1\n" - "cmp w5, w1\n" - "blt 51b\n" - "add w6, w6, #1\n" - "add x3, x3, #4\n" - "add x4, x4, x11\n" - "cmp w6, w2\n" - "blt 50b\n" - "b 31f\n" - "30:\n" - // Yes, all of the 4x4 block fits. - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v16.b}[0], [x3], #1\n" - "add x4, x4, x11\n" - "st1 {v16.b}[1], [x3], #1\n" - "st1 {v16.b}[2], [x3], #1\n" - "st1 {v16.b}[3], [x3], #1\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v16.b}[4], [x3], #1\n" - "add x4, x4, x11\n" - "st1 {v16.b}[5], [x3], #1\n" - "st1 {v16.b}[6], [x3], #1\n" - "st1 {v16.b}[7], [x3], #1\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v16.b}[8], [x3], #1\n" - "add x4, x4, x11\n" - "st1 {v16.b}[9], [x3], #1\n" - "st1 {v16.b}[10], [x3], #1\n" - "st1 {v16.b}[11], [x3], #1\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v16.b}[12], [x3], #1\n" - "add x4, x4, x11\n" - "st1 {v16.b}[13], [x3], #1\n" - "st1 {v16.b}[14], [x3], #1\n" - "st1 {v16.b}[15], [x3], #1\n" - "31:\n" - - "add %[dst_ptr], %[dst_ptr], #4\n" - - RUY_MAKE_ZERO(v16) - RUY_MAKE_ZERO(v17) - - "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" - - RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n" - - // Cast-and-saturate from int32 to int16 - "sqxtn v16.4h, v16.4s\n" - "sqxtn2 v16.8h, v17.4s\n" - "sqxtn v17.4h, v18.4s\n" - "sqxtn2 v17.8h, v19.4s\n" - - // At this point, v18 -- v31 aren't used anymore for the current block, - // so we can start clearing these accumulators for the next block - // (next iteration of the main loop). - RUY_MAKE_ZERO(v18) - RUY_MAKE_ZERO(v19) - RUY_MAKE_ZERO(v20) - RUY_MAKE_ZERO(v21) - RUY_MAKE_ZERO(v22) - RUY_MAKE_ZERO(v23) - RUY_MAKE_ZERO(v24) - RUY_MAKE_ZERO(v25) - RUY_MAKE_ZERO(v26) - RUY_MAKE_ZERO(v27) - RUY_MAKE_ZERO(v28) - RUY_MAKE_ZERO(v29) - RUY_MAKE_ZERO(v30) - RUY_MAKE_ZERO(v31) - - // Add the destination zero point - "dup v14.8h, v13.h[4]\n" - "add v16.8h, v16.8h, v14.8h\n" - "add v17.8h, v17.8h, v14.8h\n" - - // Cast-and-saturate from int16 to int8 - "sqxtn v16.8b, v16.8h\n" - "sqxtn2 v16.16b, v17.8h\n" - - // Load the clamp_min, clamp_max bounds - "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" - "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" - "dup v14.16b, w2\n" // clamp_min - "dup v15.16b, w3\n" // clamp_max - - // Apply the clamp_min bound - "smax v16.16b, v16.16b, v14.16b\n" - // Apply the clamp_max bound - "smin v16.16b, v16.16b, v15.16b\n" - - // Compute how much of the 4x4 block of destination 8bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 4x4, there are some 4x4 blocks along the boundaries that do - // not fit entirely. - "sub w1, %w[dst_rows], %w[row]\n" - "sub w2, %w[dst_cols], %w[col]\n" - "mov w3, #4\n" - "cmp w1, #4\n" - // Compute w1 = how many rows of the 4x4 block fit - "csel w1, w1, w3, le\n" - "cmp w2, #4\n" - // Compute w2 = how many cols of the 4x4 block fit - "csel w2, w2, w3, le\n" - - // Test if w1==4 && w2 == 4, i.e. if all of the 4x4 block fits. - "cmp w1, w3\n" - "ccmp w2, w3, 0, eq\n" - "mov x4, %[dst_ptr]\n" - // Yes, all of the 4x4 block fits, go to fast path. - "beq 30f\n" - // Not all of the 4x4 block fits. - // Store to dst_tmp_buf - "st1 {v16.16b}, [%[dst_tmp_buf]]\n" - // Slow loop copying from dst_tmp_buf to dst. - "mov x3, %[dst_tmp_buf]\n" - "mov w6, #0\n" - "50:\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov w5, #0\n" - "51:\n" - "ldrb w7, [x3, w5, uxtw]\n" - "strb w7, [x4, w5, uxtw]\n" - "add w5, w5, #1\n" - "cmp w5, w1\n" - "blt 51b\n" - "add w6, w6, #1\n" - "add x3, x3, #4\n" - "add x4, x4, x11\n" - "cmp w6, w2\n" - "blt 50b\n" - "b 31f\n" - "30:\n" - // Yes, all of the 4x4 block fits. - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v16.b}[0], [x3], #1\n" - "add x4, x4, x11\n" - "st1 {v16.b}[1], [x3], #1\n" - "st1 {v16.b}[2], [x3], #1\n" - "st1 {v16.b}[3], [x3], #1\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v16.b}[4], [x3], #1\n" - "add x4, x4, x11\n" - "st1 {v16.b}[5], [x3], #1\n" - "st1 {v16.b}[6], [x3], #1\n" - "st1 {v16.b}[7], [x3], #1\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v16.b}[8], [x3], #1\n" - "add x4, x4, x11\n" - "st1 {v16.b}[9], [x3], #1\n" - "st1 {v16.b}[10], [x3], #1\n" - "st1 {v16.b}[11], [x3], #1\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v16.b}[12], [x3], #1\n" - "add x4, x4, x11\n" - "st1 {v16.b}[13], [x3], #1\n" - "st1 {v16.b}[14], [x3], #1\n" - "st1 {v16.b}[15], [x3], #1\n" - "31:\n" - - "add %[dst_ptr], %[dst_ptr], #4\n" - - RUY_MAKE_ZERO(v16) - RUY_MAKE_ZERO(v17) - - "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" - - RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n" - - // Add the destination zero point - "dup v14.4h, v13.h[4]\n" - "saddw v16.4s, v16.4s, v14.4h\n" - "saddw v17.4s, v17.4s, v14.4h\n" - "saddw v18.4s, v18.4s, v14.4h\n" - "saddw v19.4s, v19.4s, v14.4h\n" - - // Cast-and-saturate from int32 to int16 - "sqxtn v16.4h, v16.4s\n" - "sqxtn2 v16.8h, v17.4s\n" - "sqxtn v17.4h, v18.4s\n" - "sqxtn2 v17.8h, v19.4s\n" - - // At this point, v18 -- v31 aren't used anymore for the current block, - // so we can start clearing these accumulators for the next block - // (next iteration of the main loop). - RUY_MAKE_ZERO(v18) - RUY_MAKE_ZERO(v19) - RUY_MAKE_ZERO(v20) - RUY_MAKE_ZERO(v21) - RUY_MAKE_ZERO(v22) - RUY_MAKE_ZERO(v23) - RUY_MAKE_ZERO(v24) - RUY_MAKE_ZERO(v25) - RUY_MAKE_ZERO(v26) - RUY_MAKE_ZERO(v27) - RUY_MAKE_ZERO(v28) - RUY_MAKE_ZERO(v29) - RUY_MAKE_ZERO(v30) - RUY_MAKE_ZERO(v31) - - // Load the clamp_min, clamp_max bounds - "ldrh w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" - "ldrh w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" - "dup v14.8h, w2\n" // clamp_min - "dup v15.8h, w3\n" // clamp_max - - // Apply the clamp_min bound - "smax v16.8h, v16.8h, v14.8h\n" - "smax v17.8h, v17.8h, v14.8h\n" - // Apply the clamp_max bound - "smin v16.8h, v16.8h, v15.8h\n" - "smin v17.8h, v17.8h, v15.8h\n" - - // Compute how much of the 4x4 block of destination 8bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 4x4, there are some 4x4 blocks along the boundaries that do - // not fit entirely. - "sub w1, %w[dst_rows], %w[row]\n" - "sub w2, %w[dst_cols], %w[col]\n" - "mov w3, #4\n" - "cmp w1, #4\n" - // Compute w1 = how many rows of the 4x4 block fit - "csel w1, w1, w3, le\n" - "cmp w2, #4\n" - // Compute w2 = how many cols of the 4x4 block fit - "csel w2, w2, w3, le\n" - - // Test if w1==4 && w2 == 4, i.e. if all of the 8x8 block fits. - "cmp w1, w3\n" - "ccmp w2, w3, 0, eq\n" - "mov x4, %[dst_ptr]\n" - // Yes, all of the 4x4 block fits, go to fast path. - "beq 30f\n" - // Not all of the 4x4 block fits. - // Store to dst_tmp_buf - "str q16, [%[dst_tmp_buf], #0]\n" - "str q17, [%[dst_tmp_buf], #16]\n" - // Slow loop copying from dst_tmp_buf to dst. - "mov x3, %[dst_tmp_buf]\n" - "mov w6, #0\n" - "50:\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov w5, #0\n" - "51:\n" - "ldrh w7, [x3, x5, lsl #1]\n" - "strh w7, [x4, x5, lsl #1]\n" - "add w5, w5, #1\n" - "cmp w5, w1\n" - "blt 51b\n" - "add w6, w6, #1\n" - "add x3, x3, #8\n" - "add x4, x4, x11\n" - "cmp w6, w2\n" - "blt 50b\n" - "b 31f\n" - "30:\n" - // Yes, all of the 4x4 block fits. - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v16.h}[0], [x3], #2\n" - "add x4, x4, x11\n" - "st1 {v16.h}[1], [x3], #2\n" - "st1 {v16.h}[2], [x3], #2\n" - "st1 {v16.h}[3], [x3], #2\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v16.h}[4], [x3], #2\n" - "add x4, x4, x11\n" - "st1 {v16.h}[5], [x3], #2\n" - "st1 {v16.h}[6], [x3], #2\n" - "st1 {v16.h}[7], [x3], #2\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v17.h}[0], [x3], #2\n" - "add x4, x4, x11\n" - "st1 {v17.h}[1], [x3], #2\n" - "st1 {v17.h}[2], [x3], #2\n" - "st1 {v17.h}[3], [x3], #2\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v17.h}[4], [x3], #2\n" - "add x4, x4, x11\n" - "st1 {v17.h}[5], [x3], #2\n" - "st1 {v17.h}[6], [x3], #2\n" - "st1 {v17.h}[7], [x3], #2\n" - "31:\n" - - "add %[dst_ptr], %[dst_ptr], #8\n" - - RUY_MAKE_ZERO(v16) - RUY_MAKE_ZERO(v17) - - "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" - - RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n" - - // Since the store type is the same as the accum type, no need for - // downcast. There's also no need for clamp by min/max. - - // At this point, v20 -- v31 aren't used anymore for the current block, - // so we can start clearing these accumulators for the next block - // (next iteration of the main loop). - RUY_MAKE_ZERO(v20) - RUY_MAKE_ZERO(v21) - RUY_MAKE_ZERO(v22) - RUY_MAKE_ZERO(v23) - RUY_MAKE_ZERO(v24) - RUY_MAKE_ZERO(v25) - RUY_MAKE_ZERO(v26) - RUY_MAKE_ZERO(v27) - RUY_MAKE_ZERO(v28) - RUY_MAKE_ZERO(v29) - RUY_MAKE_ZERO(v30) - RUY_MAKE_ZERO(v31) - - // Compute how much of the 4x4 block of destination 8bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 4x4, there are some 4x4 blocks along the boundaries that do - // not fit entirely. - "sub w1, %w[dst_rows], %w[row]\n" - "sub w2, %w[dst_cols], %w[col]\n" - "mov w3, #4\n" - "cmp w1, #4\n" - // Compute w1 = how many rows of the 4x4 block fit - "csel w1, w1, w3, le\n" - "cmp w2, #4\n" - // Compute w2 = how many cols of the 4x4 block fit - "csel w2, w2, w3, le\n" - - // Test if w1==4 && w2 == 4, i.e. if all of the 8x8 block fits. - "cmp w1, w3\n" - "ccmp w2, w3, 0, eq\n" - "mov x4, %[dst_ptr]\n" - // Yes, all of the 4x4 block fits, go to fast path. - "beq 30f\n" - // Not all of the 4x4 block fits. - // Store to dst_tmp_buf - "str q16, [%[dst_tmp_buf], #0]\n" - "str q17, [%[dst_tmp_buf], #16]\n" - "str q18, [%[dst_tmp_buf], #32]\n" - "str q19, [%[dst_tmp_buf], #48]\n" - // Slow loop copying from dst_tmp_buf to dst. - "mov x3, %[dst_tmp_buf]\n" - "mov w6, #0\n" - "50:\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov w5, #0\n" - "51:\n" - "ldr w7, [x3, x5, lsl #2]\n" - "str w7, [x4, x5, lsl #2]\n" - "add w5, w5, #1\n" - "cmp w5, w1\n" - "blt 51b\n" - "add w6, w6, #1\n" - "add x3, x3, #16\n" - "add x4, x4, x11\n" - "cmp w6, w2\n" - "blt 50b\n" - "b 31f\n" - "30:\n" - // Yes, all of the 4x4 block fits. - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v16.s}[0], [x3], #4\n" - "add x4, x4, x11\n" - "st1 {v16.s}[1], [x3], #4\n" - "st1 {v16.s}[2], [x3], #4\n" - "st1 {v16.s}[3], [x3], #4\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v17.s}[0], [x3], #4\n" - "add x4, x4, x11\n" - "st1 {v17.s}[1], [x3], #4\n" - "st1 {v17.s}[2], [x3], #4\n" - "st1 {v17.s}[3], [x3], #4\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v18.s}[0], [x3], #4\n" - "add x4, x4, x11\n" - "st1 {v18.s}[1], [x3], #4\n" - "st1 {v18.s}[2], [x3], #4\n" - "st1 {v18.s}[3], [x3], #4\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v19.s}[0], [x3], #4\n" - "add x4, x4, x11\n" - "st1 {v19.s}[1], [x3], #4\n" - "st1 {v19.s}[2], [x3], #4\n" - "st1 {v19.s}[3], [x3], #4\n" - "31:\n" - - "add %[dst_ptr], %[dst_ptr], #16\n" - - RUY_MAKE_ZERO(v16) - RUY_MAKE_ZERO(v17) - RUY_MAKE_ZERO(v18) - RUY_MAKE_ZERO(v19) - - RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n" - - // For the next block: perform the first few multiply-adds on the data - // that we have already loaded. - "smull v8.8h, v0.8b, v4.8b\n" - "smull v9.8h, v1.8b, v4.8b\n" - "smull v10.8h, v2.8b, v4.8b\n" - "smull v11.8h, v3.8b, v4.8b\n" - "smull v12.8h, v0.8b, v5.8b\n" - "smull v13.8h, v1.8b, v5.8b\n" - "smull v14.8h, v2.8b, v5.8b\n" - "smull v15.8h, v3.8b, v5.8b\n" - "smlal2 v8.8h, v0.16b, v4.16b\n" - "smlal2 v9.8h, v1.16b, v4.16b\n" - "smlal2 v10.8h, v2.16b, v4.16b\n" - "smlal2 v11.8h, v3.16b, v4.16b\n" - "smlal2 v12.8h, v0.16b, v5.16b\n" - "smlal2 v13.8h, v1.16b, v5.16b\n" - "smlal2 v14.8h, v2.16b, v5.16b\n" - "smlal2 v15.8h, v3.16b, v5.16b\n" - - // Reload some params --- we had used x5 -- x7 for a few other things - // since the last time we had loaded them. - "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" - "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" - - // Move to the next block of the destination matrix, for the next iter - // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already - // been updated earlier. - // Have we reached the end row? - "cmp %w[row], w7\n" - "beq 20f\n" // yes, end row. - // Not end row. Move to the next row. - "add %w[row], %w[row], #4\n" - "b 21f\n" - "20:\n" - // Was already at end row. - "mov %w[row], w6\n" // Move back to first row. - "add %w[col], %w[col], #4\n" // Move to the next column. - "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #2\n" - "mov %[dst_ptr], %[dst_col_ptr]\n" - "21:\n" - - // Main loop exit condition: have we hit the end column? - "cmp %w[col], w8\n" - - // w1 is the number of levels of depth that we have already loaded - // LHS and RHS data for. Corresponding to the initial ld1 instructions - // above, this is currently 4. - "mov w1, #16\n" - - "ble 1b\n" - - // clang-format on - - : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr), - [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), - [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col) - : [ params ] "r"(¶ms), [dst_rows] "r"(params.dst_rows), - [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf), - [dst_type_id] "r"(params.dst_type_id) - : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc", - "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", - "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", - "v26", "v27", "v28", "v29", "v30", "v31"); -} - -// Similar to existing Kernel8bitNeonOutOfOrder but specialized for the case of -// RHS cols == 1. -// Relevant target CPUs for this kernel include ARM Cortex-A73 and Cortex-A75, -// since these are 64-bit, out-of-order and without dotprod support. -void Kernel8bitNeonOutOfOrder1Col(const KernelParams8bit<4, 4>& params) { - profiler::ScopeLabel label( - "Kernel (kNeon, optimized for out-of-order cores)"); - - CheckOffsetsInKernelParams8bit(params); - - const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; - const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; - const std::int8_t* lhs_ptr = lhs_col_ptr; - const std::int8_t* rhs_ptr = rhs_col_ptr; - void* dst_col_ptr = params.dst_base_ptr; - void* dst_ptr = dst_col_ptr; - int row = params.start_row; - int col = params.start_col; - - // The asm kernel below has the following NEON register allocation: - // - // v16 -- v19 are int32 accumulators. - // During accumulation, v0 -- v3 are used to load int8 data from LHS and - // v4 from RHS: - // - // int8 RHS 16x1 block - // /-----------\ - // |v4.b[0] | - // | ... | - // |v4.b[15] | - // \-----------/ - // int8 LHS 4x16 block - // /---------------------\ /-----------\ - // |v0.b[0] ... v0.b[15] | |v16.4s | - // |v1.b[0] ... v1.b[15] | |v17.4s | - // |v2.b[0] ... v2.b[15] | |v18.4s | - // |v3.b[0] ... v3.b[15] | |v19.4s | - // \---------------------/ \-----------/ - // int32 accumulators 4x1 block - // - // No attempt had been made so far at implementing the RUY_OPT_MAX_STREAMING - // optimization for this kernel. - asm volatile( -#define RUY_MAKE_ZERO(reg) "dup " #reg ".4s, wzr\n" - - // clang-format off - - // Load some parameters into registers. - "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" - "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" - "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" - "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" - "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" - "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" - "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" - - // Load the first 64 bytes of LHS and RHS data. - "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v2.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v3.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v4.16b}, [%[rhs_ptr]], #16\n" - "add %[rhs_ptr], %[rhs_ptr], #48\n" - - // Clear accumulators. - RUY_MAKE_ZERO(v16) - RUY_MAKE_ZERO(v17) - RUY_MAKE_ZERO(v18) - RUY_MAKE_ZERO(v19) - - // w1 is the number of levels of depth that we have already loaded - // LHS and RHS data for. Corresponding to the initial ld1 instructions - // above, this is currently 16. - "mov w1, #16\n" - - // Perform the first few multiply-adds on the data that we have already - // loaded. - "smull v8.8h, v0.8b, v4.8b\n" - "smull v9.8h, v1.8b, v4.8b\n" - "smull v10.8h, v2.8b, v4.8b\n" - "smull v11.8h, v3.8b, v4.8b\n" - - // Multiply-accumulate second-half, again into the same - // 16bit local accumulator registers. This is where we - // take advantage of having int8 instead of uint8 and therefore - // being able to accumulate two products into int16. - "smlal2 v8.8h, v0.16b, v4.16b\n" - "smlal2 v9.8h, v1.16b, v4.16b\n" - "smlal2 v10.8h, v2.16b, v4.16b\n" - "smlal2 v11.8h, v3.16b, v4.16b\n" - - // Main loop of the whole GEMM, over rows and columns of the - // destination matrix. - "1:\n" - - // Reminder - w1 is how many levels of depth we have already loaded - // data for, w12 is the total depth. - "cmp w1, w12\n" - "beq 79f\n" - - "2:\n" - - // Some multiplications and 16-bit accumulation were already done above, - // so we start right away in the middle. - "sadalp v16.4s, v8.8h\n" - "ld1 {v4.16b}, [%[rhs_ptr]], #16\n" - "add %[rhs_ptr], %[rhs_ptr], #48\n" - "sadalp v17.4s, v9.8h\n" - "sadalp v18.4s, v10.8h\n" - "sadalp v19.4s, v11.8h\n" - - "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v2.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v3.16b}, [%[lhs_ptr]], #16\n" - - "smull v8.8h, v0.8b, v4.8b\n" - "smull v9.8h, v1.8b, v4.8b\n" - "smull v10.8h, v2.8b, v4.8b\n" - "smull v11.8h, v3.8b, v4.8b\n" - - // Multiply-accumulate second-half, again into the same - // 16bit local accumulator registers. This is where we - // take advantage of having int8 instead of uint8 and therefore - // being able to accumulate two products into int16. - "smlal2 v8.8h, v0.16b, v4.16b\n" - "smlal2 v9.8h, v1.16b, v4.16b\n" - "smlal2 v10.8h, v2.16b, v4.16b\n" - "smlal2 v11.8h, v3.16b, v4.16b\n" - - // Each iteration of this loop advances by 16 levels of depth. - "add w1, w1, #16\n" - - // Loop termination condition - "cmp w1, w12\n" - - "blt 2b\n" - - "79:\n" - - "sadalp v16.4s, v8.8h\n" - "sadalp v17.4s, v9.8h\n" - "sadalp v18.4s, v10.8h\n" - "sadalp v19.4s, v11.8h\n" - - // End of accumulation. The registers v16 -- v19 contain the final - // int32 accumulator values of the current 4x1 destination block. - // We now have to compute the final 8-bit values from these int32 - // accumulators, and advance to the next 4x1 block. We intertwine - // these two aspects whenever possible for optimal pipelining, both - // at the data flow level (prefetch data for next block as early as - // possible) and instruction pipelining level (some of the next-block - // work can dual-issue with some of the final work on the current - // block). - - // Reduce 32bit accumulators horizontally. - "addp v16.4s, v16.4s, v17.4s\n" - "addp v18.4s, v18.4s, v19.4s\n" - - // Reduce 32bit accumulators horizontally, second pass - // (each pass adds pairwise. we need to add 4-wise). - "addp v16.4s, v16.4s, v18.4s\n" - - // Logic to advance to the next block in preparation for the next - // iteration of the main loop. For now, we only want to compute - // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are - // not yet ready to update the values of row and col, as we still need - // the current values for the rest of the work on the current block. - - "cmp %w[row], w7\n" // Have we finished the last row? - "bge 4f\n" // If finished last row, go to 4 - // Not finished last row: then advance to next row. - "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #2\n" - "b 5f\n" - "4:\n" // Finished last row... - "mov %[lhs_col_ptr], x5\n" // Go back to first row - // Now we need to advance to the next column. If we already - // finished the last column, then in principle we are done, however - // we can't just return here, as we need to allow the end work of the - // current block to complete. The good news is that at this point it - // doesn't matter what data we load for the next column, since - // we will exit from the main loop below before actually storing - // anything computed from that data. - "cmp %w[col], w8\n" // Have we finished the last column? - "bge 5f\n" // If yes, just carry on without updating the column pointer. - // Not finished last column: then advance to next column. - // (still multiply column stride by 4 due to packing) - "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #2\n" - "5:\n" - - // Set the LHS and RHS data pointers to the start of the columns just - // computed. - "mov %[lhs_ptr], %[lhs_col_ptr]\n" - "mov %[rhs_ptr], %[rhs_col_ptr]\n" - - // Load some parameters needed for the end work on current block. - RUY_MAKE_ZERO(v8) - "ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" - "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n" - "ins v13.h[4], w4\n" // dst_zero_point - "ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n" - "ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" - "dup v9.4s, w3\n" // create prod_zp_depth_vec - "add x5, x4, %x[row], lsl #2\n" - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" - "csel x4, x4, x5, eq\n" - - "ld1 {v15.4s}, [x4]\n" // multiplier_fixedpoint - - // Now we load: bias data, LHS sums data, RHS sums data. - - // First, load the base pointers from the params. - "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" - - "add x5, x1, %x[row], lsl #2\n" - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" - "csel x1, x1, x5, eq\n" - - // Load 4 bias values. - "ld1 {v14.4s}, [x1]\n" - - // Now that we know what LHS and RHS data the next iteration of the - // main loop will need to load, we start loading the first 32 bytes of - // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore - // in the rest of the work on the current block. - "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v2.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v3.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v4.16b}, [%[rhs_ptr]], #16\n" - "add %[rhs_ptr], %[rhs_ptr], #48\n" - - // Add to the bias values the product (depth * lhs_zero_point * rhs_zero_point), - // See the term NZ1Z2 in equation (7) in https://arxiv.org/pdf/1712.05877.pdf - "add v14.4s, v14.4s, v9.4s\n" - - // Perform the bias-addition (per the above, we have just folded into - // the bias the (depth * lhs_zero_point * rhs_zero_point) term.) - // (all four 32-bit accumulators are in v16 at this point) - "add v16.4s, v16.4s, v14.4s\n" - - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n" - "beq 401f\n" - "ldr x3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n" - "add x3, x3, %x[col], lsl #2\n" - "ld1 {v14.4s}, [x3]\n" - "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n" - "dup v10.4s, w5\n" // create lhs_zero_point_vec - // Subtract rhs_sums * lhs_zero_point, per - // equation (7) in https://arxiv.org/pdf/1712.05877.pdf - "mls v16.4s, v10.4s, v14.s[0]\n" - "401:\n" - - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n" - "beq 402f\n" - "ldr x2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n" - "add x2, x2, %x[row], lsl #2\n" - "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n" - // Load 4 lhs_sums values. - "ld1 {v11.4s}, [x2]\n" - "ins v13.s[1], w5\n" // rhs_zero_point - // Compute lhs_sums * rhs_zero_point. - "mul v11.4s, v11.4s, v13.s[1]\n" - // Subtract lhs_sums * rhs_zero_point, per - // equation (7) in https://arxiv.org/pdf/1712.05877.pdf - "sub v16.4s, v16.4s, v11.4s\n" - - // If the destination is int32, it means the user asks for the raw - // accumulators, no need for us to downquantize the value. - "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n" - "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n" - - "402:\n" - - // At this point we have computed the final int32 values. Now we - // start down-quantizing them to obtain the final 8bit values from them. - - // As part of this down-quantization, our int32 values will be - // multiplied by a multiplier that has a fixed-point component and an - // exponent component. - - //Load the exponent part of the multiplier. - "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n" - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" - "add x5, x1, %x[row], lsl #2\n" - "csel x1, x1, x5, eq\n" - - "ld1 {v14.4s}, [x1]\n" - - "smax v12.4s, v14.4s, v8.4s\n" - - "sshl v16.4s, v16.4s, v12.4s\n" - - "smin v12.4s, v14.4s, v8.4s\n" - - // Apply the fixed-point part of the multiplier. - "sqrdmulh v16.4s, v16.4s, v15.4s\n" - - // We have some rounding division-by-power-of-two to do. This should - // always use "round to nearest". We allow for some - // freedom in how ties are broken, to strike a good compromise of - // performance on given hardware vs. perfect agreement of results - // across hardware. - // - // When RUY_OPT_NATIVE_ROUNDING is enabled, we allow for implementation - // defined tie-breaks to help performance. On NEON, this means that we - // can just use the NEON rounding instructions, such as srshl. They - // happen to be breaking ties upward. - // - // When RUY_OPT_NATIVE_ROUNDING is disabled, we implement strict - // break-ties-away-from zero, as described in Appendix B of - // https://arxiv.org/pdf/1712.05877.pdf - // When we wrote that, we thought that that would be better unbiased - // than the NEON upwards tie-breaks, and we had observed some - // improvement on some model. However, that is only more unbiased for - // data centered at zero, which was likely the case in that model, - // but is not always the case. If we wanted something more consistently - // unbiased then we should try breaking ties toward-nearest-even. -#if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING) - // Fix up values to be right-shifted, so that the (round to nearest, - // break ties upward) behavior of srshl applied to these fixed-up - // values, produces the same result as the desired (round to nearest, - // break ties away from zero) behavior on the original values. - "and v8.16b, v16.16b, v12.16b\n" - "sshr v8.4s, v8.4s, #31\n" - "sqadd v16.4s, v16.4s, v8.4s\n" -#endif - // At this point we have reduced the problem of correctly implementing - // rounding divide-by-power-of-two, to what the SRSHL instruction can - // do. - "srshl v16.4s, v16.4s, v12.4s\n" - - "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n" - "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n" - "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n" - "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n" - - RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n" - - // Cast-and-saturate from int32 to int16 - // After this instruction, all data is in lower half (64-bits) of v16 - "sqxtn v16.4h, v16.4s\n" - - // At this point, v18 -- v31 aren't used anymore for the current block, - // so we can start clearing these accumulators for the next block - // (next iteration of the main loop). - RUY_MAKE_ZERO(v18) - RUY_MAKE_ZERO(v19) - - // Add the destination zero point - "dup v14.8h, v13.h[4]\n" - "add v16.8h, v16.8h, v14.8h\n" - - // Cast-and-saturate from int16 to uint8 - // Now all data is in the first 32-bits of v16 - "sqxtun v16.8b, v16.8h\n" - - // Load the clamp_min, clamp_max bounds - "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" - "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" - "dup v14.16b, w2\n" // clamp_min - "dup v15.16b, w3\n" // clamp_max - - // Apply the clamp_min bound - "umax v16.16b, v16.16b, v14.16b\n" - // Apply the clamp_max bound - "umin v16.16b, v16.16b, v15.16b\n" - - // Compute how much of the 4x1 block of destination 8bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 4x1, there are some 4x1 blocks along the boundaries that do - // not fit entirely. - "sub w1, %w[dst_rows], %w[row]\n" - "mov w3, #4\n" - "cmp w1, #4\n" - // Compute w1 = how many rows of the 4x1 block fit - "csel w1, w1, w3, le\n" - - // Test if w1==4, i.e. if all of the 4x1 block fits. - "cmp w1, w3\n" - - "mov x4, %[dst_ptr]\n" - // Yes, all of the 4x1 block fits, go to fast path. - "beq 30f\n" - // Not all of the 4x1 block fits. - // Store to dst_tmp_buf - "st1 {v16.16b}, [%[dst_tmp_buf]]\n" - // Slow loop copying from dst_tmp_buf to dst. - "mov x3, %[dst_tmp_buf]\n" - "mov w6, #0\n" - "50:\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov w5, #0\n" - "51:\n" - "ldrb w7, [x3, w5, uxtw]\n" - "strb w7, [x4, w5, uxtw]\n" - "add w5, w5, #1\n" - "cmp w5, w1\n" - "blt 51b\n" - "b 31f\n" - "30:\n" - // Yes, all of the 4x1 block fits. - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v16.b}[0], [x3], #1\n" - "st1 {v16.b}[1], [x3], #1\n" - "st1 {v16.b}[2], [x3], #1\n" - "st1 {v16.b}[3], [x3], #1\n" - "31:\n" - - "add %[dst_ptr], %[dst_ptr], #4\n" - - RUY_MAKE_ZERO(v16) - RUY_MAKE_ZERO(v17) - - "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" - - RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n" - - // Cast-and-saturate from int32 to int16 - // After this, all values for output are in the lower half (64 bits) of v16. - "sqxtn v16.4h, v16.4s\n" - - // At this point, v18 -- v31 aren't used anymore for the current block, - // so we can start clearing these accumulators for the next block - // (next iteration of the main loop). - RUY_MAKE_ZERO(v18) - RUY_MAKE_ZERO(v19) - - // Add the destination zero point - "dup v14.8h, v13.h[4]\n" - "add v16.8h, v16.8h, v14.8h\n" - - // Cast-and-saturate from int16 to int8 - "sqxtn v16.8b, v16.8h\n" - // At this point, we only need 4 lowest 8-bit values in v16. - - // Load the clamp_min, clamp_max bounds - "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" - "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" - "dup v14.16b, w2\n" // clamp_min - "dup v15.16b, w3\n" // clamp_max - - // Apply the clamp_min bound - "smax v16.16b, v16.16b, v14.16b\n" - // Apply the clamp_max bound - "smin v16.16b, v16.16b, v15.16b\n" - - // Compute how much of the 4x4 block of destination 8bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 4x4, there are some 4x4 blocks along the boundaries that do - // not fit entirely. - "sub w1, %w[dst_rows], %w[row]\n" - "sub w2, %w[dst_cols], %w[col]\n" - "mov w3, #4\n" - "cmp w1, #4\n" - // Compute w1 = how many rows of the 4x1 block fit - "csel w1, w1, w3, le\n" - "cmp w2, #4\n" - - // Test if w1==4, i.e. if all of the 4x1 block fits. - "cmp w1, w3\n" - "ccmp w2, w3, 0, eq\n" - "mov x4, %[dst_ptr]\n" - // Yes, all of the 4x1 block fits, go to fast path. - "beq 30f\n" - // Not all of the 4x4 block fits. - // Store to dst_tmp_buf - "st1 {v16.16b}, [%[dst_tmp_buf]]\n" - // Slow loop copying from dst_tmp_buf to dst. - "mov x3, %[dst_tmp_buf]\n" - "mov w6, #0\n" - "50:\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov w5, #0\n" - "51:\n" - "ldrb w7, [x3, w5, uxtw]\n" - "strb w7, [x4, w5, uxtw]\n" - "add w5, w5, #1\n" - "cmp w5, w1\n" - "blt 51b\n" - "b 31f\n" - "30:\n" - // Yes, all of the 4x4 block fits. - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v16.b}[0], [x3], #1\n" - "st1 {v16.b}[1], [x3], #1\n" - "st1 {v16.b}[2], [x3], #1\n" - "st1 {v16.b}[3], [x3], #1\n" - "31:\n" - - "add %[dst_ptr], %[dst_ptr], #4\n" - - RUY_MAKE_ZERO(v16) - RUY_MAKE_ZERO(v17) - - "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" - - RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n" - - // Add the destination zero point - "dup v14.4h, v13.h[4]\n" - "saddw v16.4s, v16.4s, v14.4h\n" - - // Cast-and-saturate from int32 to int16 - // After this instruction, all data is in lower half of v16. - "sqxtn v16.4h, v16.4s\n" - - // At this point, v18 -- v31 aren't used anymore for the current block, - // so we can start clearing these accumulators for the next block - // (next iteration of the main loop). - RUY_MAKE_ZERO(v18) - RUY_MAKE_ZERO(v19) - - // Load the clamp_min, clamp_max bounds - "ldrh w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" - "ldrh w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" - "dup v14.8h, w2\n" // clamp_min - "dup v15.8h, w3\n" // clamp_max - - // Apply the clamp_min bound - "smax v16.8h, v16.8h, v14.8h\n" - // Apply the clamp_max bound - "smin v16.8h, v16.8h, v15.8h\n" - - // Compute how much of the 4x4 block of destination 8bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 4x4, there are some 4x4 blocks along the boundaries that do - // not fit entirely. - "sub w1, %w[dst_rows], %w[row]\n" - "sub w2, %w[dst_cols], %w[col]\n" - "mov w3, #4\n" - "cmp w1, #4\n" - // Compute w1 = how many rows of the 4x4 block fit - "csel w1, w1, w3, le\n" - "cmp w2, #4\n" - - // Test if w1==4 && w2 == 4, i.e. if all of the 8x8 block fits. - "cmp w1, w3\n" - "mov x4, %[dst_ptr]\n" - // Yes, all of the 4x4 block fits, go to fast path. - "beq 30f\n" - // Not all of the 4x4 block fits. - // Store to dst_tmp_buf - "str q16, [%[dst_tmp_buf], #0]\n" - // Slow loop copying from dst_tmp_buf to dst. - "mov x3, %[dst_tmp_buf]\n" - "mov w6, #0\n" - "50:\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov w5, #0\n" - "51:\n" - "ldrh w7, [x3, x5, lsl #1]\n" - "strh w7, [x4, x5, lsl #1]\n" - "add w5, w5, #1\n" - "cmp w5, w1\n" - "blt 51b\n" - "blt 50b\n" - "b 31f\n" - "30:\n" - // Yes, all of the 4x4 block fits. - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v16.h}[0], [x3], #2\n" - "st1 {v16.h}[1], [x3], #2\n" - "st1 {v16.h}[2], [x3], #2\n" - "st1 {v16.h}[3], [x3], #2\n" - "31:\n" - - "add %[dst_ptr], %[dst_ptr], #8\n" - - RUY_MAKE_ZERO(v16) - RUY_MAKE_ZERO(v17) - - "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" - - RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n" - - // Since the store type is the same as the accum type, no need for - // downcast. There's also no need for clamp by min/max. - - // Compute how much of the 4x4 block of destination 8bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 4x4, there are some 4x4 blocks along the boundaries that do - // not fit entirely. - "sub w1, %w[dst_rows], %w[row]\n" - "sub w2, %w[dst_cols], %w[col]\n" - "mov w3, #4\n" - "cmp w1, #4\n" - // Compute w1 = how many rows of the 4x4 block fit - "csel w1, w1, w3, le\n" - "cmp w2, #4\n" - - // Test if w1==4 i.e. if all of the 4x1 block fits. - "cmp w1, w3\n" - "ccmp w2, w3, 0, eq\n" - "mov x4, %[dst_ptr]\n" - // Yes, all of the 4x1 block fits, go to fast path. - "beq 30f\n" - // Not all of the 4x4 block fits. - // Store to dst_tmp_buf - "str q16, [%[dst_tmp_buf], #0]\n" - // Slow loop copying from dst_tmp_buf to dst. - "mov x3, %[dst_tmp_buf]\n" - "mov w6, #0\n" - "50:\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov w5, #0\n" - "51:\n" - "ldr w7, [x3, x5, lsl #2]\n" - "str w7, [x4, x5, lsl #2]\n" - "add w5, w5, #1\n" - "cmp w5, w1\n" - "blt 51b\n" - "b 31f\n" - "30:\n" - // Yes, all of the 4x4 block fits. - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v16.s}[0], [x3], #4\n" - "st1 {v16.s}[1], [x3], #4\n" - "st1 {v16.s}[2], [x3], #4\n" - "st1 {v16.s}[3], [x3], #4\n" - "31:\n" - - "add %[dst_ptr], %[dst_ptr], #16\n" - - RUY_MAKE_ZERO(v16) - RUY_MAKE_ZERO(v17) - RUY_MAKE_ZERO(v18) - RUY_MAKE_ZERO(v19) - - RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n" - - // For the next block: perform the first few multiply-adds on the data - // that we have already loaded. - "smull v8.8h, v0.8b, v4.8b\n" - "smull v9.8h, v1.8b, v4.8b\n" - "smull v10.8h, v2.8b, v4.8b\n" - "smull v11.8h, v3.8b, v4.8b\n" - "smlal2 v8.8h, v0.16b, v4.16b\n" - "smlal2 v9.8h, v1.16b, v4.16b\n" - "smlal2 v10.8h, v2.16b, v4.16b\n" - "smlal2 v11.8h, v3.16b, v4.16b\n" - - // Reload some params --- we had used x5 -- x7 for a few other things - // since the last time we had loaded them. - "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" - "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" - - // Move to the next block of the destination matrix, for the next iter - // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already - // been updated earlier. - // Have we reached the end row? - "cmp %w[row], w7\n" - "beq 20f\n" // yes, end row. - // Not end row. Move to the next row. - "add %w[row], %w[row], #4\n" - "b 21f\n" - "20:\n" - // Was already at end row. - "mov %w[row], w6\n" // Move back to first row. - "add %w[col], %w[col], #4\n" // Move to the next column. - "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #2\n" - "mov %[dst_ptr], %[dst_col_ptr]\n" - "21:\n" - - // Main loop exit condition: have we hit the end column? - "cmp %w[col], w8\n" - - // w1 is the number of levels of depth that we have already loaded - // LHS and RHS data for. Corresponding to the initial ld1 instructions - // above, this is currently 16. - "mov w1, #16\n" - - "ble 1b\n" - - // clang-format on - - : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr), - [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), - [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col) - : [ params ] "r"(¶ms), [dst_rows] "r"(params.dst_rows), - [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf), - [dst_type_id] "r"(params.dst_type_id) - : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc", - "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", - "v13", "v14", "v15", "v16", "v17", "v18", "v19"); -} - -// Variant of the above Kernel8bitNeonOutOfOrder, tuned for in-order CPUs. -// Specifically here, the relevant in-order CPUs are ARM Cortex-A53 and -// the original Cortex-A55, since these are 64-bit and do not support dotprod. -// -// While this kernel does not have a direct equivalent in gemmlowp, it was -// developed based on insights that David Mansell at ARM shared with their -// contribution of gemmlowp kernels tuned for Cortex-A53, with very helpful -// comments. Specifically, see this comment about tuning for Cortex-A53: -// https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L4215 -void Kernel8bitNeonInOrder(const KernelParams8bit<4, 4>& params) { - profiler::ScopeLabel label("Kernel (kNeon, optimized for in-order cores)"); - - CheckOffsetsInKernelParams8bit(params); - - const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; - const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; - const std::int8_t* lhs_ptr = lhs_col_ptr; - const std::int8_t* rhs_ptr = rhs_col_ptr; - void* dst_col_ptr = params.dst_base_ptr; - void* dst_ptr = dst_col_ptr; - int row = params.start_row; - int col = params.start_col; - - // The asm kernel below has the following NEON register allocation: - // - // v16 -- v31 are int32 accumulators. - // During accumulation, v0 -- v3 are used to load int8 data from LHS and - // v4 -- v7 from RHS: - // - // int8 RHS 16x4 block - // /-----------------------------------------\ - // |v4.b[0] ... v7.b[0] | - // | ... ... | - // |v4.b[15] ... v7.b[15] | - // \-----------------------------------------/ - // int8 LHS 4x16 block - // /---------------------\ /-----------------------------------------\ - // |v0.b[0] ... v0.b[15] | |v16.4s ... v28.4s | - // |v1.b[0] ... v1.b[15] | |v17.4s ... v29.4s | - // |v2.b[0] ... v2.b[15] | |v18.4s ... v30.4s | - // |v3.b[0] ... v3.b[15] | |v19.4s ... v31.4s | - // \---------------------/ \-----------------------------------------/ - // int32 accumulators 4x4 block - asm volatile( -#define RUY_MAKE_ZERO(reg) "dup " #reg ".4s, wzr\n" - - // clang-format off - - // Load some parameters into registers. - "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - RUY_MAKE_ZERO(v16) - "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" - RUY_MAKE_ZERO(v17) - "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" - RUY_MAKE_ZERO(v18) - "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" - RUY_MAKE_ZERO(v19) - "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" - RUY_MAKE_ZERO(v20) - "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" - RUY_MAKE_ZERO(v21) - "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" - RUY_MAKE_ZERO(v22) - "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" - RUY_MAKE_ZERO(v23) - - // Load the first 64 bytes of LHS and RHS data. - "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" - RUY_MAKE_ZERO(v24) - "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" - RUY_MAKE_ZERO(v25) - "ld1 {v2.16b}, [%[lhs_ptr]], #16\n" - RUY_MAKE_ZERO(v26) - "ld1 {v3.16b}, [%[lhs_ptr]], #16\n" - RUY_MAKE_ZERO(v27) - "ld1 {v4.16b}, [%[rhs_ptr]], #16\n" - RUY_MAKE_ZERO(v28) - "ld1 {v5.16b}, [%[rhs_ptr]], #16\n" - RUY_MAKE_ZERO(v29) - "ld1 {v6.16b}, [%[rhs_ptr]], #16\n" - RUY_MAKE_ZERO(v30) - "ld1 {v7.16b}, [%[rhs_ptr]], #16\n" - RUY_MAKE_ZERO(v31) - - - // w1 is the number of levels of depth that we have already loaded - // LHS and RHS data for. Corresponding to the initial ld1 instructions - // above, this is currently 16. - "mov w1, #16\n" - - // Perform the first few multiply-adds on the data that we have already - // loaded. - "smull v8.8h, v0.8b, v4.8b\n" - "smull v9.8h, v1.8b, v4.8b\n" - "smull v10.8h, v2.8b, v4.8b\n" - "smull v11.8h, v3.8b, v4.8b\n" - "smull v12.8h, v0.8b, v5.8b\n" - "smull v13.8h, v1.8b, v5.8b\n" - "smull v14.8h, v2.8b, v5.8b\n" - "smull v15.8h, v3.8b, v5.8b\n" - - // Multiply-accumulate second-half, again into the same - // 16bit local accumulator registers. This is where we - // take advantage of having int8 instead of uint8 and therefore - // being able to accumulate two products into int16. - "smlal2 v8.8h, v0.16b, v4.16b\n" - "smlal2 v9.8h, v1.16b, v4.16b\n" - "smlal2 v10.8h, v2.16b, v4.16b\n" - "smlal2 v11.8h, v3.16b, v4.16b\n" - "smlal2 v12.8h, v0.16b, v5.16b\n" - "smlal2 v13.8h, v1.16b, v5.16b\n" - "smlal2 v14.8h, v2.16b, v5.16b\n" - "smlal2 v15.8h, v3.16b, v5.16b\n" - - - // Main loop of the whole GEMM, over rows and columns of the - // destination matrix. - "1:\n" - - // Reminder - w1 is how many levels of depth we have already loaded - // data for, w12 is the total depth. - "cmp w1, w12\n" - "beq 79f\n" - - "2:\n" - - // Some multiplications and 16-bit accumulation were already done above, - // so we start right away in the middle. - "sadalp v16.4s, v8.8h\n" - "ldr d4, [%[rhs_ptr], #0]\n" - "smull v8.8h, v0.8b, v6.8b\n" - "ldr x7, [%[rhs_ptr], #8]\n" - "sadalp v17.4s, v9.8h\n" - "ldr d5, [%[rhs_ptr], #16]\n" - "smull v9.8h, v1.8b, v6.8b\n" - "ldr x8, [%[rhs_ptr], #24]\n" - "sadalp v18.4s, v10.8h\n" - "smull v10.8h, v2.8b, v6.8b\n" - "sadalp v19.4s, v11.8h\n" - "add %[lhs_ptr], %[lhs_ptr], #64\n" - "smull v11.8h, v3.8b, v6.8b\n" - "add %[rhs_ptr], %[rhs_ptr], #64\n" - "sadalp v20.4s, v12.8h\n" - // Each iteration of this loop advances by 16 levels of depth. - "add w1, w1, #16\n" - "smull v12.8h, v0.8b, v7.8b\n" - // Loop termination condition - "cmp w1, w12\n" - "sadalp v21.4s, v13.8h\n" - "ldr x3, [%[lhs_ptr], #-56]\n" - "smull v13.8h, v1.8b, v7.8b\n" - "ldr x4, [%[lhs_ptr], #-40]\n" - "sadalp v22.4s, v14.8h\n" - "ldr x5, [%[lhs_ptr], #-24]\n" - "smull v14.8h, v2.8b, v7.8b\n" - "ldr x6, [%[lhs_ptr], #-8]\n" - "sadalp v23.4s, v15.8h\n" - "smull v15.8h, v3.8b, v7.8b\n" - - // Multiply-accumulate second-half, again into the same - // 16bit local accumulator registers. This is where we - // take advantage of having int8 instead of uint8 and therefore - // being able to accumulate two products into int16. - "smlal2 v8.8h, v0.16b, v6.16b\n" - "smlal2 v9.8h, v1.16b, v6.16b\n" - "smlal2 v10.8h, v2.16b, v6.16b\n" - "ldr x9, [%[rhs_ptr], #-24]\n" - "smlal2 v11.8h, v3.16b, v6.16b\n" - "ldr d6, [%[rhs_ptr], #-32]\n" - "smlal2 v12.8h, v0.16b, v7.16b\n" - "ldr d0, [%[lhs_ptr], #-64]\n" - "smlal2 v13.8h, v1.16b, v7.16b\n" - "ldr d1, [%[lhs_ptr], #-48]\n" - "smlal2 v14.8h, v2.16b, v7.16b\n" - "ins v4.d[1], x7\n" - "smlal2 v15.8h, v3.16b, v7.16b\n" - "ins v5.d[1], x8\n" - - "ldr d2, [%[lhs_ptr], #-32]\n" - "ins v0.d[1], x3\n" - "sadalp v24.4s, v8.8h\n" - "ldr d3, [%[lhs_ptr], #-16]\n" - "ins v1.d[1], x4\n" - "smull v8.8h, v0.8b, v4.8b\n" - "ins v2.d[1], x5\n" - "sadalp v25.4s, v9.8h\n" - "ins v3.d[1], x6\n" - "smull v9.8h, v1.8b, v4.8b\n" - "ldr d7, [%[rhs_ptr], #-16]\n" - "sadalp v26.4s, v10.8h\n" - "ldr x10, [%[rhs_ptr], #-8]\n" - "smull v10.8h, v2.8b, v4.8b\n" - "sadalp v27.4s, v11.8h\n" - "smull v11.8h, v3.8b, v4.8b\n" - "sadalp v28.4s, v12.8h\n" - "smull v12.8h, v0.8b, v5.8b\n" - "sadalp v29.4s, v13.8h\n" - "smull v13.8h, v1.8b, v5.8b\n" - "sadalp v30.4s, v14.8h\n" - "smull v14.8h, v2.8b, v5.8b\n" - "sadalp v31.4s, v15.8h\n" - "smull v15.8h, v3.8b, v5.8b\n" - - // Multiply-accumulate second-half, again into the same - // 16bit local accumulator registers. This is where we - // take advantage of having int8 instead of uint8 and therefore - // being able to accumulate two products into int16. - "smlal2 v8.8h, v0.16b, v4.16b\n" - "smlal2 v9.8h, v1.16b, v4.16b\n" - "smlal2 v10.8h, v2.16b, v4.16b\n" - "smlal2 v11.8h, v3.16b, v4.16b\n" - - "smlal2 v12.8h, v0.16b, v5.16b\n" - "smlal2 v13.8h, v1.16b, v5.16b\n" - "ins v6.d[1], x9\n" - "smlal2 v14.8h, v2.16b, v5.16b\n" - "ins v7.d[1], x10\n" - "smlal2 v15.8h, v3.16b, v5.16b\n" - - "blt 2b\n" - - "79:\n" - - "sadalp v16.4s, v8.8h\n" - "smull v8.8h, v0.8b, v6.8b\n" - "sadalp v17.4s, v9.8h\n" - "smull v9.8h, v1.8b, v6.8b\n" - "sadalp v18.4s, v10.8h\n" - "smull v10.8h, v2.8b, v6.8b\n" - "sadalp v19.4s, v11.8h\n" - "smull v11.8h, v3.8b, v6.8b\n" - "sadalp v20.4s, v12.8h\n" - "smull v12.8h, v0.8b, v7.8b\n" - "sadalp v21.4s, v13.8h\n" - "smull v13.8h, v1.8b, v7.8b\n" - "sadalp v22.4s, v14.8h\n" - "smull v14.8h, v2.8b, v7.8b\n" - "sadalp v23.4s, v15.8h\n" - "smull v15.8h, v3.8b, v7.8b\n" - - // Multiply-accumulate second-half, again into the same - // 16bit local accumulator registers. This is where we - // take advantage of having int8 instead of uint8 and therefore - // being able to accumulate two products into int16. - "smlal2 v8.8h, v0.16b, v6.16b\n" - "smlal2 v9.8h, v1.16b, v6.16b\n" - "smlal2 v10.8h, v2.16b, v6.16b\n" - "smlal2 v11.8h, v3.16b, v6.16b\n" - - "smlal2 v12.8h, v0.16b, v7.16b\n" - "smlal2 v13.8h, v1.16b, v7.16b\n" - "smlal2 v14.8h, v2.16b, v7.16b\n" - "smlal2 v15.8h, v3.16b, v7.16b\n" - - "sadalp v24.4s, v8.8h\n" - "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - "sadalp v25.4s, v9.8h\n" - "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" - "sadalp v26.4s, v10.8h\n" - "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" - "sadalp v27.4s, v11.8h\n" - "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" - "sadalp v28.4s, v12.8h\n" - "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" - "sadalp v29.4s, v13.8h\n" - "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" - "sadalp v30.4s, v14.8h\n" - "sadalp v31.4s, v15.8h\n" - - // End of accumulation. The registers v16 -- v31 contain the final - // int32 accumulator values of the current 4x4 destination block. - // We now have to compute the final 8-bit values from these int32 - // accumulators, and advance to the next 4x4 block. We intertwine - // these two aspects whenever possible for optimal pipelining, both - // at the data flow level (prefetch data for next block as early as - // possible) and instruction pipelining level (some of the next-block - // work can dual-issue with some of the final work on the current - // block). - - // Reduce 32bit accumulators horizontally. - "addp v16.4s, v16.4s, v17.4s\n" - "addp v18.4s, v18.4s, v19.4s\n" - "addp v20.4s, v20.4s, v21.4s\n" - "addp v22.4s, v22.4s, v23.4s\n" - "addp v24.4s, v24.4s, v25.4s\n" - "addp v26.4s, v26.4s, v27.4s\n" - "addp v28.4s, v28.4s, v29.4s\n" - "addp v30.4s, v30.4s, v31.4s\n" - - // Reduce 32bit accumulators horizontally, second pass - // (each pass adds pairwise. we need to add 4-wise). - "addp v16.4s, v16.4s, v18.4s\n" - "addp v17.4s, v20.4s, v22.4s\n" - "addp v18.4s, v24.4s, v26.4s\n" - "addp v19.4s, v28.4s, v30.4s\n" - - // Logic to advance to the next block in preparation for the next - // iteration of the main loop. For now, we only want to compute - // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are - // not yet ready to update the values of row and col, as we still need - // the current values for the rest of the work on the current block. - - "cmp %w[row], w7\n" // Have we finished the last row? - "bge 4f\n" // If finished last row, go to 4 - // Not finished last row: then advance to next row. - "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #2\n" - "b 5f\n" - "4:\n" // Finished last row... - "mov %[lhs_col_ptr], x5\n" // Go back to first row - // Now we need to advance to the next column. If we already - // finished the last column, then in principle we are done, however - // we can't just return here, as we need to allow the end work of the - // current block to complete. The good news is that at this point it - // doesn't matter what data we load for the next column, since - // we will exit from the main loop below before actually storing - // anything computed from that data. - "cmp %w[col], w8\n" // Have we finished the last column? - "bge 5f\n" // If yes, just carry on without updating the column pointer. - // Not finished last column: then advance to next column. - "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #2\n" - "5:\n" - - // Set the LHS and RHS data pointers to the start of the columns just - // computed. - "mov %[lhs_ptr], %[lhs_col_ptr]\n" - "mov %[rhs_ptr], %[rhs_col_ptr]\n" - - // Load some parameters needed for the end work on current block. - RUY_MAKE_ZERO(v8) - "ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" - "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n" - "ins v13.h[4], w4\n" // dst_zero_point - "ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n" - "ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" - "dup v9.4s, w3\n" // create prod_zp_depth_vec - "add x5, x4, %x[row], lsl #2\n" - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" - "csel x4, x4, x5, eq\n" - - "ld1 {v15.4s}, [x4]\n" // multiplier_fixedpoint - - "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" - "add x5, x1, %x[row], lsl #2\n" - - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" - "csel x1, x1, x5, eq\n" - - // Load 4 bias values. - "ld1 {v14.4s}, [x1]\n" - - // Now that we know what LHS and RHS data the next iteration of the - // main loop will need to load, we start loading the first 32 bytes of - // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore - // in the rest of the work on the current block. - - // Add to the bias values the product (depth * lhs_zero_point * rhs_zero_point), - // See the term NZ1Z2 in equation (7) in https://arxiv.org/pdf/1712.05877.pdf - "add v14.4s, v14.4s, v9.4s\n" - "ldr d0, [%[lhs_ptr], #0]\n" - - // Perform the bias-addition (per the above, we have just folded into - // the bias the (depth * lhs_zero_point * rhs_zero_point) term.) - "add v16.4s, v16.4s, v14.4s\n" - "ldr d1, [%[lhs_ptr], #16]\n" - "add v17.4s, v17.4s, v14.4s\n" - "ldr d2, [%[lhs_ptr], #32]\n" - "add v18.4s, v18.4s, v14.4s\n" - "ldr d3, [%[lhs_ptr], #48]\n" - "add v19.4s, v19.4s, v14.4s\n" - "ldr d4, [%[rhs_ptr], #0]\n" - "ldr d5, [%[rhs_ptr], #16]\n" - "ldr d6, [%[rhs_ptr], #32]\n" - "ldr d7, [%[rhs_ptr], #48]\n" - - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n" - "beq 401f\n" - "ldr x3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n" - "add x3, x3, %x[col], lsl #2\n" - "ld1 {v14.4s}, [x3]\n" - "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n" - "dup v10.4s, w5\n" // create lhs_zero_point_vec - // Subtract rhs_sums * lhs_zero_point, per - // equation (7) in https://arxiv.org/pdf/1712.05877.pdf - "mls v16.4s, v10.4s, v14.s[0]\n" - "mls v17.4s, v10.4s, v14.s[1]\n" - "mls v18.4s, v10.4s, v14.s[2]\n" - "mls v19.4s, v10.4s, v14.s[3]\n" - "401:\n" - - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n" - "beq 402f\n" - "ldr x2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n" - "add x2, x2, %x[row], lsl #2\n" - "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n" - // Load 4 lhs_sums values. - "ld1 {v11.4s}, [x2]\n" - "ins v13.s[1], w5\n" // rhs_zero_point - // Compute lhs_sums * rhs_zero_point. - "mul v11.4s, v11.4s, v13.s[1]\n" - // Subtract lhs_sums * rhs_zero_point, per - // equation (7) in https://arxiv.org/pdf/1712.05877.pdf - "sub v16.4s, v16.4s, v11.4s\n" - "sub v17.4s, v17.4s, v11.4s\n" - "sub v18.4s, v18.4s, v11.4s\n" - "sub v19.4s, v19.4s, v11.4s\n" - - // If the destination is int32, it means the user asks for the raw - // accumulators, no need for us to downquantize the value. - "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n" - "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n" - - "402:\n" - - // At this point we have computed the final int32 values. Now we - // start down-quantizing them to obtain the final 8bit values from them. - - // As part of this down-quantization, our int32 values will be - // multiplied by a multiplier that has a fixed-point component and an - // exponent component. - - - "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n" - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" - "add x5, x1, %x[row], lsl #2\n" - "csel x1, x1, x5, eq\n" - - "ld1 {v14.4s}, [x1]\n" - - "smax v12.4s, v14.4s, v8.4s\n" - "ldr x1, [%[lhs_ptr], #8]\n" - - "sshl v16.4s, v16.4s, v12.4s\n" - "ldr x2, [%[lhs_ptr], #24]\n" - "sshl v17.4s, v17.4s, v12.4s\n" - "ldr x3, [%[lhs_ptr], #40]\n" - "sshl v18.4s, v18.4s, v12.4s\n" - "ldr x4, [%[lhs_ptr], #56]\n" - "sshl v19.4s, v19.4s, v12.4s\n" - - "smin v12.4s, v14.4s, v8.4s\n" - - // Apply the fixed-point part of the multiplier. - "ins v0.d[1], x1\n" - "ldr x1, [%[rhs_ptr], #8]\n" - "sqrdmulh v16.4s, v16.4s, v15.4s\n" - "ins v1.d[1], x2\n" - "ldr x2, [%[rhs_ptr], #24]\n" - "sqrdmulh v17.4s, v17.4s, v15.4s\n" - "ins v2.d[1], x3\n" - "ldr x3, [%[rhs_ptr], #40]\n" - "sqrdmulh v18.4s, v18.4s, v15.4s\n" - "ins v3.d[1], x4\n" - "ldr x4, [%[rhs_ptr], #56]\n" - "sqrdmulh v19.4s, v19.4s, v15.4s\n" - - // We have some rounding division-by-power-of-two to do. This should - // always use "round to nearest". We allow for some - // freedom in how ties are broken, to strike a good compromise of - // performance on given hardware vs. perfect agreement of results - // across hardware. - // - // When RUY_OPT_NATIVE_ROUNDING is enabled, we allow for implementation - // defined tie-breaks to help performance. On NEON, this means that we - // can just use the NEON rounding instructions, such as srshl. They - // happen to be breaking ties upward. - // - // When RUY_OPT_NATIVE_ROUNDING is disabled, we implement strict - // break-ties-away-from zero, as described in Appendix B of - // https://arxiv.org/pdf/1712.05877.pdf - // When we wrote that, we thought that that would be better unbiased - // than the NEON upwards tie-breaks, and we had observed some - // improvement on some model. However, that is only more unbiased for - // data centered at zero, which was likely the case in that model, - // but is not always the case. If we wanted something more consistently - // unbiased then we should try breaking ties toward-nearest-even. -#if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING) - // Fix up values to be right-shifted, so that the (round to nearest, - // break ties upward) behavior of srshl applied to these fixed-up - // values, produces the same result as the desired (round to nearest, - // break ties away from zero) behavior on the original values. - "and v8.16b, v16.16b, v12.16b\n" - "and v9.16b, v17.16b, v12.16b\n" - "and v14.16b, v18.16b, v12.16b\n" - "and v15.16b, v19.16b, v12.16b\n" - "sshr v8.4s, v8.4s, #31\n" - "sshr v9.4s, v9.4s, #31\n" - "sshr v14.4s, v14.4s, #31\n" - "sshr v15.4s, v15.4s, #31\n" - "sqadd v16.4s, v16.4s, v8.4s\n" - "sqadd v17.4s, v17.4s, v9.4s\n" - "sqadd v18.4s, v18.4s, v14.4s\n" - "sqadd v19.4s, v19.4s, v15.4s\n" -#endif - // At this point we have reduced the problem of correctly implementing - // rounding divide-by-power-of-two, to what the SRSHL instruction can - // do. - "srshl v16.4s, v16.4s, v12.4s\n" - "srshl v17.4s, v17.4s, v12.4s\n" - "srshl v18.4s, v18.4s, v12.4s\n" - "srshl v19.4s, v19.4s, v12.4s\n" - - "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n" - "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n" - "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n" - "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n" - - RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n" - - "ins v4.d[1], x1\n" - "sqxtn v16.4h, v16.4s\n" - "ins v5.d[1], x2\n" - "sqxtn2 v16.8h, v17.4s\n" - "ins v6.d[1], x3\n" - "sqxtn v17.4h, v18.4s\n" - "ins v7.d[1], x4\n" - RUY_MAKE_ZERO(v18) - "sqxtn2 v17.8h, v19.4s\n" - - // At this point, v18 -- v31 aren't used anymore for the current block, - // so we can start clearing these accumulators for the next block - // (next iteration of the main loop). - RUY_MAKE_ZERO(v19) - - // Add the destination zero point - "add %[lhs_ptr], %[lhs_ptr], #64\n" - "dup v14.8h, v13.h[4]\n" - RUY_MAKE_ZERO(v20) - "add %[rhs_ptr], %[rhs_ptr], #64\n" - "add v16.8h, v16.8h, v14.8h\n" - RUY_MAKE_ZERO(v21) - "add v17.8h, v17.8h, v14.8h\n" - RUY_MAKE_ZERO(v22) - - // Cast-and-saturate from int16 to uint8 - "sqxtun v16.8b, v16.8h\n" - RUY_MAKE_ZERO(v23) - "sqxtun2 v16.16b, v17.8h\n" - RUY_MAKE_ZERO(v24) - - // Load the clamp_min, clamp_max bounds - "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" - RUY_MAKE_ZERO(v25) - "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" - RUY_MAKE_ZERO(v26) - "dup v14.16b, w2\n" // clamp_min - RUY_MAKE_ZERO(v27) - "dup v15.16b, w3\n" // clamp_max - RUY_MAKE_ZERO(v28) - - // Apply the clamp_min bound - "umax v16.16b, v16.16b, v14.16b\n" - RUY_MAKE_ZERO(v29) - // Apply the clamp_max bound - "umin v16.16b, v16.16b, v15.16b\n" - RUY_MAKE_ZERO(v30) - - // Compute how much of the 4x4 block of destination 8bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 4x4, there are some 4x4 blocks along the boundaries that do - // not fit entirely. - "sub w1, %w[dst_rows], %w[row]\n" - RUY_MAKE_ZERO(v31) - "sub w2, %w[dst_cols], %w[col]\n" - "mov w3, #4\n" - "cmp w1, #4\n" - // Compute w1 = how many rows of the 4x4 block fit - "csel w1, w1, w3, le\n" - "cmp w2, #4\n" - // Compute w2 = how many cols of the 4x4 block fit - "csel w2, w2, w3, le\n" - - // Test if w1==4 && w2 == 4, i.e. if all of the 8x8 block fits. - "cmp w1, w3\n" - "ccmp w2, w3, 0, eq\n" - "mov x4, %[dst_ptr]\n" - // Yes, all of the 4x4 block fits, go to fast path. - "beq 30f\n" - // Not all of the 4x4 block fits. - // Store to dst_tmp_buf - "st1 {v16.16b}, [%[dst_tmp_buf]]\n" - // Slow loop copying from dst_tmp_buf to dst. - "mov x3, %[dst_tmp_buf]\n" - "mov w6, #0\n" - "50:\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov w5, #0\n" - "51:\n" - "ldrb w7, [x3, w5, uxtw]\n" - "strb w7, [x4, w5, uxtw]\n" - "add w5, w5, #1\n" - "cmp w5, w1\n" - "blt 51b\n" - "add w6, w6, #1\n" - "add x3, x3, #4\n" - "add x4, x4, x11\n" - "cmp w6, w2\n" - "blt 50b\n" - "b 31f\n" - "30:\n" - // Yes, all of the 4x4 block fits. - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v16.b}[0], [x3], #1\n" - "add x4, x4, x11\n" - "st1 {v16.b}[1], [x3], #1\n" - "st1 {v16.b}[2], [x3], #1\n" - "st1 {v16.b}[3], [x3], #1\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v16.b}[4], [x3], #1\n" - "add x4, x4, x11\n" - "st1 {v16.b}[5], [x3], #1\n" - "st1 {v16.b}[6], [x3], #1\n" - "st1 {v16.b}[7], [x3], #1\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v16.b}[8], [x3], #1\n" - "add x4, x4, x11\n" - "st1 {v16.b}[9], [x3], #1\n" - "st1 {v16.b}[10], [x3], #1\n" - "st1 {v16.b}[11], [x3], #1\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v16.b}[12], [x3], #1\n" - "add x4, x4, x11\n" - "st1 {v16.b}[13], [x3], #1\n" - "st1 {v16.b}[14], [x3], #1\n" - "st1 {v16.b}[15], [x3], #1\n" - "31:\n" - - "add %[dst_ptr], %[dst_ptr], #4\n" - - RUY_MAKE_ZERO(v16) - RUY_MAKE_ZERO(v17) - - "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" - - RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n" - - "ins v4.d[1], x1\n" - "sqxtn v16.4h, v16.4s\n" - "ins v5.d[1], x2\n" - "sqxtn2 v16.8h, v17.4s\n" - "ins v6.d[1], x3\n" - "sqxtn v17.4h, v18.4s\n" - "ins v7.d[1], x4\n" - RUY_MAKE_ZERO(v18) - "sqxtn2 v17.8h, v19.4s\n" - - // At this point, v18 -- v31 aren't used anymore for the current block, - // so we can start clearing these accumulators for the next block - // (next iteration of the main loop). - RUY_MAKE_ZERO(v19) - - // Add the destination zero point - "add %[lhs_ptr], %[lhs_ptr], #64\n" - "dup v14.8h, v13.h[4]\n" - RUY_MAKE_ZERO(v20) - "add %[rhs_ptr], %[rhs_ptr], #64\n" - "add v16.8h, v16.8h, v14.8h\n" - RUY_MAKE_ZERO(v21) - "add v17.8h, v17.8h, v14.8h\n" - RUY_MAKE_ZERO(v22) - - // Cast-and-saturate from int16 to uint8 - "sqxtn v16.8b, v16.8h\n" - RUY_MAKE_ZERO(v23) - "sqxtn2 v16.16b, v17.8h\n" - RUY_MAKE_ZERO(v24) - - // Load the clamp_min, clamp_max bounds - "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" - RUY_MAKE_ZERO(v25) - "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" - RUY_MAKE_ZERO(v26) - "dup v14.16b, w2\n" // clamp_min - RUY_MAKE_ZERO(v27) - "dup v15.16b, w3\n" // clamp_max - RUY_MAKE_ZERO(v28) - - // Apply the clamp_min bound - "smax v16.16b, v16.16b, v14.16b\n" - RUY_MAKE_ZERO(v29) - // Apply the clamp_max bound - "smin v16.16b, v16.16b, v15.16b\n" - RUY_MAKE_ZERO(v30) - - // Compute how much of the 4x4 block of destination 8bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 4x4, there are some 4x4 blocks along the boundaries that do - // not fit entirely. - "sub w1, %w[dst_rows], %w[row]\n" - RUY_MAKE_ZERO(v31) - "sub w2, %w[dst_cols], %w[col]\n" - "mov w3, #4\n" - "cmp w1, #4\n" - // Compute w1 = how many rows of the 4x4 block fit - "csel w1, w1, w3, le\n" - "cmp w2, #4\n" - // Compute w2 = how many cols of the 4x4 block fit - "csel w2, w2, w3, le\n" - - // Test if w1==4 && w2 == 4, i.e. if all of the 8x8 block fits. - "cmp w1, w3\n" - "ccmp w2, w3, 0, eq\n" - "mov x4, %[dst_ptr]\n" - // Yes, all of the 4x4 block fits, go to fast path. - "beq 30f\n" - // Not all of the 4x4 block fits. - // Store to dst_tmp_buf - "st1 {v16.16b}, [%[dst_tmp_buf]]\n" - // Slow loop copying from dst_tmp_buf to dst. - "mov x3, %[dst_tmp_buf]\n" - "mov w6, #0\n" - "50:\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov w5, #0\n" - "51:\n" - "ldrb w7, [x3, w5, uxtw]\n" - "strb w7, [x4, w5, uxtw]\n" - "add w5, w5, #1\n" - "cmp w5, w1\n" - "blt 51b\n" - "add w6, w6, #1\n" - "add x3, x3, #4\n" - "add x4, x4, x11\n" - "cmp w6, w2\n" - "blt 50b\n" - "b 31f\n" - "30:\n" - // Yes, all of the 4x4 block fits. - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v16.b}[0], [x3], #1\n" - "add x4, x4, x11\n" - "st1 {v16.b}[1], [x3], #1\n" - "st1 {v16.b}[2], [x3], #1\n" - "st1 {v16.b}[3], [x3], #1\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v16.b}[4], [x3], #1\n" - "add x4, x4, x11\n" - "st1 {v16.b}[5], [x3], #1\n" - "st1 {v16.b}[6], [x3], #1\n" - "st1 {v16.b}[7], [x3], #1\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v16.b}[8], [x3], #1\n" - "add x4, x4, x11\n" - "st1 {v16.b}[9], [x3], #1\n" - "st1 {v16.b}[10], [x3], #1\n" - "st1 {v16.b}[11], [x3], #1\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v16.b}[12], [x3], #1\n" - "add x4, x4, x11\n" - "st1 {v16.b}[13], [x3], #1\n" - "st1 {v16.b}[14], [x3], #1\n" - "st1 {v16.b}[15], [x3], #1\n" - "31:\n" - - "add %[dst_ptr], %[dst_ptr], #4\n" - - RUY_MAKE_ZERO(v16) - RUY_MAKE_ZERO(v17) - - "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" - - RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n" - - // Add the destination zero point - "dup v14.4h, v13.h[4]\n" - "saddw v16.4s, v16.4s, v14.4h\n" - "saddw v17.4s, v17.4s, v14.4h\n" - "saddw v18.4s, v18.4s, v14.4h\n" - "saddw v19.4s, v19.4s, v14.4h\n" - - // Cast-and-saturate from int32 to int16 - "ins v4.d[1], x1\n" - "sqxtn v16.4h, v16.4s\n" - "ins v5.d[1], x2\n" - "sqxtn2 v16.8h, v17.4s\n" - "ins v6.d[1], x3\n" - "sqxtn v17.4h, v18.4s\n" - "ins v7.d[1], x4\n" - RUY_MAKE_ZERO(v18) - "sqxtn2 v17.8h, v19.4s\n" - - // At this point, v18 -- v31 aren't used anymore for the current block, - // so we can start clearing these accumulators for the next block - // (next iteration of the main loop). - RUY_MAKE_ZERO(v19) - - "add %[lhs_ptr], %[lhs_ptr], #64\n" - RUY_MAKE_ZERO(v20) - "add %[rhs_ptr], %[rhs_ptr], #64\n" - RUY_MAKE_ZERO(v21) - RUY_MAKE_ZERO(v22) - - RUY_MAKE_ZERO(v23) - RUY_MAKE_ZERO(v24) - - // Load the clamp_min, clamp_max bounds - "ldrh w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" - RUY_MAKE_ZERO(v25) - "ldrh w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" - RUY_MAKE_ZERO(v26) - "dup v14.8h, w2\n" // clamp_min - RUY_MAKE_ZERO(v27) - "dup v15.8h, w3\n" // clamp_max - RUY_MAKE_ZERO(v28) - - // Apply the clamp_min bound - "smax v16.8h, v16.8h, v14.8h\n" - "smax v17.8h, v17.8h, v14.8h\n" - RUY_MAKE_ZERO(v29) - // Apply the clamp_max bound - "smin v16.8h, v16.8h, v15.8h\n" - "smin v17.8h, v17.8h, v15.8h\n" - RUY_MAKE_ZERO(v30) - - // Compute how much of the 4x4 block of destination 8bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 4x4, there are some 4x4 blocks along the boundaries that do - // not fit entirely. - "sub w1, %w[dst_rows], %w[row]\n" - RUY_MAKE_ZERO(v31) - "sub w2, %w[dst_cols], %w[col]\n" - "mov w3, #4\n" - "cmp w1, #4\n" - // Compute w1 = how many rows of the 4x4 block fit - "csel w1, w1, w3, le\n" - "cmp w2, #4\n" - // Compute w2 = how many cols of the 4x4 block fit - "csel w2, w2, w3, le\n" - - // Test if w1==4 && w2 == 4, i.e. if all of the 4x4 block fits. - "cmp w1, w3\n" - "ccmp w2, w3, 0, eq\n" - "mov x4, %[dst_ptr]\n" - // Yes, all of the 4x4 block fits, go to fast path. - "beq 30f\n" - // Not all of the 4x4 block fits. - // Store to dst_tmp_buf - "str q16, [%[dst_tmp_buf], #0]\n" - "str q17, [%[dst_tmp_buf], #16]\n" - // Slow loop copying from dst_tmp_buf to dst. - "mov x3, %[dst_tmp_buf]\n" - "mov w6, #0\n" - "50:\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov w5, #0\n" - "51:\n" - "ldrh w7, [x3, x5, lsl #1]\n" - "strh w7, [x4, x5, lsl #1]\n" - "add w5, w5, #1\n" - "cmp w5, w1\n" - "blt 51b\n" - "add w6, w6, #1\n" - "add x3, x3, #8\n" - "add x4, x4, x11\n" - "cmp w6, w2\n" - "blt 50b\n" - "b 31f\n" - "30:\n" - // Yes, all of the 4x4 block fits. - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v16.h}[0], [x3], #2\n" - "add x4, x4, x11\n" - "st1 {v16.h}[1], [x3], #2\n" - "st1 {v16.h}[2], [x3], #2\n" - "st1 {v16.h}[3], [x3], #2\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v16.h}[4], [x3], #2\n" - "add x4, x4, x11\n" - "st1 {v16.h}[5], [x3], #2\n" - "st1 {v16.h}[6], [x3], #2\n" - "st1 {v16.h}[7], [x3], #2\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v17.h}[0], [x3], #2\n" - "add x4, x4, x11\n" - "st1 {v17.h}[1], [x3], #2\n" - "st1 {v17.h}[2], [x3], #2\n" - "st1 {v17.h}[3], [x3], #2\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v17.h}[4], [x3], #2\n" - "add x4, x4, x11\n" - "st1 {v17.h}[5], [x3], #2\n" - "st1 {v17.h}[6], [x3], #2\n" - "st1 {v17.h}[7], [x3], #2\n" - "31:\n" - - "add %[dst_ptr], %[dst_ptr], #8\n" - - RUY_MAKE_ZERO(v16) - RUY_MAKE_ZERO(v17) - - "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" - - RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n" - - "ldr x1, [%[lhs_ptr], #8]\n" - "ldr x2, [%[lhs_ptr], #24]\n" - "ldr x3, [%[lhs_ptr], #40]\n" - "ldr x4, [%[lhs_ptr], #56]\n" - - "ins v0.d[1], x1\n" - "ldr x1, [%[rhs_ptr], #8]\n" - "ins v1.d[1], x2\n" - "ldr x2, [%[rhs_ptr], #24]\n" - "ins v2.d[1], x3\n" - "ldr x3, [%[rhs_ptr], #40]\n" - "ins v3.d[1], x4\n" - "ldr x4, [%[rhs_ptr], #56]\n" - "ins v4.d[1], x1\n" - "ins v5.d[1], x2\n" - "ins v6.d[1], x3\n" - "ins v7.d[1], x4\n" - - // Since the store type is the same as the accum type, no need for - // downcast. There's also no need for clamp by min/max. - - // At this point, v20 -- v31 aren't used anymore for the current block, - // so we can start clearing these accumulators for the next block - // (next iteration of the main loop). - - RUY_MAKE_ZERO(v20) - "add %[lhs_ptr], %[lhs_ptr], #64\n" - RUY_MAKE_ZERO(v21) - "add %[rhs_ptr], %[rhs_ptr], #64\n" - RUY_MAKE_ZERO(v22) - - RUY_MAKE_ZERO(v23) - RUY_MAKE_ZERO(v24) - RUY_MAKE_ZERO(v25) - RUY_MAKE_ZERO(v26) - RUY_MAKE_ZERO(v27) - RUY_MAKE_ZERO(v28) - RUY_MAKE_ZERO(v29) - RUY_MAKE_ZERO(v30) - - // Compute how much of the 4x4 block of destination 8bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 4x4, there are some 4x4 blocks along the boundaries that do - // not fit entirely. - "sub w1, %w[dst_rows], %w[row]\n" - RUY_MAKE_ZERO(v31) - "sub w2, %w[dst_cols], %w[col]\n" - "mov w3, #4\n" - "cmp w1, #4\n" - // Compute w1 = how many rows of the 4x4 block fit - "csel w1, w1, w3, le\n" - "cmp w2, #4\n" - // Compute w2 = how many cols of the 4x4 block fit - "csel w2, w2, w3, le\n" - - // Test if w1==4 && w2 == 4, i.e. if all of the 8x8 block fits. - "cmp w1, w3\n" - "ccmp w2, w3, 0, eq\n" - "mov x4, %[dst_ptr]\n" - // Yes, all of the 4x4 block fits, go to fast path. - "beq 30f\n" - // Not all of the 4x4 block fits. - // Store to dst_tmp_buf - "str q16, [%[dst_tmp_buf], #0]\n" - "str q17, [%[dst_tmp_buf], #16]\n" - "str q18, [%[dst_tmp_buf], #32]\n" - "str q19, [%[dst_tmp_buf], #48]\n" - // Slow loop copying from dst_tmp_buf to dst. - "mov x3, %[dst_tmp_buf]\n" - "mov w6, #0\n" - "50:\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov w5, #0\n" - "51:\n" - "ldr w7, [x3, x5, lsl #2]\n" - "str w7, [x4, x5, lsl #2]\n" - "add w5, w5, #1\n" - "cmp w5, w1\n" - "blt 51b\n" - "add w6, w6, #1\n" - "add x3, x3, #16\n" - "add x4, x4, x11\n" - "cmp w6, w2\n" - "blt 50b\n" - "b 31f\n" - "30:\n" - // Yes, all of the 4x4 block fits. - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v16.s}[0], [x3], #4\n" - "add x4, x4, x11\n" - "st1 {v16.s}[1], [x3], #4\n" - "st1 {v16.s}[2], [x3], #4\n" - "st1 {v16.s}[3], [x3], #4\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v17.s}[0], [x3], #4\n" - "add x4, x4, x11\n" - "st1 {v17.s}[1], [x3], #4\n" - "st1 {v17.s}[2], [x3], #4\n" - "st1 {v17.s}[3], [x3], #4\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v18.s}[0], [x3], #4\n" - "add x4, x4, x11\n" - "st1 {v18.s}[1], [x3], #4\n" - "st1 {v18.s}[2], [x3], #4\n" - "st1 {v18.s}[3], [x3], #4\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v19.s}[0], [x3], #4\n" - "add x4, x4, x11\n" - "st1 {v19.s}[1], [x3], #4\n" - "st1 {v19.s}[2], [x3], #4\n" - "st1 {v19.s}[3], [x3], #4\n" - "31:\n" - - "add %[dst_ptr], %[dst_ptr], #16\n" - - RUY_MAKE_ZERO(v16) - RUY_MAKE_ZERO(v17) - RUY_MAKE_ZERO(v18) - RUY_MAKE_ZERO(v19) - - RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n" - - // For the next block: perform the first few multiply-adds on the data - // that we have already loaded. - "smull v8.8h, v0.8b, v4.8b\n" - "smull v9.8h, v1.8b, v4.8b\n" - "smull v10.8h, v2.8b, v4.8b\n" - // Reload some params --- we had used x5 -- x7 for a few other things - // since the last time we had loaded them. - "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - "smull v11.8h, v3.8b, v4.8b\n" - "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" - "smull v12.8h, v0.8b, v5.8b\n" - "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" - "smull v13.8h, v1.8b, v5.8b\n" - "smull v14.8h, v2.8b, v5.8b\n" - "smull v15.8h, v3.8b, v5.8b\n" - // Move to the next block of the destination matrix, for the next iter - // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already - // been updated earlier. - // Have we reached the end row? - "cmp %w[row], w7\n" - "smlal2 v8.8h, v0.16b, v4.16b\n" - "smlal2 v9.8h, v1.16b, v4.16b\n" - "smlal2 v10.8h, v2.16b, v4.16b\n" - "smlal2 v11.8h, v3.16b, v4.16b\n" - "smlal2 v12.8h, v0.16b, v5.16b\n" - "smlal2 v13.8h, v1.16b, v5.16b\n" - "smlal2 v14.8h, v2.16b, v5.16b\n" - "smlal2 v15.8h, v3.16b, v5.16b\n" - - - "beq 20f\n" // yes, end row. - // Not end row. Move to the next row. - "add %w[row], %w[row], #4\n" - "b 21f\n" - "20:\n" - // Was already at end row. - "mov %w[row], w6\n" // Move back to first row. - "add %w[col], %w[col], #4\n" // Move to the next column. - "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #2\n" - "mov %[dst_ptr], %[dst_col_ptr]\n" - "21:\n" - - // Main loop exit condition: have we hit the end column? - "cmp %w[col], w8\n" - - // w1 is the number of levels of depth that we have already loaded - // LHS and RHS data for. Corresponding to the initial ld1 instructions - // above, this is currently 4. - "mov w1, #16\n" - - "ble 1b\n" - - // clang-format on - - : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr), - [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), - [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col) - : [ params ] "r"(¶ms),[dst_rows] "r"(params.dst_rows), - [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf), - [dst_type_id] "r"(params.dst_type_id) - : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc", - "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", - "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", - "v26", "v27", "v28", "v29", "v30", "v31"); -} - -// Kernel taking advantage of the optional dotprod instruction. -// This is very similar to (and directly inspired by) this gemmlowp kernel -// which was contributed by David Mansell at ARM: -// NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_dotproduct -// https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L3391 -// -// Besides the ruy-ification, the main difference here is that we use a 8x8 -// instead of 12x8 width, so as to stick to power-of-two widths. This slightly -// narrower kernel layout is still wide enough to achieve high performance -// although we haven't actually performed a real comparison to know exactly -// how this compares to ARM's aforementioned kernel. -// -// Relevant target CPUs for this kernel include ARM Cortex-A76, -// since these are 64-bit, out-of-order and with dotprod support. -void Kernel8bitNeonDotprodOutOfOrder(const KernelParams8bit<8, 8>& params) { - profiler::ScopeLabel label( - "Kernel (kNeonDotprod, optimized for out-of-order cores)"); - - CheckOffsetsInKernelParams8bit(params); - - const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; - const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; - const std::int8_t* lhs_ptr = lhs_col_ptr; - const std::int8_t* rhs_ptr = rhs_col_ptr; - void* dst_col_ptr = params.dst_base_ptr; - void* dst_ptr = dst_col_ptr; - int row = params.start_row; - int col = params.start_col; - - // The asm kernel below has the following NEON register allocation: - // - // v16 -- v31 are int32 accumulators. - // During accumulation, v0 -- v15 are used to load int8 data from LHS and - // RHS. At least v0 and v1 are used to load a 8x4 block of LHS, and v2 and - // v3 are used to load a 4x8 block of RHS, like this: - // - // int8 RHS 4x8 block - // /-----------------------------------------\ - // |v2.b[0] ... v2.b[12] v3.b[0] ... v3.b[12]| - // | ... ... | - // |v2.b[3] ... v2.b[15] v3.b[3] ... v3.b[15]| - // \-----------------------------------------/ - // int8 LHS 8x4 block - // /---------------------\ /-----------------------------------------\ - // |v0.b[0] ... v0.b[3] | |v16.s[0] ... v30.s[0]| - // | ... ... | | ... ... | - // |v0.b[12] ... v0.b[15]| |v16.s[3] ... v30.s[3]| - // |v1.b[0] ... v1.b[3] | |v17.s[0] ... v31.s[0]| - // | ... ... | | ... ... | - // |v1.b[12] ... v1.b[15]| |v17.s[3] ... v31.s[3]| - // \---------------------/ \-----------------------------------------/ - // int32 accumulators 8x8 block - // - // In the RUY_OPT_MAX_STREAMING part of the kernel, this elementary step - // is repeated 4 times, using 4x more registers for LHS and RHS, so that - // is where instead of using v0 -- v3 for LHS and RHS, we use v0 -- v15. - // - // Outside of the RUY_OPT_MAX_STREAMING part of the kernel, v4 -- v7 are - // unused, and v8 -- v15 are used for loading parameters used for the - // post-accumulation part of the kernel. - asm volatile( -#define RUY_MAKE_ZERO(reg) "dup " #reg ".4s, wzr\n" - - // clang-format off - - // Load some parameters into registers. - "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" - "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" - "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" - "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" - "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" - "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" - "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" - - // Load the first 32 bytes of LHS and RHS data. - "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v2.16b}, [%[rhs_ptr]], #16\n" - "ld1 {v3.16b}, [%[rhs_ptr]], #16\n" - - // Clear accumulators. - RUY_MAKE_ZERO(v16) - RUY_MAKE_ZERO(v17) - RUY_MAKE_ZERO(v18) - RUY_MAKE_ZERO(v19) - RUY_MAKE_ZERO(v20) - RUY_MAKE_ZERO(v21) - RUY_MAKE_ZERO(v22) - RUY_MAKE_ZERO(v23) - RUY_MAKE_ZERO(v24) - RUY_MAKE_ZERO(v25) - RUY_MAKE_ZERO(v26) - RUY_MAKE_ZERO(v27) - RUY_MAKE_ZERO(v28) - RUY_MAKE_ZERO(v29) - RUY_MAKE_ZERO(v30) - RUY_MAKE_ZERO(v31) - - // w1 is the number of levels of depth that we have already loaded - // LHS and RHS data for. Corresponding to the initial ld1 instructions - // above, this is currently 4. - "mov w1, #4\n" - - // Perform the first few multiply-adds on the data that we have already - // loaded. - ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" - ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" - ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" - ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" - - // Main loop of the whole GEMM, over rows and columns of the - // destination matrix. - "1:\n" - - // Optional, maximally-streaming, partial-unrolling (4x unrolled) - // optimization of the kernel inner loop (over depth). For more - // comments, see the non-unrolled loop below after the #endif. -#if RUY_OPT_ENABLED(RUY_OPT_MAX_STREAMING) - "cmp w12, #32\n" - "blt 78f\n" - - "ld1 {v4.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v5.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v6.16b}, [%[rhs_ptr]], #16\n" - "ld1 {v7.16b}, [%[rhs_ptr]], #16\n" - "ld1 {v8.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v9.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v10.16b}, [%[rhs_ptr]], #16\n" - "ld1 {v11.16b}, [%[rhs_ptr]], #16\n" - "ld1 {v12.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v13.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v14.16b}, [%[rhs_ptr]], #16\n" - "ld1 {v15.16b}, [%[rhs_ptr]], #16\n" - "mov w1, #16\n" - - "and w3, w12, #-16\n" - "81:\n" - "add w1, w1, #16\n" - - ".word 0x4f83e018 // sdot v24.4s, v0.16b, v3.4b[0]\n" - ".word 0x4fa3e01a // sdot v26.4s, v0.16b, v3.4b[1]\n" - ".word 0x4f83e81c // sdot v28.4s, v0.16b, v3.4b[2]\n" - ".word 0x4fa3e81e // sdot v30.4s, v0.16b, v3.4b[3]\n" - "ldr q0, [%[lhs_ptr], #0]\n" - ".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n" - ".word 0x4fa2e033 // sdot v19.4s, v1.16b, v2.4b[1]\n" - ".word 0x4f82e835 // sdot v21.4s, v1.16b, v2.4b[2]\n" - ".word 0x4fa2e837 // sdot v23.4s, v1.16b, v2.4b[3]\n" - "ldr q2, [%[rhs_ptr], #0]\n" - ".word 0x4f83e039 // sdot v25.4s, v1.16b, v3.4b[0]\n" - ".word 0x4fa3e03b // sdot v27.4s, v1.16b, v3.4b[1]\n" - ".word 0x4f83e83d // sdot v29.4s, v1.16b, v3.4b[2]\n" - ".word 0x4fa3e83f // sdot v31.4s, v1.16b, v3.4b[3]\n" - "ldr q1, [%[lhs_ptr], #16]\n" - - ".word 0x4f87e098 // sdot v24.4s, v4.16b, v7.4b[0]\n" - ".word 0x4fa7e09a // sdot v26.4s, v4.16b, v7.4b[1]\n" - "ldr q3, [%[rhs_ptr], #16]\n" - ".word 0x4f87e89c // sdot v28.4s, v4.16b, v7.4b[2]\n" - ".word 0x4fa7e89e // sdot v30.4s, v4.16b, v7.4b[3]\n" - ".word 0x4f86e0b1 // sdot v17.4s, v5.16b, v6.4b[0]\n" - ".word 0x4fa6e0b3 // sdot v19.4s, v5.16b, v6.4b[1]\n" - ".word 0x4f86e8b5 // sdot v21.4s, v5.16b, v6.4b[2]\n" - ".word 0x4fa6e8b7 // sdot v23.4s, v5.16b, v6.4b[3]\n" - ".word 0x4f87e0b9 // sdot v25.4s, v5.16b, v7.4b[0]\n" - ".word 0x4fa7e0bb // sdot v27.4s, v5.16b, v7.4b[1]\n" - ".word 0x4f87e8bd // sdot v29.4s, v5.16b, v7.4b[2]\n" - ".word 0x4fa7e8bf // sdot v31.4s, v5.16b, v7.4b[3]\n" - "ldr q5, [%[lhs_ptr], #48]\n" - ".word 0x4f86e090 // sdot v16.4s, v4.16b, v6.4b[0]\n" - ".word 0x4fa6e092 // sdot v18.4s, v4.16b, v6.4b[1]\n" - "ldr q7, [%[rhs_ptr], #48]\n" - ".word 0x4f86e894 // sdot v20.4s, v4.16b, v6.4b[2]\n" - ".word 0x4fa6e896 // sdot v22.4s, v4.16b, v6.4b[3]\n" - "ldr q4, [%[lhs_ptr], #32]\n" - - ".word 0x4f8be118 // sdot v24.4s, v8.16b, v11.4b[0]\n" - ".word 0x4fabe11a // sdot v26.4s, v8.16b, v11.4b[1]\n" - "ldr q6, [%[rhs_ptr], #32]\n" - ".word 0x4f8be91c // sdot v28.4s, v8.16b, v11.4b[2]\n" - ".word 0x4fabe91e // sdot v30.4s, v8.16b, v11.4b[3]\n" - ".word 0x4f8ae131 // sdot v17.4s, v9.16b, v10.4b[0]\n" - ".word 0x4faae133 // sdot v19.4s, v9.16b, v10.4b[1]\n" - ".word 0x4f8ae935 // sdot v21.4s, v9.16b, v10.4b[2]\n" - ".word 0x4faae937 // sdot v23.4s, v9.16b, v10.4b[3]\n" - ".word 0x4f8be139 // sdot v25.4s, v9.16b, v11.4b[0]\n" - ".word 0x4fabe13b // sdot v27.4s, v9.16b, v11.4b[1]\n" - ".word 0x4f8be93d // sdot v29.4s, v9.16b, v11.4b[2]\n" - ".word 0x4fabe93f // sdot v31.4s, v9.16b, v11.4b[3]\n" - "ldr q9, [%[lhs_ptr], #80]\n" - ".word 0x4f8ae110 // sdot v16.4s, v8.16b, v10.4b[0]\n" - ".word 0x4faae112 // sdot v18.4s, v8.16b, v10.4b[1]\n" - "ldr q11, [%[rhs_ptr], #80]\n" - ".word 0x4f8ae914 // sdot v20.4s, v8.16b, v10.4b[2]\n" - ".word 0x4faae916 // sdot v22.4s, v8.16b, v10.4b[3]\n" - "ldr q8, [%[lhs_ptr], #64]\n" - - ".word 0x4f8fe198 // sdot v24.4s, v12.16b, v15.4b[0]\n" - ".word 0x4fafe19a // sdot v26.4s, v12.16b, v15.4b[1]\n" - "ldr q10, [%[rhs_ptr], #64]\n" - ".word 0x4f8fe99c // sdot v28.4s, v12.16b, v15.4b[2]\n" - ".word 0x4fafe99e // sdot v30.4s, v12.16b, v15.4b[3]\n" - "add %[lhs_ptr], %[lhs_ptr], #128\n" - ".word 0x4f8ee1b1 // sdot v17.4s, v13.16b, v14.4b[0]\n" - ".word 0x4faee1b3 // sdot v19.4s, v13.16b, v14.4b[1]\n" - "add %[rhs_ptr], %[rhs_ptr], #128\n" - ".word 0x4f8ee9b5 // sdot v21.4s, v13.16b, v14.4b[2]\n" - ".word 0x4faee9b7 // sdot v23.4s, v13.16b, v14.4b[3]\n" - ".word 0x4f8fe1b9 // sdot v25.4s, v13.16b, v15.4b[0]\n" - ".word 0x4fafe1bb // sdot v27.4s, v13.16b, v15.4b[1]\n" - "cmp w1, w3\n" - ".word 0x4f8fe9bd // sdot v29.4s, v13.16b, v15.4b[2]\n" - ".word 0x4fafe9bf // sdot v31.4s, v13.16b, v15.4b[3]\n" - "ldr q13, [%[lhs_ptr], #-16]\n" - ".word 0x4f8ee190 // sdot v16.4s, v12.16b, v14.4b[0]\n" - ".word 0x4faee192 // sdot v18.4s, v12.16b, v14.4b[1]\n" - "ldr q15, [%[rhs_ptr], #-16]\n" - ".word 0x4f8ee994 // sdot v20.4s, v12.16b, v14.4b[2]\n" - ".word 0x4faee996 // sdot v22.4s, v12.16b, v14.4b[3]\n" - "ldr q12, [%[lhs_ptr], #-32]\n" - - ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" - ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" - "ldr q14, [%[rhs_ptr], #-32]\n" - ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" - ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" - - "blt 81b\n" - - ".word 0x4f87e098 // sdot v24.4s, v4.16b, v7.4b[0]\n" - ".word 0x4fa7e09a // sdot v26.4s, v4.16b, v7.4b[1]\n" - ".word 0x4f87e89c // sdot v28.4s, v4.16b, v7.4b[2]\n" - ".word 0x4fa7e89e // sdot v30.4s, v4.16b, v7.4b[3]\n" - ".word 0x4f86e0b1 // sdot v17.4s, v5.16b, v6.4b[0]\n" - ".word 0x4fa6e0b3 // sdot v19.4s, v5.16b, v6.4b[1]\n" - ".word 0x4f86e8b5 // sdot v21.4s, v5.16b, v6.4b[2]\n" - ".word 0x4fa6e8b7 // sdot v23.4s, v5.16b, v6.4b[3]\n" - ".word 0x4f87e0b9 // sdot v25.4s, v5.16b, v7.4b[0]\n" - ".word 0x4fa7e0bb // sdot v27.4s, v5.16b, v7.4b[1]\n" - ".word 0x4f87e8bd // sdot v29.4s, v5.16b, v7.4b[2]\n" - ".word 0x4fa7e8bf // sdot v31.4s, v5.16b, v7.4b[3]\n" - ".word 0x4f86e090 // sdot v16.4s, v4.16b, v6.4b[0]\n" - ".word 0x4fa6e092 // sdot v18.4s, v4.16b, v6.4b[1]\n" - ".word 0x4f86e894 // sdot v20.4s, v4.16b, v6.4b[2]\n" - ".word 0x4fa6e896 // sdot v22.4s, v4.16b, v6.4b[3]\n" - - ".word 0x4f8be118 // sdot v24.4s, v8.16b, v11.4b[0]\n" - ".word 0x4fabe11a // sdot v26.4s, v8.16b, v11.4b[1]\n" - ".word 0x4f8be91c // sdot v28.4s, v8.16b, v11.4b[2]\n" - ".word 0x4fabe91e // sdot v30.4s, v8.16b, v11.4b[3]\n" - ".word 0x4f8ae131 // sdot v17.4s, v9.16b, v10.4b[0]\n" - ".word 0x4faae133 // sdot v19.4s, v9.16b, v10.4b[1]\n" - ".word 0x4f8ae935 // sdot v21.4s, v9.16b, v10.4b[2]\n" - ".word 0x4faae937 // sdot v23.4s, v9.16b, v10.4b[3]\n" - ".word 0x4f8be139 // sdot v25.4s, v9.16b, v11.4b[0]\n" - ".word 0x4fabe13b // sdot v27.4s, v9.16b, v11.4b[1]\n" - ".word 0x4f8be93d // sdot v29.4s, v9.16b, v11.4b[2]\n" - ".word 0x4fabe93f // sdot v31.4s, v9.16b, v11.4b[3]\n" - ".word 0x4f8ae110 // sdot v16.4s, v8.16b, v10.4b[0]\n" - ".word 0x4faae112 // sdot v18.4s, v8.16b, v10.4b[1]\n" - ".word 0x4f8ae914 // sdot v20.4s, v8.16b, v10.4b[2]\n" - ".word 0x4faae916 // sdot v22.4s, v8.16b, v10.4b[3]\n" - - ".word 0x4f8fe198 // sdot v24.4s, v12.16b, v15.4b[0]\n" - ".word 0x4fafe19a // sdot v26.4s, v12.16b, v15.4b[1]\n" - ".word 0x4f8fe99c // sdot v28.4s, v12.16b, v15.4b[2]\n" - ".word 0x4fafe99e // sdot v30.4s, v12.16b, v15.4b[3]\n" - ".word 0x4f8ee1b1 // sdot v17.4s, v13.16b, v14.4b[0]\n" - ".word 0x4faee1b3 // sdot v19.4s, v13.16b, v14.4b[1]\n" - ".word 0x4f8ee9b5 // sdot v21.4s, v13.16b, v14.4b[2]\n" - ".word 0x4faee9b7 // sdot v23.4s, v13.16b, v14.4b[3]\n" - ".word 0x4f8fe1b9 // sdot v25.4s, v13.16b, v15.4b[0]\n" - ".word 0x4fafe1bb // sdot v27.4s, v13.16b, v15.4b[1]\n" - ".word 0x4f8fe9bd // sdot v29.4s, v13.16b, v15.4b[2]\n" - ".word 0x4fafe9bf // sdot v31.4s, v13.16b, v15.4b[3]\n" - ".word 0x4f8ee190 // sdot v16.4s, v12.16b, v14.4b[0]\n" - ".word 0x4faee192 // sdot v18.4s, v12.16b, v14.4b[1]\n" - ".word 0x4f8ee994 // sdot v20.4s, v12.16b, v14.4b[2]\n" - ".word 0x4faee996 // sdot v22.4s, v12.16b, v14.4b[3]\n" - - "78:\n" - -#endif // #if RUY_OPT_ENABLED(RUY_OPT_MAX_STREAMING) - - // Ordinary kernel inner loop (over depth), the simpler loop that the - // above was an equivalent 4x-partially-unrolled version of. - - // Reminder - w1 is how many levels of depth we have already loaded - // data for, w12 is the total depth. - "cmp w1, w12\n" - "beq 79f\n" - - "2:\n" - - // Because of the data that we have already loaded, we can start the - // loop body right away with some multiply-adds. - ".word 0x4f83e018 // sdot v24.4s, v0.16b, v3.4b[0]\n" - ".word 0x4fa3e01a // sdot v26.4s, v0.16b, v3.4b[1]\n" - // Each iteration of this loop advances by 4 levels of depth. - "add w1, w1, #4\n" - ".word 0x4f83e81c // sdot v28.4s, v0.16b, v3.4b[2]\n" - ".word 0x4fa3e81e // sdot v30.4s, v0.16b, v3.4b[3]\n" - "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" - ".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n" - ".word 0x4fa2e033 // sdot v19.4s, v1.16b, v2.4b[1]\n" - // Loop termination condition. - "cmp w1, w12\n" - ".word 0x4f82e835 // sdot v21.4s, v1.16b, v2.4b[2]\n" - ".word 0x4fa2e837 // sdot v23.4s, v1.16b, v2.4b[3]\n" - "ld1 {v2.16b}, [%[rhs_ptr]], #16\n" - ".word 0x4f83e039 // sdot v25.4s, v1.16b, v3.4b[0]\n" - ".word 0x4fa3e03b // sdot v27.4s, v1.16b, v3.4b[1]\n" - ".word 0x4f83e83d // sdot v29.4s, v1.16b, v3.4b[2]\n" - ".word 0x4fa3e83f // sdot v31.4s, v1.16b, v3.4b[3]\n" - "ld1 {v3.16b}, [%[rhs_ptr]], #16\n" - ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" - ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" - ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" - ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" - "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" - - "blt 2b\n" - - "79:\n" - // End of the inner loop on depth. Now perform the remaining - // multiply-adds of the last 4 levels of depth, for which the LHS - // and RHS data is already loaded. - - ".word 0x4f83e018 // sdot v24.4s, v0.16b, v3.4b[0]\n" - ".word 0x4fa3e01a // sdot v26.4s, v0.16b, v3.4b[1]\n" - ".word 0x4f83e81c // sdot v28.4s, v0.16b, v3.4b[2]\n" - ".word 0x4fa3e81e // sdot v30.4s, v0.16b, v3.4b[3]\n" - ".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n" - ".word 0x4fa2e033 // sdot v19.4s, v1.16b, v2.4b[1]\n" - ".word 0x4f82e835 // sdot v21.4s, v1.16b, v2.4b[2]\n" - ".word 0x4fa2e837 // sdot v23.4s, v1.16b, v2.4b[3]\n" - ".word 0x4f83e039 // sdot v25.4s, v1.16b, v3.4b[0]\n" - ".word 0x4fa3e03b // sdot v27.4s, v1.16b, v3.4b[1]\n" - ".word 0x4f83e83d // sdot v29.4s, v1.16b, v3.4b[2]\n" - ".word 0x4fa3e83f // sdot v31.4s, v1.16b, v3.4b[3]\n" - - // End of accumulation. The registers v16 -- v31 contain the final - // int32 accumulator values of the current 8x8 destination block. - // We now have to compute the final 8-bit values from these int32 - // accumulators, and advance to the next 8x8 block. We intertwine - // these two aspects whenever possible for optimal pipelining, both - // at the data flow level (prefetch data for next block as early as - // possible) and instruction pipelining level (some of the next-block - // work can dual-issue with some of the final work on the current - // block). - - // Logic to advance to the next block in preparation for the next - // iteration of the main loop. For now, we only want to compute - // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are - // not yet ready to update the values of row and col, as we still need - // the current values for the rest of the work on the current block. - - "cmp %w[row], w7\n" // Have we finished the last row? - "bge 4f\n" // If finished last row, go to 4 - // Not finished last row: then advance to next row. - "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #3\n" - "b 5f\n" - "4:\n" // Finished last row... - "mov %[lhs_col_ptr], x5\n" // Go back to first row - // Now we need to advance to the next column. If we already - // finished the last column, then in principle we are done, however - // we can't just return here, as we need to allow the end work of the - // current block to complete. The good news is that at this point it - // doesn't matter what data we load for the next column, since - // we will exit from the main loop below before actually storing - // anything computed from that data. - "cmp %w[col], w8\n" // Have we finished the last column? - "bge 5f\n" // If yes, just carry on without updating the column pointer. - // Not finished last column: then advance to next column. - "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #3\n" - "5:\n" - - // Set the LHS and RHS data pointers to the start of the columns just - // computed. - "mov %[lhs_ptr], %[lhs_col_ptr]\n" - "mov %[rhs_ptr], %[rhs_col_ptr]\n" - - // Load some parameters needed for the end work on current block. - RUY_MAKE_ZERO(v8) - "ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" - "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n" - "ins v13.h[4], w4\n" // dst_zero_point - "ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n" - "ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" - "dup v9.4s, w3\n" // create prod_zp_depth_vec - "add x5, x4, %x[row], lsl #2\n" - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" - "csel x4, x4, x5, eq\n" - - "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" - "add x5, x1, %x[row], lsl #2\n" - - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" - "csel x1, x1, x5, eq\n" - - // Load 8 bias values. - "ld1 {v14.4s}, [x1], #16\n" - "ld1 {v15.4s}, [x1]\n" - - // Now that we know what LHS and RHS data the next iteration of the - // main loop will need to load, we start loading the first 32 bytes of - // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore - // in the rest of the work on the current block. - "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v2.16b}, [%[rhs_ptr]], #16\n" - "ld1 {v3.16b}, [%[rhs_ptr]], #16\n" - - // Add to the bias values the product (depth * lhs_zero_point * rhs_zero_point), - // See the term NZ1Z2 in equation (7) in https://arxiv.org/pdf/1712.05877.pdf - "add v14.4s, v14.4s, v9.4s\n" - "add v15.4s, v15.4s, v9.4s\n" - - // Perform the bias-addition (per the above, we have just folded into - // the bias the (depth * lhs_zero_point * rhs_zero_point) term.) - "add v16.4s, v16.4s, v14.4s\n" - "add v17.4s, v17.4s, v15.4s\n" - "add v18.4s, v18.4s, v14.4s\n" - "add v19.4s, v19.4s, v15.4s\n" - "add v20.4s, v20.4s, v14.4s\n" - "add v21.4s, v21.4s, v15.4s\n" - "add v22.4s, v22.4s, v14.4s\n" - "add v23.4s, v23.4s, v15.4s\n" - "add v24.4s, v24.4s, v14.4s\n" - "add v25.4s, v25.4s, v15.4s\n" - "add v26.4s, v26.4s, v14.4s\n" - "add v27.4s, v27.4s, v15.4s\n" - "add v28.4s, v28.4s, v14.4s\n" - "add v29.4s, v29.4s, v15.4s\n" - "add v30.4s, v30.4s, v14.4s\n" - "add v31.4s, v31.4s, v15.4s\n" - - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n" - "beq 401f\n" - "ldr x3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n" - "add x3, x3, %x[col], lsl #2\n" - "ld1 {v14.4s}, [x3], #16\n" - "ld1 {v15.4s}, [x3]\n" - "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n" - "dup v10.4s, w5\n" // create lhs_zero_point_vec - // Subtract rhs_sums * lhs_zero_point, per - // equation (7) in https://arxiv.org/pdf/1712.05877.pdf - "mls v16.4s, v10.4s, v14.s[0]\n" - "mls v17.4s, v10.4s, v14.s[0]\n" - "mls v18.4s, v10.4s, v14.s[1]\n" - "mls v19.4s, v10.4s, v14.s[1]\n" - "mls v20.4s, v10.4s, v14.s[2]\n" - "mls v21.4s, v10.4s, v14.s[2]\n" - "mls v22.4s, v10.4s, v14.s[3]\n" - "mls v23.4s, v10.4s, v14.s[3]\n" - "mls v24.4s, v10.4s, v15.s[0]\n" - "mls v25.4s, v10.4s, v15.s[0]\n" - "mls v26.4s, v10.4s, v15.s[1]\n" - "mls v27.4s, v10.4s, v15.s[1]\n" - "mls v28.4s, v10.4s, v15.s[2]\n" - "mls v29.4s, v10.4s, v15.s[2]\n" - "mls v30.4s, v10.4s, v15.s[3]\n" - "mls v31.4s, v10.4s, v15.s[3]\n" - "401:\n" - - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n" - "beq 402f\n" - "ldr x2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n" - "add x2, x2, %x[row], lsl #2\n" - "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n" - // Load 4 lhs_sums values. - "ld1 {v11.4s}, [x2], #16\n" - "ld1 {v12.4s}, [x2]\n" - "ins v13.s[1], w5\n" // rhs_zero_point - // Compute lhs_sums * rhs_zero_point. - "mul v11.4s, v11.4s, v13.s[1]\n" - "mul v12.4s, v12.4s, v13.s[1]\n" - // Subtract lhs_sums * rhs_zero_point, per - // equation (7) in https://arxiv.org/pdf/1712.05877.pdf - "sub v16.4s, v16.4s, v11.4s\n" - "sub v17.4s, v17.4s, v12.4s\n" - "sub v18.4s, v18.4s, v11.4s\n" - "sub v19.4s, v19.4s, v12.4s\n" - "sub v20.4s, v20.4s, v11.4s\n" - "sub v21.4s, v21.4s, v12.4s\n" - "sub v22.4s, v22.4s, v11.4s\n" - "sub v23.4s, v23.4s, v12.4s\n" - "sub v24.4s, v24.4s, v11.4s\n" - "sub v25.4s, v25.4s, v12.4s\n" - "sub v26.4s, v26.4s, v11.4s\n" - "sub v27.4s, v27.4s, v12.4s\n" - "sub v28.4s, v28.4s, v11.4s\n" - "sub v29.4s, v29.4s, v12.4s\n" - "sub v30.4s, v30.4s, v11.4s\n" - "sub v31.4s, v31.4s, v12.4s\n" - - "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n" - "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n" - - "402:\n" - - // At this point we have computed the final int32 values. Now we - // start down-quantizing them to obtain the final 8bit values from them. - - // As part of this down-quantization, our int32 values will be - // multiplied by a multiplier that has a fixed-point component and an - // exponent component. - - //Load the exponent part of the multiplier. - "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n" - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" - "add x5, x1, %x[row], lsl #2\n" - "csel x1, x1, x5, eq\n" - - "ldr q9, [x1]\n" - "ldr q10, [x1, #16]\n" - - "tst w6, #" RUY_STR(RUY_ASM_FLAG_NEEDS_LEFT_SHIFT) "\n" - "beq 403f\n" - "smax v11.4s, v9.4s, v8.4s\n" - "smax v12.4s, v10.4s, v8.4s\n" - "sshl v16.4s, v16.4s, v11.4s\n" - "sshl v17.4s, v17.4s, v12.4s\n" - "sshl v18.4s, v18.4s, v11.4s\n" - "sshl v19.4s, v19.4s, v12.4s\n" - "sshl v20.4s, v20.4s, v11.4s\n" - "sshl v21.4s, v21.4s, v12.4s\n" - "sshl v22.4s, v22.4s, v11.4s\n" - "sshl v23.4s, v23.4s, v12.4s\n" - "sshl v24.4s, v24.4s, v11.4s\n" - "sshl v25.4s, v25.4s, v12.4s\n" - "sshl v26.4s, v26.4s, v11.4s\n" - "sshl v27.4s, v27.4s, v12.4s\n" - "sshl v28.4s, v28.4s, v11.4s\n" - "sshl v29.4s, v29.4s, v12.4s\n" - "sshl v30.4s, v30.4s, v11.4s\n" - "sshl v31.4s, v31.4s, v12.4s\n" - "403:\n" - - "ldr q14, [x4]\n" // multiplier_fixedpoint - "ldr q15, [x4, #16]\n" // multiplier_fixedpoint - - "smin v11.4s, v9.4s, v8.4s\n" - "smin v12.4s, v10.4s, v8.4s\n" - - // Apply the fixed-point part of the multiplier. - "sqrdmulh v16.4s, v16.4s, v14.4s\n" - "sqrdmulh v17.4s, v17.4s, v15.4s\n" - "sqrdmulh v18.4s, v18.4s, v14.4s\n" - "sqrdmulh v19.4s, v19.4s, v15.4s\n" - "sqrdmulh v20.4s, v20.4s, v14.4s\n" - "sqrdmulh v21.4s, v21.4s, v15.4s\n" - "sqrdmulh v22.4s, v22.4s, v14.4s\n" - "sqrdmulh v23.4s, v23.4s, v15.4s\n" - "sqrdmulh v24.4s, v24.4s, v14.4s\n" - "sqrdmulh v25.4s, v25.4s, v15.4s\n" - "sqrdmulh v26.4s, v26.4s, v14.4s\n" - "sqrdmulh v27.4s, v27.4s, v15.4s\n" - "sqrdmulh v28.4s, v28.4s, v14.4s\n" - "sqrdmulh v29.4s, v29.4s, v15.4s\n" - "sqrdmulh v30.4s, v30.4s, v14.4s\n" - "sqrdmulh v31.4s, v31.4s, v15.4s\n" - - // We have some rounding division-by-power-of-two to do. This should - // always use "round to nearest". We allow for some - // freedom in how ties are broken, to strike a good compromise of - // performance on given hardware vs. perfect agreement of results - // across hardware. - // - // When RUY_OPT_NATIVE_ROUNDING is enabled, we allow for implementation - // defined tie-breaks to help performance. On NEON, this means that we - // can just use the NEON rounding instructions, such as srshl. They - // happen to be breaking ties upward. - // - // When RUY_OPT_NATIVE_ROUNDING is disabled, we implement strict - // break-ties-away-from zero, as described in Appendix B of - // https://arxiv.org/pdf/1712.05877.pdf - // When we wrote that, we thought that that would be better unbiased - // than the NEON upwards tie-breaks, and we had observed some - // improvement on some model. However, that is only more unbiased for - // data centered at zero, which was likely the case in that model, - // but is not always the case. If we wanted something more consistently - // unbiased then we should try breaking ties toward-nearest-even. -#if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING) - // Fix up values to be right-shifted, so that the (round to nearest, - // break ties upward) behavior of srshl applied to these fixed-up - // values, produces the same result as the desired (round to nearest, - // break ties away from zero) behavior on the original values. - "and v8.16b, v16.16b, v11.16b\n" - "and v9.16b, v17.16b, v12.16b\n" - "and v14.16b, v18.16b, v11.16b\n" - "and v15.16b, v19.16b, v12.16b\n" - "sshr v8.4s, v8.4s, #31\n" - "sshr v9.4s, v9.4s, #31\n" - "sshr v14.4s, v14.4s, #31\n" - "sshr v15.4s, v15.4s, #31\n" - "sqadd v16.4s, v16.4s, v8.4s\n" - "sqadd v17.4s, v17.4s, v9.4s\n" - "sqadd v18.4s, v18.4s, v14.4s\n" - "sqadd v19.4s, v19.4s, v15.4s\n" - "and v8.16b, v20.16b, v11.16b\n" - "and v9.16b, v21.16b, v12.16b\n" - "and v14.16b, v22.16b, v11.16b\n" - "and v15.16b, v23.16b, v12.16b\n" - "sshr v8.4s, v8.4s, #31\n" - "sshr v9.4s, v9.4s, #31\n" - "sshr v14.4s, v14.4s, #31\n" - "sshr v15.4s, v15.4s, #31\n" - "sqadd v20.4s, v20.4s, v8.4s\n" - "sqadd v21.4s, v21.4s, v9.4s\n" - "sqadd v22.4s, v22.4s, v14.4s\n" - "sqadd v23.4s, v23.4s, v15.4s\n" - "and v8.16b, v24.16b, v11.16b\n" - "and v9.16b, v25.16b, v12.16b\n" - "and v14.16b, v26.16b, v11.16b\n" - "and v15.16b, v27.16b, v12.16b\n" - "sshr v8.4s, v8.4s, #31\n" - "sshr v9.4s, v9.4s, #31\n" - "sshr v14.4s, v14.4s, #31\n" - "sshr v15.4s, v15.4s, #31\n" - "sqadd v24.4s, v24.4s, v8.4s\n" - "sqadd v25.4s, v25.4s, v9.4s\n" - "sqadd v26.4s, v26.4s, v14.4s\n" - "sqadd v27.4s, v27.4s, v15.4s\n" - "and v8.16b, v28.16b, v11.16b\n" - "and v9.16b, v29.16b, v12.16b\n" - "and v14.16b, v30.16b, v11.16b\n" - "and v15.16b, v31.16b, v12.16b\n" - "sshr v8.4s, v8.4s, #31\n" - "sshr v9.4s, v9.4s, #31\n" - "sshr v14.4s, v14.4s, #31\n" - "sshr v15.4s, v15.4s, #31\n" - "sqadd v28.4s, v28.4s, v8.4s\n" - "sqadd v29.4s, v29.4s, v9.4s\n" - "sqadd v30.4s, v30.4s, v14.4s\n" - "sqadd v31.4s, v31.4s, v15.4s\n" -#endif - // At this point we have reduced the problem of correctly implementing - // rounding divide-by-power-of-two, to what the SRSHL instruction can - // do. - "srshl v16.4s, v16.4s, v11.4s\n" - "srshl v17.4s, v17.4s, v12.4s\n" - "srshl v18.4s, v18.4s, v11.4s\n" - "srshl v19.4s, v19.4s, v12.4s\n" - "srshl v20.4s, v20.4s, v11.4s\n" - "srshl v21.4s, v21.4s, v12.4s\n" - "srshl v22.4s, v22.4s, v11.4s\n" - "srshl v23.4s, v23.4s, v12.4s\n" - "srshl v24.4s, v24.4s, v11.4s\n" - "srshl v25.4s, v25.4s, v12.4s\n" - "srshl v26.4s, v26.4s, v11.4s\n" - "srshl v27.4s, v27.4s, v12.4s\n" - "srshl v28.4s, v28.4s, v11.4s\n" - "srshl v29.4s, v29.4s, v12.4s\n" - "srshl v30.4s, v30.4s, v11.4s\n" - "srshl v31.4s, v31.4s, v12.4s\n" - - "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n" - "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n" - "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n" - "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n" - - RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n" - - // Cast-and-saturate from int32 to int16 - "sqxtn v16.4h, v16.4s\n" - "sqxtn2 v16.8h, v17.4s\n" - "sqxtn v17.4h, v18.4s\n" - "sqxtn2 v17.8h, v19.4s\n" - "sqxtn v18.4h, v20.4s\n" - "sqxtn2 v18.8h, v21.4s\n" - "sqxtn v19.4h, v22.4s\n" - "sqxtn2 v19.8h, v23.4s\n" - "sqxtn v20.4h, v24.4s\n" - "sqxtn2 v20.8h, v25.4s\n" - "sqxtn v21.4h, v26.4s\n" - "sqxtn2 v21.8h, v27.4s\n" - "sqxtn v22.4h, v28.4s\n" - "sqxtn2 v22.8h, v29.4s\n" - "sqxtn v23.4h, v30.4s\n" - "sqxtn2 v23.8h, v31.4s\n" - - // At this point, v24 -- v31 aren't used anymore for the current block, - // so we can start clearing these accumulators for the next block - // (next iteration of the main loop). - RUY_MAKE_ZERO(v24) - RUY_MAKE_ZERO(v25) - RUY_MAKE_ZERO(v26) - RUY_MAKE_ZERO(v27) - RUY_MAKE_ZERO(v28) - RUY_MAKE_ZERO(v29) - RUY_MAKE_ZERO(v30) - RUY_MAKE_ZERO(v31) - - // Add the destination zero point - "dup v14.8h, v13.h[4]\n" - "add v16.8h, v16.8h, v14.8h\n" - "add v17.8h, v17.8h, v14.8h\n" - "add v18.8h, v18.8h, v14.8h\n" - "add v19.8h, v19.8h, v14.8h\n" - "add v20.8h, v20.8h, v14.8h\n" - "add v21.8h, v21.8h, v14.8h\n" - "add v22.8h, v22.8h, v14.8h\n" - "add v23.8h, v23.8h, v14.8h\n" - - // Cast-and-saturate from int16 to uint8 - "sqxtun v16.8b, v16.8h\n" - "sqxtun2 v16.16b, v17.8h\n" - "sqxtun v17.8b, v18.8h\n" - "sqxtun2 v17.16b, v19.8h\n" - "sqxtun v18.8b, v20.8h\n" - "sqxtun2 v18.16b, v21.8h\n" - "sqxtun v19.8b, v22.8h\n" - "sqxtun2 v19.16b, v23.8h\n" - - // Load the clamp_min, clamp_max bounds - "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" - "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" - "dup v14.16b, w2\n" // clamp_min - "dup v15.16b, w3\n" // clamp_max - - // Apply the clamp_min bound - "umax v16.16b, v16.16b, v14.16b\n" - "umax v17.16b, v17.16b, v14.16b\n" - "umax v18.16b, v18.16b, v14.16b\n" - "umax v19.16b, v19.16b, v14.16b\n" - - // Apply the clamp_max bound - "umin v16.16b, v16.16b, v15.16b\n" - "umin v17.16b, v17.16b, v15.16b\n" - "umin v18.16b, v18.16b, v15.16b\n" - "umin v19.16b, v19.16b, v15.16b\n" - - // Make it so that all of the final 8bit values are stored in the - // first 64bits of 128bit NEON registers, so they can be stored - // by 64bit st1 store instructions with byte alignment. - "dup d20, v16.d[1]\n" - "dup d21, v17.d[1]\n" - "dup d22, v18.d[1]\n" - "dup d23, v19.d[1]\n" - - // Compute how much of the 8x8 block of destination 8bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 8x8, there are some 8x8 blocks along the boundaries that do - // not fit entirely. - "sub w1, %w[dst_rows], %w[row]\n" - "sub w2, %w[dst_cols], %w[col]\n" - "mov w3, #8\n" - "cmp w1, #8\n" - // Compute w1 = how many rows of the 8x8 block fit - "csel w1, w1, w3, le\n" - "cmp w2, #8\n" - // Compute w2 = how many cols of the 8x8 block fit - "csel w2, w2, w3, le\n" - - // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. - "cmp w1, w3\n" - "ccmp w2, w3, 0, eq\n" - // Yes, all of the 8x8 block fits, go to fast path. - "beq 30f\n" - // Not all of the 8x8 block fits. - // Set (x3 address, x4 stride) to write to dst_tmp_buf - "mov x3, %[dst_tmp_buf]\n" - "mov x4, #8\n" - "b 31f\n" - "30:\n" - // Yes, all of the 8x8 block fits. - // Set (x3 address, x4 stride) to write directly to destination matrix. - "mov x3, %[dst_ptr]\n" - "mov x4, x11\n" - "31:\n" - - // Write our 8bit values to the destination described by - // (x3 address, x4 stride). - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v16.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v16) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v20.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v20) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v17.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v17) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v21.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v21) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v18.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v18) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v22.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v22) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v19.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v19) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v23.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v23) - - // For the next block: perform the first few multiply-adds on the data - // that we have already loaded. - ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" - ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" - ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" - ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" - - // If all of the 8x8 block fits, we just finished writing it to the - // destination, so we skip the next part. - "beq 41f\n" - // Not all of the 8x8 block fits in the destination matrix. We just - // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over - // it to copy into the destination matrix the part that fits. - "mov x3, %[dst_tmp_buf]\n" - "mov x4, %[dst_ptr]\n" - "mov w6, #0\n" - "50:\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov w5, #0\n" - "51:\n" - "ldrb w7, [x3, w5, uxtw]\n" - "strb w7, [x4, w5, uxtw]\n" - "add w5, w5, #1\n" - "cmp w5, w1\n" - "blt 51b\n" - "add w6, w6, #1\n" - "add x3, x3, #8\n" - "add x4, x4, x11\n" - "cmp w6, w2\n" - "blt 50b\n" - "41:\n" - "add %[dst_ptr], %[dst_ptr], #8\n" - // At this point we have completely finished writing values to the - // destination matrix for the current block. - - "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" - - RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n" - - // Cast-and-saturate from int32 to int16 - "sqxtn v16.4h, v16.4s\n" - "sqxtn2 v16.8h, v17.4s\n" - "sqxtn v17.4h, v18.4s\n" - "sqxtn2 v17.8h, v19.4s\n" - "sqxtn v18.4h, v20.4s\n" - "sqxtn2 v18.8h, v21.4s\n" - "sqxtn v19.4h, v22.4s\n" - "sqxtn2 v19.8h, v23.4s\n" - "sqxtn v20.4h, v24.4s\n" - "sqxtn2 v20.8h, v25.4s\n" - "sqxtn v21.4h, v26.4s\n" - "sqxtn2 v21.8h, v27.4s\n" - "sqxtn v22.4h, v28.4s\n" - "sqxtn2 v22.8h, v29.4s\n" - "sqxtn v23.4h, v30.4s\n" - "sqxtn2 v23.8h, v31.4s\n" - - // At this point, v24 -- v31 aren't used anymore for the current block, - // so we can start clearing these accumulators for the next block - // (next iteration of the main loop). - RUY_MAKE_ZERO(v24) - RUY_MAKE_ZERO(v25) - RUY_MAKE_ZERO(v26) - RUY_MAKE_ZERO(v27) - RUY_MAKE_ZERO(v28) - RUY_MAKE_ZERO(v29) - RUY_MAKE_ZERO(v30) - RUY_MAKE_ZERO(v31) - - // Add the destination zero point - "dup v14.8h, v13.h[4]\n" - "add v16.8h, v16.8h, v14.8h\n" - "add v17.8h, v17.8h, v14.8h\n" - "add v18.8h, v18.8h, v14.8h\n" - "add v19.8h, v19.8h, v14.8h\n" - "add v20.8h, v20.8h, v14.8h\n" - "add v21.8h, v21.8h, v14.8h\n" - "add v22.8h, v22.8h, v14.8h\n" - "add v23.8h, v23.8h, v14.8h\n" - - // Cast-and-saturate from int16 to uint8 - "sqxtn v16.8b, v16.8h\n" - "sqxtn2 v16.16b, v17.8h\n" - "sqxtn v17.8b, v18.8h\n" - "sqxtn2 v17.16b, v19.8h\n" - "sqxtn v18.8b, v20.8h\n" - "sqxtn2 v18.16b, v21.8h\n" - "sqxtn v19.8b, v22.8h\n" - "sqxtn2 v19.16b, v23.8h\n" - - // Load the clamp_min, clamp_max bounds - "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" - "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" - "dup v14.16b, w2\n" // clamp_min - "dup v15.16b, w3\n" // clamp_max - - // Apply the clamp_min bound - "smax v16.16b, v16.16b, v14.16b\n" - "smax v17.16b, v17.16b, v14.16b\n" - "smax v18.16b, v18.16b, v14.16b\n" - "smax v19.16b, v19.16b, v14.16b\n" - - // Apply the clamp_max bound - "smin v16.16b, v16.16b, v15.16b\n" - "smin v17.16b, v17.16b, v15.16b\n" - "smin v18.16b, v18.16b, v15.16b\n" - "smin v19.16b, v19.16b, v15.16b\n" - - // Make it so that all of the final 8bit values are stored in the - // first 64bits of 128bit NEON registers, so they can be stored - // by 64bit st1 store instructions with byte alignment. - "dup d20, v16.d[1]\n" - "dup d21, v17.d[1]\n" - "dup d22, v18.d[1]\n" - "dup d23, v19.d[1]\n" - - // Compute how much of the 8x8 block of destination 8bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 8x8, there are some 8x8 blocks along the boundaries that do - // not fit entirely. - "sub w1, %w[dst_rows], %w[row]\n" - "sub w2, %w[dst_cols], %w[col]\n" - "mov w3, #8\n" - "cmp w1, #8\n" - // Compute w1 = how many rows of the 8x8 block fit - "csel w1, w1, w3, le\n" - "cmp w2, #8\n" - // Compute w2 = how many cols of the 8x8 block fit - "csel w2, w2, w3, le\n" - - // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. - "cmp w1, w3\n" - "ccmp w2, w3, 0, eq\n" - // Yes, all of the 8x8 block fits, go to fast path. - "beq 130f\n" - // Not all of the 8x8 block fits. - // Set (x3 address, x4 stride) to write to dst_tmp_buf - "mov x3, %[dst_tmp_buf]\n" - "mov x4, #8\n" - "b 131f\n" - "130:\n" - // Yes, all of the 8x8 block fits. - // Set (x3 address, x4 stride) to write directly to destination matrix. - "mov x3, %[dst_ptr]\n" - "mov x4, x11\n" - "131:\n" - - // Write our 8bit values to the destination described by - // (x3 address, x4 stride). - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v16.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v16) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v20.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v20) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v17.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v17) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v21.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v21) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v18.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v18) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v22.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v22) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v19.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v19) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v23.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v23) - - // For the next block: perform the first few multiply-adds on the data - // that we have already loaded. - ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" - ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" - ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" - ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" - - // If all of the 8x8 block fits, we just finished writing it to the - // destination, so we skip the next part. - "beq 141f\n" - // Not all of the 8x8 block fits in the destination matrix. We just - // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over - // it to copy into the destination matrix the part that fits. - "mov x3, %[dst_tmp_buf]\n" - "mov x4, %[dst_ptr]\n" - "mov w6, #0\n" - "150:\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov w5, #0\n" - "151:\n" - "ldrb w7, [x3, w5, uxtw]\n" - "strb w7, [x4, w5, uxtw]\n" - "add w5, w5, #1\n" - "cmp w5, w1\n" - "blt 151b\n" - "add w6, w6, #1\n" - "add x3, x3, #8\n" - "add x4, x4, x11\n" - "cmp w6, w2\n" - "blt 150b\n" - "141:\n" - "add %[dst_ptr], %[dst_ptr], #8\n" - // At this point we have completely finished writing values to the - // destination matrix for the current block. - - "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" - - RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n" - - // Add the destination zero point - "dup v14.8h, v13.h[4]\n" - "saddw v16.4s, v16.4s, v14.4h\n" - "saddw v17.4s, v17.4s, v14.4h\n" - "saddw v18.4s, v18.4s, v14.4h\n" - "saddw v19.4s, v19.4s, v14.4h\n" - "saddw v20.4s, v20.4s, v14.4h\n" - "saddw v21.4s, v21.4s, v14.4h\n" - "saddw v22.4s, v22.4s, v14.4h\n" - "saddw v23.4s, v23.4s, v14.4h\n" - "saddw v24.4s, v24.4s, v14.4h\n" - "saddw v25.4s, v25.4s, v14.4h\n" - "saddw v26.4s, v26.4s, v14.4h\n" - "saddw v27.4s, v27.4s, v14.4h\n" - "saddw v28.4s, v28.4s, v14.4h\n" - "saddw v29.4s, v29.4s, v14.4h\n" - "saddw v30.4s, v30.4s, v14.4h\n" - "saddw v31.4s, v31.4s, v14.4h\n" - - // Cast-and-saturate from int32 to int16 - "sqxtn v16.4h, v16.4s\n" - "sqxtn2 v16.8h, v17.4s\n" - "sqxtn v17.4h, v18.4s\n" - "sqxtn2 v17.8h, v19.4s\n" - "sqxtn v18.4h, v20.4s\n" - "sqxtn2 v18.8h, v21.4s\n" - "sqxtn v19.4h, v22.4s\n" - "sqxtn2 v19.8h, v23.4s\n" - "sqxtn v20.4h, v24.4s\n" - "sqxtn2 v20.8h, v25.4s\n" - "sqxtn v21.4h, v26.4s\n" - "sqxtn2 v21.8h, v27.4s\n" - "sqxtn v22.4h, v28.4s\n" - "sqxtn2 v22.8h, v29.4s\n" - "sqxtn v23.4h, v30.4s\n" - "sqxtn2 v23.8h, v31.4s\n" - - // At this point, v24 -- v31 aren't used anymore for the current block, - // so we can start clearing these accumulators for the next block - // (next iteration of the main loop). - RUY_MAKE_ZERO(v24) - RUY_MAKE_ZERO(v25) - RUY_MAKE_ZERO(v26) - RUY_MAKE_ZERO(v27) - RUY_MAKE_ZERO(v28) - RUY_MAKE_ZERO(v29) - RUY_MAKE_ZERO(v30) - RUY_MAKE_ZERO(v31) - - // Load the clamp_min, clamp_max bounds - "ldrsh w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" - "ldrsh w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" - "dup v14.8h, w2\n" // clamp_min - "dup v15.8h, w3\n" // clamp_max - - // Apply the clamp_min bound - "smax v16.8h, v16.8h, v14.8h\n" - "smax v17.8h, v17.8h, v14.8h\n" - "smax v18.8h, v18.8h, v14.8h\n" - "smax v19.8h, v19.8h, v14.8h\n" - "smax v20.8h, v20.8h, v14.8h\n" - "smax v21.8h, v21.8h, v14.8h\n" - "smax v22.8h, v22.8h, v14.8h\n" - "smax v23.8h, v23.8h, v14.8h\n" - // Apply the clamp_max bound - "smin v16.8h, v16.8h, v15.8h\n" - "smin v17.8h, v17.8h, v15.8h\n" - "smin v18.8h, v18.8h, v15.8h\n" - "smin v19.8h, v19.8h, v15.8h\n" - "smin v20.8h, v20.8h, v15.8h\n" - "smin v21.8h, v21.8h, v15.8h\n" - "smin v22.8h, v22.8h, v15.8h\n" - "smin v23.8h, v23.8h, v15.8h\n" - - // Compute how much of the 8x8 block of destination 16bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 8x8, there are some 8x8 blocks along the boundaries that do - // not fit entirely. - "sub w1, %w[dst_rows], %w[row]\n" - "sub w2, %w[dst_cols], %w[col]\n" - "mov w3, #8\n" - "cmp w1, #8\n" - // Compute w1 = how many rows of the 8x8 block fit - "csel w1, w1, w3, le\n" - "cmp w2, #8\n" - // Compute w1 = how many rows of the 8x8 block fit - "csel w2, w2, w3, le\n" - - // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. - "cmp w1, w3\n" - "ccmp w2, w3, 0, eq\n" - // Yes, all of the 8x8 block fits, go to fast path. - "beq 230f\n" - // Not all of the 8x8 block fits. - // Set (x3 address, x4 stride) to write to dst_tmp_buf - "mov x3, %[dst_tmp_buf]\n" - "mov x4, #16\n" - "b 231f\n" - "230:\n" - // Yes, all of the 8x8 block fits. - // Set (x3 address, x4 stride) to write directly to destination matrix. - "mov x3, %[dst_ptr]\n" - "mov x4, x11\n" - "231:\n" - - // Write our 16bit values to the destination described by - // (x3 address, x4 stride). - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v16.8h}, [x3], x4\n" - RUY_MAKE_ZERO(v16) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v17.8h}, [x3], x4\n" - RUY_MAKE_ZERO(v17) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v18.8h}, [x3], x4\n" - RUY_MAKE_ZERO(v18) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v19.8h}, [x3], x4\n" - RUY_MAKE_ZERO(v19) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v20.8h}, [x3], x4\n" - RUY_MAKE_ZERO(v20) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v21.8h}, [x3], x4\n" - RUY_MAKE_ZERO(v21) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v22.8h}, [x3], x4\n" - RUY_MAKE_ZERO(v22) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v23.8h}, [x3], x4\n" - RUY_MAKE_ZERO(v23) - - // For the next block: perform the first few multiply-adds on the data - // that we have already loaded. - ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" - ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" - ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" - ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" - - // If all of the 8x8 block fits, we just finished writing it to the - // destination, so we skip the next part. - "beq 241f\n" - // Not all of the 8x8 block fits in the destination matrix. We just - // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over - // it to copy into the destination matrix the part that fits. - "mov x3, %[dst_tmp_buf]\n" - "mov x4, %[dst_ptr]\n" - "mov w6, #0\n" - "250:\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov w5, #0\n" - "251:\n" - "ldrsh w7, [x3, x5, lsl #1]\n" - "strh w7, [x4, x5, lsl #1]\n" - "add w5, w5, #1\n" - "cmp w5, w1\n" - "blt 251b\n" - "add w6, w6, #1\n" - "add x3, x3, #16\n" - "add x4, x4, x11\n" - "cmp w6, w2\n" - "blt 250b\n" - "241:\n" - "add %[dst_ptr], %[dst_ptr], #16\n" - // At this point we have completely finished writing values to the - // destination matrix for the current block. - - "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" - - RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n" - - // Since the store type is the same as the accum type, no need for - // downcast. There's also no need for clamp by min/max. - - // Compute how much of the 8x8 block of destination 32it values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 8x8, there are some 8x8 blocks along the boundaries that do - // not fit entirely. - "sub w1, %w[dst_rows], %w[row]\n" - "sub w2, %w[dst_cols], %w[col]\n" - "mov w3, #8\n" - "cmp w1, #8\n" - // Compute w1 = how many rows of the 8x8 block fit - "csel w1, w1, w3, le\n" - "cmp w2, #8\n" - // Compute w1 = how many rows of the 8x8 block fit - "csel w2, w2, w3, le\n" - - // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. - "cmp w1, w3\n" - "ccmp w2, w3, 0, eq\n" - // Yes, all of the 8x8 block fits, go to fast path. - "beq 330f\n" - // Not all of the 8x8 block fits. - // Write to dst_tmp_buf - "mov x3, %[dst_tmp_buf]\n" - "st1 {v16.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v16) - "st1 {v17.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v17) - "st1 {v18.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v18) - "st1 {v19.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v19) - "st1 {v20.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v20) - "st1 {v21.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v21) - "st1 {v22.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v22) - "st1 {v23.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v23) - "st1 {v24.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v24) - "st1 {v25.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v25) - "st1 {v26.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v26) - "st1 {v27.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v27) - "st1 {v28.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v28) - "st1 {v29.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v29) - "st1 {v30.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v30) - "st1 {v31.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v31) - - "b 331f\n" - - "330:\n" - // Yes, all of the 8x8 block fits. - "mov x4, %[dst_ptr]\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v16.4s, v17.4s}, [x3], #32\n" - RUY_MAKE_ZERO(v16) - RUY_MAKE_ZERO(v17) - "add x4, x4, x11\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v18.4s, v19.4s}, [x3], #32\n" - RUY_MAKE_ZERO(v18) - RUY_MAKE_ZERO(v19) - "add x4, x4, x11\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v20.4s, v21.4s}, [x3], #32\n" - RUY_MAKE_ZERO(v20) - RUY_MAKE_ZERO(v21) - "add x4, x4, x11\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v22.4s, v23.4s}, [x3], #32\n" - RUY_MAKE_ZERO(v22) - RUY_MAKE_ZERO(v23) - "add x4, x4, x11\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v24.4s, v25.4s}, [x3], #32\n" - RUY_MAKE_ZERO(v24) - RUY_MAKE_ZERO(v25) - "add x4, x4, x11\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v26.4s, v27.4s}, [x3], #32\n" - RUY_MAKE_ZERO(v26) - RUY_MAKE_ZERO(v27) - "add x4, x4, x11\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v28.4s, v29.4s}, [x3], #32\n" - RUY_MAKE_ZERO(v28) - RUY_MAKE_ZERO(v29) - "add x4, x4, x11\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov x3, x4\n" - "st1 {v30.4s, v31.4s}, [x3], #32\n" - RUY_MAKE_ZERO(v30) - RUY_MAKE_ZERO(v31) - - "331:\n" - - // For the next block: perform the first few multiply-adds on the data - // that we have already loaded. - ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" - ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" - ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" - ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" - - // If all of the 8x8 block fits, we just finished writing it to the - // destination, so we skip the next part. - "beq 341f\n" - - // Not all of the 8x8 block fits in the destination matrix. We just - // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over - // it to copy into the destination matrix the part that fits. - "mov x3, %[dst_tmp_buf]\n" - "mov x4, %[dst_ptr]\n" - "mov w6, #0\n" - "350:\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov w5, #0\n" - "351:\n" - "ldr w7, [x3, x5, lsl #2]\n" - "str w7, [x4, x5, lsl #2]\n" - "add w5, w5, #1\n" - "cmp w5, w1\n" - "blt 351b\n" - "add w6, w6, #1\n" - "add x3, x3, #32\n" - "add x4, x4, x11\n" - "cmp w6, w2\n" - "blt 350b\n" - "341:\n" - "add %[dst_ptr], %[dst_ptr], #32\n" - // At this point we have completely finished writing values to the - // destination matrix for the current block. - - RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n" - - // Reload some params --- we had used x5 -- x7 for a few other things - // since the last time we had loaded them. - "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" - "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" - - // Move to the next block of the destination matrix, for the next iter - // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already - // been updated earlier. - // Have we reached the end row? - "cmp %w[row], w7\n" - "beq 20f\n" // yes, end row. - // Not end row. Move to the next row. - "add %w[row], %w[row], #8\n" - "b 21f\n" - "20:\n" - // Was already at end row. - "mov %w[row], w6\n" // Move back to first row. - "add %w[col], %w[col], #8\n" // Move to the next column. - "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #3\n" - "mov %[dst_ptr], %[dst_col_ptr]\n" - "21:\n" - - // Main loop exit condition: have we hit the end column? - "cmp %w[col], w8\n" - - // w1 is the number of levels of depth that we have already loaded - // LHS and RHS data for. Corresponding to the initial ld1 instructions - // above, this is currently 4. - "mov w1, #4\n" - - "ble 1b\n" - - // clang-format on - - : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr), - [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), - [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col) - : [ params ] "r"(¶ms), [dst_rows] "r"(params.dst_rows), - [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf), - [dst_type_id] "r"(params.dst_type_id) - : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc", - "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", - "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", - "v26", "v27", "v28", "v29", "v30", "v31"); -} - -// Similar to the above 8-bit dotprod kernel, but specialized for the case of -// RHS cols == 1. -// Relevant target CPUs for this kernel include ARM Cortex-A76, -// since these are 64-bit, out-of-order and with dotprod support. -void Kernel8bitNeonDotprodOutOfOrder1Col(const KernelParams8bit<8, 8>& params) { - profiler::ScopeLabel label( - "Kernel (kNeonDotprod, optimized for out-of-order cores)"); - - CheckOffsetsInKernelParams8bit(params); - - const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; - const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; - const std::int8_t* lhs_ptr = lhs_col_ptr; - const std::int8_t* rhs_ptr = rhs_col_ptr; - void* dst_col_ptr = params.dst_base_ptr; - void* dst_ptr = dst_col_ptr; - int row = params.start_row; - int col = params.start_col; - - // The asm kernel below has the following NEON register allocation: - // - // v16 -- v31 are int32 accumulators. - // During accumulation, v0 -- v15 are used to load int8 data from LHS and - // RHS. At least v0 and v1 are used to load a 8x4 block of LHS, and v2 and - // v3 are used to load a 4x8 block of RHS, like this: - // - // int8 RHS 4x1 block - // /-------\ - // |v2.b[0]| - // | ... | - // |v2.b[3]| - // \-------/ - // int8 LHS 8x4 block - // /---------------------\ /--------\ - // |v0.b[0] ... v0.b[3] | |v16.s[0]| - // | ... ... | | ... | - // |v0.b[12] ... v0.b[15]| |v16.s[3]| - // |v1.b[0] ... v1.b[3] | |v17.s[0]| - // | ... ... | | ... | - // |v1.b[12] ... v1.b[15]| |v17.s[3]| - // \---------------------/ \--------/ - // int32 accumulators 8x1 block - // - // In the RUY_OPT_MAX_STREAMING part of the kernel, this elementary step - // is repeated 4 times, using 4x more registers for LHS and RHS, so that - // is where instead of using v0 -- v3 for LHS and RHS, we use v0 -- v15. - // - // Outside of the RUY_OPT_MAX_STREAMING part of the kernel, v4 -- v7 are - // unused, and v8 -- v15 are used for loading parameters used for the - // post-accumulation part of the kernel. - asm volatile( -#define RUY_MAKE_ZERO(reg) "dup " #reg ".4s, wzr\n" - - // clang-format off - - // Load some parameters into registers. - "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" - "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" - "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" - "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" - "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" - "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" - "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" - - // Load the first 32 bytes of LHS and RHS data. - "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v2.8b}, [%[rhs_ptr]]\n" - "add %[rhs_ptr], %[rhs_ptr], #32\n" - - // Clear accumulators. - RUY_MAKE_ZERO(v16) - RUY_MAKE_ZERO(v17) - - // w1 is the number of levels of depth that we have already loaded - // LHS and RHS data for. Corresponding to the initial ld1 instructions - // above, this is currently 4. - "mov w1, #4\n" - - // Perform the first few multiply-adds on the data that we have already - // loaded. - ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" - - // Main loop of the whole GEMM, over rows and columns of the - // destination matrix. - "1:\n" - - // Ordinary kernel inner loop (over depth), the simpler loop that the - // above was an equivalent 4x-partially-unrolled version of. - - // Reminder - w1 is how many levels of depth we have already loaded - // data for, w12 is the total depth. - "cmp w1, w12\n" - "beq 79f\n" - - "2:\n" - - // Because of the data that we have already loaded, we can start the - // loop body right away with some multiply-adds. - // Each iteration of this loop advances by 4 levels of depth. - "add w1, w1, #4\n" - "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" - ".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n" - // Loop termination condition. - "cmp w1, w12\n" - "ld1 {v2.8b}, [%[rhs_ptr]]\n" - "add %[rhs_ptr], %[rhs_ptr], #32\n" - ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" - "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" - - "blt 2b\n" - - "79:\n" - // End of the inner loop on depth. Now perform the remaining - // multiply-adds of the last 4 levels of depth, for which the LHS - // and RHS data is already loaded. - - ".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n" - - // End of accumulation. The registers v16 -- v31 contain the final - // int32 accumulator values of the current 8x8 destination block. - // We now have to compute the final 8-bit values from these int32 - // accumulators, and advance to the next 8x8 block. We intertwine - // these two aspects whenever possible for optimal pipelining, both - // at the data flow level (prefetch data for next block as early as - // possible) and instruction pipelining level (some of the next-block - // work can dual-issue with some of the final work on the current - // block). - - // Logic to advance to the next block in preparation for the next - // iteration of the main loop. For now, we only want to compute - // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are - // not yet ready to update the values of row and col, as we still need - // the current values for the rest of the work on the current block. - - "cmp %w[row], w7\n" // Have we finished the last row? - "bge 4f\n" // If finished last row, go to 4 - // Not finished last row: then advance to next row. - "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #3\n" - "b 5f\n" - "4:\n" // Finished last row... - "mov %[lhs_col_ptr], x5\n" // Go back to first row - // Now we need to advance to the next column. If we already - // finished the last column, then in principle we are done, however - // we can't just return here, as we need to allow the end work of the - // current block to complete. The good news is that at this point it - // doesn't matter what data we load for the next column, since - // we will exit from the main loop below before actually storing - // anything computed from that data. - "cmp %w[col], w8\n" // Have we finished the last column? - "bge 5f\n" // If yes, just carry on without updating the column pointer. - // Not finished last column: then advance to next column. - "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #3\n" - "5:\n" - - // Set the LHS and RHS data pointers to the start of the columns just - // computed. - "mov %[lhs_ptr], %[lhs_col_ptr]\n" - "mov %[rhs_ptr], %[rhs_col_ptr]\n" - - // Load some parameters needed for the end work on current block. - RUY_MAKE_ZERO(v8) - "ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" - "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n" - "ins v13.h[4], w4\n" // dst_zero_point - "ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n" - "ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" - "dup v9.4s, w3\n" // create prod_zp_depth_vec - "add x5, x4, %x[row], lsl #2\n" - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" - "csel x4, x4, x5, eq\n" - - "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" - "add x5, x1, %x[row], lsl #2\n" - - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" - "csel x1, x1, x5, eq\n" - - // Load 8 bias values. - "ld1 {v14.4s}, [x1], #16\n" - "ld1 {v15.4s}, [x1]\n" - - // Now that we know what LHS and RHS data the next iteration of the - // main loop will need to load, we start loading the first 32 bytes of - // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore - // in the rest of the work on the current block. - "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v2.8b}, [%[rhs_ptr]]\n" - "add %[rhs_ptr], %[rhs_ptr], #32\n" - - // Add to the bias values the product (depth * lhs_zero_point * rhs_zero_point), - // See the term NZ1Z2 in equation (7) in https://arxiv.org/pdf/1712.05877.pdf - "add v14.4s, v14.4s, v9.4s\n" - "add v15.4s, v15.4s, v9.4s\n" - - // Perform the bias-addition (per the above, we have just folded into - // the bias the (depth * lhs_zero_point * rhs_zero_point) term.) - "add v16.4s, v16.4s, v14.4s\n" - "add v17.4s, v17.4s, v15.4s\n" - - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n" - "beq 401f\n" - "ldr x3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n" - "add x3, x3, %x[col], lsl #2\n" - "ld1 {v14.4s}, [x3], #16\n" - "ld1 {v15.4s}, [x3]\n" - "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n" - "dup v10.4s, w5\n" // create lhs_zero_point_vec - // Subtract rhs_sums * lhs_zero_point, per - // equation (7) in https://arxiv.org/pdf/1712.05877.pdf - "mls v16.4s, v10.4s, v14.s[0]\n" - "mls v17.4s, v10.4s, v14.s[0]\n" - "401:\n" - - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n" - "beq 402f\n" - "ldr x2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n" - "add x2, x2, %x[row], lsl #2\n" - "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n" - // Load 4 lhs_sums values. - "ld1 {v11.4s}, [x2], #16\n" - "ld1 {v12.4s}, [x2]\n" - "ins v13.s[1], w5\n" // rhs_zero_point - // Compute lhs_sums * rhs_zero_point. - "mul v11.4s, v11.4s, v13.s[1]\n" - "mul v12.4s, v12.4s, v13.s[1]\n" - // Subtract lhs_sums * rhs_zero_point, per - // equation (7) in https://arxiv.org/pdf/1712.05877.pdf - "sub v16.4s, v16.4s, v11.4s\n" - "sub v17.4s, v17.4s, v12.4s\n" - - "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n" - "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n" - - "402:\n" - - // At this point we have computed the final int32 values. Now we - // start down-quantizing them to obtain the final 8bit values from them. - - // As part of this down-quantization, our int32 values will be - // multiplied by a multiplier that has a fixed-point component and an - // exponent component. - - //Load the exponent part of the multiplier. - "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n" - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" - "add x5, x1, %x[row], lsl #2\n" - "csel x1, x1, x5, eq\n" - - "ldr q9, [x1]\n" - "ldr q10, [x1, #16]\n" - - "tst w6, #" RUY_STR(RUY_ASM_FLAG_NEEDS_LEFT_SHIFT) "\n" - "beq 403f\n" - "smax v11.4s, v9.4s, v8.4s\n" - "smax v12.4s, v10.4s, v8.4s\n" - "sshl v16.4s, v16.4s, v11.4s\n" - "sshl v17.4s, v17.4s, v12.4s\n" - "403:\n" - - "ldr q14, [x4]\n" // multiplier_fixedpoint - "ldr q15, [x4, #16]\n" // multiplier_fixedpoint - - "smin v11.4s, v9.4s, v8.4s\n" - "smin v12.4s, v10.4s, v8.4s\n" - - // Apply the fixed-point part of the multiplier. - "sqrdmulh v16.4s, v16.4s, v14.4s\n" - "sqrdmulh v17.4s, v17.4s, v15.4s\n" - - // We have some rounding division-by-power-of-two to do. This should - // always use "round to nearest". We allow for some - // freedom in how ties are broken, to strike a good compromise of - // performance on given hardware vs. perfect agreement of results - // across hardware. - // - // When RUY_OPT_NATIVE_ROUNDING is enabled, we allow for implementation - // defined tie-breaks to help performance. On NEON, this means that we - // can just use the NEON rounding instructions, such as srshl. They - // happen to be breaking ties upward. - // - // When RUY_OPT_NATIVE_ROUNDING is disabled, we implement strict - // break-ties-away-from zero, as described in Appendix B of - // https://arxiv.org/pdf/1712.05877.pdf - // When we wrote that, we thought that that would be better unbiased - // than the NEON upwards tie-breaks, and we had observed some - // improvement on some model. However, that is only more unbiased for - // data centered at zero, which was likely the case in that model, - // but is not always the case. If we wanted something more consistently - // unbiased then we should try breaking ties toward-nearest-even. -#if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING) - // Fix up values to be right-shifted, so that the (round to nearest, - // break ties upward) behavior of srshl applied to these fixed-up - // values, produces the same result as the desired (round to nearest, - // break ties away from zero) behavior on the original values. - "and v8.16b, v16.16b, v11.16b\n" - "and v9.16b, v17.16b, v12.16b\n" - "sshr v8.4s, v8.4s, #31\n" - "sshr v9.4s, v9.4s, #31\n" - "sqadd v16.4s, v16.4s, v8.4s\n" - "sqadd v17.4s, v17.4s, v9.4s\n" - -#endif - // At this point we have reduced the problem of correctly implementing - // rounding divide-by-power-of-two, to what the SRSHL instruction can - // do. - "srshl v16.4s, v16.4s, v11.4s\n" - "srshl v17.4s, v17.4s, v12.4s\n" - - "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n" - "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n" - "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n" - "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n" - - RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n" - - // Cast-and-saturate from int32 to int16 - "sqxtn v16.4h, v16.4s\n" - "sqxtn2 v16.8h, v17.4s\n" - // All data in v16 at this point. - - // Add the destination zero point - "dup v14.8h, v13.h[4]\n" - "add v16.8h, v16.8h, v14.8h\n" - - // Cast-and-saturate from int16 to uint8, leaving all data in the - // lower half of v16. - "sqxtun v16.8b, v16.8h\n" - - // Load the clamp_min, clamp_max bounds - "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" - "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" - "dup v14.16b, w2\n" // clamp_min - "dup v15.16b, w3\n" // clamp_max - - // Apply the clamp_min bound - "umax v16.16b, v16.16b, v14.16b\n" - - // Apply the clamp_max bound - "umin v16.16b, v16.16b, v15.16b\n" - - // Make it so that all of the final 8bit values are stored in the - // first 64bits of 128bit NEON registers, so they can be stored - // by 64bit st1 store instructions with byte alignment. - "dup d20, v16.d[1]\n" - - // Compute how much of the 8x1 block of destination 8bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 8x1, there are some 8x1 blocks along the boundaries that do - // not fit entirely. - "sub w1, %w[dst_rows], %w[row]\n" - "sub w2, %w[dst_cols], %w[col]\n" - "mov w3, #8\n" - "cmp w1, #8\n" - // Compute w1 = how many rows of the 8x1 block fit - "csel w1, w1, w3, le\n" - "cmp w2, #8\n" - - // Test if w1==8, i.e. if all of the 8x1 block fits. - "cmp w1, w3\n" - // Yes, all of the 8x1 block fits, go to fast path. - "beq 30f\n" - // Not all of the 8x1 block fits. - // Set (x3 address, x4 stride) to write to dst_tmp_buf - "mov x3, %[dst_tmp_buf]\n" - "mov x4, #8\n" - "b 31f\n" - "30:\n" - // Yes, all of the 8x1 block fits. - // Set (x3 address, x4 stride) to write directly to destination matrix. - "mov x3, %[dst_ptr]\n" - "mov x4, x11\n" - "31:\n" - - // Write our 8bit values to the destination - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "st1 {v16.8b}, [x3]\n" - RUY_MAKE_ZERO(v16) - RUY_MAKE_ZERO(v17) - - // For the next block: perform the first few multiply-adds on the data - // that we have already loaded. - ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" - - // If all of the 8x8 block fits, we just finished writing it to the - // destination, so we skip the next part. - "beq 41f\n" - // Not all of the 8x8 block fits in the destination matrix. We just - // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over - // it to copy into the destination matrix the part that fits. - "mov x3, %[dst_tmp_buf]\n" - "mov x4, %[dst_ptr]\n" - "mov w6, #0\n" - "50:\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov w5, #0\n" - "51:\n" - "ldrb w7, [x3, w5, uxtw]\n" - "strb w7, [x4, w5, uxtw]\n" - "add w5, w5, #1\n" - "cmp w5, w1\n" - "blt 51b\n" - "41:\n" - "add %[dst_ptr], %[dst_ptr], #8\n" - // At this point we have completely finished writing values to the - // destination matrix for the current block. - - "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" - - RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n" - - // Cast-and-saturate from int32 to int16 - "sqxtn v16.4h, v16.4s\n" - "sqxtn2 v16.8h, v17.4s\n" - - - // Add the destination zero point - "dup v14.8h, v13.h[4]\n" - "add v16.8h, v16.8h, v14.8h\n" - - // Cast-and-saturate from int16 to uint8 - "sqxtn v16.8b, v16.8h\n" - - // Load the clamp_min, clamp_max bounds - "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" - "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" - "dup v14.16b, w2\n" // clamp_min - "dup v15.16b, w3\n" // clamp_max - - // Apply the clamp_min bound - "smax v16.16b, v16.16b, v14.16b\n" - - // Apply the clamp_max bound - "smin v16.16b, v16.16b, v15.16b\n" - - // Make it so that all of the final 8bit values are stored in the - // first 64bits of 128bit NEON registers, so they can be stored - // by 64bit st1 store instructions with byte alignment. - "dup d20, v16.d[1]\n" - - // Compute how much of the 8x1 block of destination 8bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 8x8, there are some 8x8 blocks along the boundaries that do - // not fit entirely. - "sub w1, %w[dst_rows], %w[row]\n" - "sub w2, %w[dst_cols], %w[col]\n" - "mov w3, #8\n" - "cmp w1, #8\n" - // Compute w1 = how many rows of the 8x1 block fit - "csel w1, w1, w3, le\n" - "cmp w2, #8\n" - - // Test if w1==8, i.e. if all of the 8x1 block fits. - "cmp w1, w3\n" - // Yes, all of the 8x1 block fits, go to fast path. - "beq 130f\n" - // Not all of the 8x1 block fits. - // Set (x3 address, x4 stride) to write to dst_tmp_buf - "mov x3, %[dst_tmp_buf]\n" - "mov x4, #8\n" - "b 131f\n" - "130:\n" - // Yes, all of the 8x8 block fits. - // Set (x3 address, x4 stride) to write directly to destination matrix. - "mov x3, %[dst_ptr]\n" - "mov x4, x11\n" - "131:\n" - - // Write our 8bit values to the destination - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "st1 {v16.8b}, [x3]\n" - RUY_MAKE_ZERO(v16) - RUY_MAKE_ZERO(v17) - - // For the next block: perform the first few multiply-adds on the data - // that we have already loaded. - ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" - - // If all of the 8x8 block fits, we just finished writing it to the - // destination, so we skip the next part. - "beq 141f\n" - // Not all of the 8x8 block fits in the destination matrix. We just - // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over - // it to copy into the destination matrix the part that fits. - "mov x3, %[dst_tmp_buf]\n" - "mov x4, %[dst_ptr]\n" - "mov w6, #0\n" - "150:\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov w5, #0\n" - "151:\n" - "ldrb w7, [x3, w5, uxtw]\n" - "strb w7, [x4, w5, uxtw]\n" - "add w5, w5, #1\n" - "cmp w5, w1\n" - "blt 151b\n" - "141:\n" - "add %[dst_ptr], %[dst_ptr], #8\n" - // At this point we have completely finished writing values to the - // destination matrix for the current block. - - "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" - - RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n" - - // Add the destination zero point - "dup v14.8h, v13.h[4]\n" - "saddw v16.4s, v16.4s, v14.4h\n" - "saddw v17.4s, v17.4s, v14.4h\n" - - // Cast-and-saturate from int32 to int16 - "sqxtn v16.4h, v16.4s\n" - "sqxtn2 v16.8h, v17.4s\n" - - // Load the clamp_min, clamp_max bounds - "ldrsh w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" - "ldrsh w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" - "dup v14.8h, w2\n" // clamp_min - "dup v15.8h, w3\n" // clamp_max - - // Apply the clamp_min bound - "smax v16.8h, v16.8h, v14.8h\n" - // Apply the clamp_max bound - "smin v16.8h, v16.8h, v15.8h\n" - - // Compute how much of the 8x1 block of destination 16bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 8x8, there are some 8x1 blocks along the boundaries that do - // not fit entirely. - "sub w1, %w[dst_rows], %w[row]\n" - "sub w2, %w[dst_cols], %w[col]\n" - "mov w3, #8\n" - "cmp w1, #8\n" - // Compute w1 = how many rows of the 8x1 block fit - "csel w1, w1, w3, le\n" - "cmp w2, #8\n" - - // Test if w1==8, i.e. if all of the 8x8 block fits. - "cmp w1, w3\n" - // Yes, all of the 8x1 block fits, go to fast path. - "beq 230f\n" - // Not all of the 8x1 block fits. - // Set (x3 address, x4 stride) to write to dst_tmp_buf - "mov x3, %[dst_tmp_buf]\n" - "mov x4, #16\n" - "b 231f\n" - "230:\n" - // Yes, all of the 8x1 block fits. - // Set (x3 address, x4 stride) to write directly to destination matrix. - "mov x3, %[dst_ptr]\n" - "mov x4, x11\n" - "231:\n" - - // Write our 16bit values to the destination - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "st1 {v16.8h}, [x3]\n" - RUY_MAKE_ZERO(v16) - RUY_MAKE_ZERO(v17) - - // For the next block: perform the first few multiply-adds on the data - // that we have already loaded. - ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" - - // If all of the 8x1 block fits, we just finished writing it to the - // destination, so we skip the next part. - "beq 241f\n" - // Not all of the 8x1 block fits in the destination matrix. We just - // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over - // it to copy into the destination matrix the part that fits. - "mov x3, %[dst_tmp_buf]\n" - "mov x4, %[dst_ptr]\n" - "mov w6, #0\n" - "250:\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov w5, #0\n" - "251:\n" - "ldrsh w7, [x3, x5, lsl #1]\n" - "strh w7, [x4, x5, lsl #1]\n" - "add w5, w5, #1\n" - "cmp w5, w1\n" - "blt 251b\n" - "241:\n" - "add %[dst_ptr], %[dst_ptr], #16\n" - // At this point we have completely finished writing values to the - // destination matrix for the current block. - - "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" - - RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n" - - // Since the store type is the same as the accum type, no need for - // downcast. There's also no need for clamp by min/max. - - // Compute how much of the 8x1 block of destination 32 bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 8x1, there are some 8x1 blocks along the boundaries that do - // not fit entirely. - "sub w1, %w[dst_rows], %w[row]\n" - "sub w2, %w[dst_cols], %w[col]\n" - "mov w3, #8\n" - "cmp w1, #8\n" - // Compute w1 = how many rows of the 8x1 block fit - "csel w1, w1, w3, le\n" - "cmp w2, #8\n" - // Compute w1 = how many rows of the 8x8 block fit - "csel w2, w2, w3, le\n" - - // Test if w1==8, i.e. if all of the 8x8 block fits. - "cmp w1, w3\n" - // Yes, all of the 8x1 block fits, go to fast path. - "beq 330f\n" - // Not all of the 8x1 block fits. - // Set (x3 address, x4 stride) to write to dst_tmp_buf - "mov x3, %[dst_tmp_buf]\n" - "mov x4, #16\n" - - // Write our 32bit values to the destination described by - // (x3 address, x4 stride). - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v16.4s}, [x3], x4\n" - RUY_MAKE_ZERO(v16) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v17.4s}, [x3], x4\n" - RUY_MAKE_ZERO(v17) - - "b 331f\n" - - "330:\n" - // Yes, all of the 8x1 block fits. - // Set (x3 address, x4 stride) to write directly to destination matrix. - "mov x4, %[dst_ptr]\n" - "mov x3, x4\n" - - // Write our 32bit values to the destination described by - // (x3 address, x4 stride). - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v16.4s, v17.4s}, [x3], #32\n" - RUY_MAKE_ZERO(v16) - RUY_MAKE_ZERO(v17) - - "331:\n" - - // For the next block: perform the first few multiply-adds on the data - // that we have already loaded. - ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" - - // If all of the 8x8 block fits, we just finished writing it to the - // destination, so we skip the next part. - "beq 341f\n" - - // Not all of the 8x8 block fits in the destination matrix. We just - // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over - // it to copy into the destination matrix the part that fits. - "mov x3, %[dst_tmp_buf]\n" - "mov x4, %[dst_ptr]\n" - "mov w6, #0\n" - "350:\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "mov w5, #0\n" - "351:\n" - "ldr w7, [x3, x5, lsl #2]\n" - "str w7, [x4, x5, lsl #2]\n" - "add w5, w5, #1\n" - "cmp w5, w1\n" - "blt 351b\n" - "341:\n" - "add %[dst_ptr], %[dst_ptr], #32\n" - // At this point we have completely finished writing values to the - // destination matrix for the current block. - - RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n" - - // Reload some params --- we had used x5 -- x7 for a few other things - // since the last time we had loaded them. - "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" - "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" - - // Move to the next block of the destination matrix, for the next iter - // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already - // been updated earlier. - // Have we reached the end row? - "cmp %w[row], w7\n" - "beq 20f\n" // yes, end row. - // Not end row. Move to the next row. - "add %w[row], %w[row], #8\n" - "b 21f\n" - "20:\n" - // Was already at end row. - "mov %w[row], w6\n" // Move back to first row. - "add %w[col], %w[col], #8\n" // Move to the next column. - "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #3\n" - "mov %[dst_ptr], %[dst_col_ptr]\n" - "21:\n" - - // Main loop exit condition: have we hit the end column? - "cmp %w[col], w8\n" - - // w1 is the number of levels of depth that we have already loaded - // LHS and RHS data for. Corresponding to the initial ld1 instructions - // above, this is currently 4. - "mov w1, #4\n" - - "ble 1b\n" - - // clang-format on - - : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr), - [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), - [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col) - : [ params ] "r"(¶ms), [dst_rows] "r"(params.dst_rows), - [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf), - [dst_type_id] "r"(params.dst_type_id) - : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc", - "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", - "v13", "v14", "v15", "v16", "v17"); -} - -// Variant of the above Kernel8bitNeonDotprodOutOfOrder, tuned for in-order -// CPUs. Specifically here, the relevant in-order CPUs are ARM Cortex-A55r1, -// since these are 64-bit and support dotprod. -// -// While this kernel does not have a direct equivalent in gemmlowp, it was -// developed based on insights that David Mansell at ARM shared with their -// contribution of gemmlowp kernels tuned for Cortex-A55r1, with very helpful -// comments. Specifically, see this comment about tuning for Cortex-A55r1: -// https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L4412 -void Kernel8bitNeonDotprodInOrder(const KernelParams8bit<8, 8>& params) { - profiler::ScopeLabel label( - "Kernel (kNeonDotprod, optimized for in-order cores)"); - - CheckOffsetsInKernelParams8bit(params); - - const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; - const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; - const std::int8_t* lhs_ptr = lhs_col_ptr; - const std::int8_t* rhs_ptr = rhs_col_ptr; - void* dst_col_ptr = params.dst_base_ptr; - void* dst_ptr = dst_col_ptr; - int row = params.start_row; - int col = params.start_col; - - // The asm kernel below has the following NEON register allocation: - // - // v16 -- v31 are int32 accumulators. - // During accumulation, v0 -- v3 are used to load int8 data from LHS and - // RHS. - // - // int8 RHS 4x8 block - // /-----------------------------------------\ - // |v2.b[0] ... v2.b[12] v3.b[0] ... v3.b[12]| - // | ... ... | - // |v2.b[3] ... v2.b[15] v3.b[3] ... v3.b[15]| - // \-----------------------------------------/ - // int8 LHS 8x4 block - // /---------------------\ /-----------------------------------------\ - // |v0.b[0] ... v0.b[3] | |v16.s[0] ... v30.s[0]| - // | ... ... | | ... ... | - // |v0.b[12] ... v0.b[15]| |v16.s[3] ... v30.s[3]| - // |v1.b[0] ... v1.b[3] | |v17.s[0] ... v31.s[0]| - // | ... ... | | ... ... | - // |v1.b[12] ... v1.b[15]| |v17.s[3] ... v31.s[3]| - // \---------------------/ \-----------------------------------------/ - // int32 accumulators 8x8 block - // - // There is no RUY_OPT_MAX_STREAMING 4x-unrolled part in this kernel because - // we did not observe a benefit of such partial unrolling on in-order CPUs. - // - // v4 -- v7 are unused, and v8 -- v15 are used for loading parameters used for - // the post-accumulation part of the kernel. - asm volatile( -#define RUY_MAKE_ZERO(reg) "dup " #reg ".4s, wzr\n" - - // clang-format off - - // Load some parameters into registers. - "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - RUY_MAKE_ZERO(v16) - "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" - RUY_MAKE_ZERO(v17) - "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" - RUY_MAKE_ZERO(v18) - "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" - RUY_MAKE_ZERO(v19) - "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" - RUY_MAKE_ZERO(v20) - "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" - RUY_MAKE_ZERO(v21) - "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" - RUY_MAKE_ZERO(v22) - "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" - - // Load the first 32 bytes of LHS and RHS data. - "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" - "ld1 {v2.16b}, [%[rhs_ptr]], #16\n" - "ld1 {v3.16b}, [%[rhs_ptr]], #16\n" - - // Clear accumulators. - RUY_MAKE_ZERO(v23) - RUY_MAKE_ZERO(v24) - RUY_MAKE_ZERO(v25) - RUY_MAKE_ZERO(v26) - RUY_MAKE_ZERO(v27) - // Perform the first few multiply-adds on the data that we have already - // loaded. - ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" - RUY_MAKE_ZERO(v28) - ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" - RUY_MAKE_ZERO(v29) - ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" - RUY_MAKE_ZERO(v30) - ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" - RUY_MAKE_ZERO(v31) - - - "1:\n" - - "add x5, %[lhs_ptr], x12, lsl #3\n" - "sub x5, x5, #32\n" - "cmp %[lhs_ptr], x5\n" - - "beq 79f\n" - - // Main accumulation loop - "2:\n" - ".word 0x4f83e018 // sdot v24.4s, v0.16b, v3.4b[0]\n" - "ldr x1, [%[lhs_ptr], #8]\n" - ".word 0x4fa3e01a // sdot v26.4s, v0.16b, v3.4b[1]\n" - "ldr x3, [%[rhs_ptr], #8]\n" - ".word 0x4f83e81c // sdot v28.4s, v0.16b, v3.4b[2]\n" - "ldr x4, [%[rhs_ptr], #24]\n" - ".word 0x4fa3e81e // sdot v30.4s, v0.16b, v3.4b[3]\n" - "ldr d0, [%[lhs_ptr], #0]\n" - ".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n" - "ins v0.d[1], x1\n" - ".word 0x4fa2e033 // sdot v19.4s, v1.16b, v2.4b[1]\n" - "ldr x2, [%[lhs_ptr], #24]\n" - ".word 0x4f82e835 // sdot v21.4s, v1.16b, v2.4b[2]\n" - "add %[lhs_ptr], %[lhs_ptr], #32\n" - ".word 0x4fa2e837 // sdot v23.4s, v1.16b, v2.4b[3]\n" - "ldr d2, [%[rhs_ptr], #0]\n" - ".word 0x4f83e039 // sdot v25.4s, v1.16b, v3.4b[0]\n" - "ins v2.d[1], x3\n" - ".word 0x4fa3e03b // sdot v27.4s, v1.16b, v3.4b[1]\n" - "cmp %[lhs_ptr], x5\n" - ".word 0x4f83e83d // sdot v29.4s, v1.16b, v3.4b[2]\n" - "add %[rhs_ptr], %[rhs_ptr], #32\n" - ".word 0x4fa3e83f // sdot v31.4s, v1.16b, v3.4b[3]\n" - "ldr d3, [%[rhs_ptr], #-16]\n" - ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" - "ldr d1, [%[lhs_ptr], #-16]\n" - ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" - "ins v3.d[1], x4\n" - ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" - "ins v1.d[1], x2\n" - ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" - "blt 2b\n" - - // Last accumulation steps, nothing left to load. - "79:\n" - ".word 0x4f83e018 // sdot v24.4s, v0.16b, v3.4b[0]\n" - "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - ".word 0x4fa3e01a // sdot v26.4s, v0.16b, v3.4b[1]\n" - "cmp %w[row], w7\n" // Have we finished the last row? - ".word 0x4f83e81c // sdot v28.4s, v0.16b, v3.4b[2]\n" - ".word 0x4fa3e81e // sdot v30.4s, v0.16b, v3.4b[3]\n" - ".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n" - ".word 0x4fa2e033 // sdot v19.4s, v1.16b, v2.4b[1]\n" - ".word 0x4f82e835 // sdot v21.4s, v1.16b, v2.4b[2]\n" - ".word 0x4fa2e837 // sdot v23.4s, v1.16b, v2.4b[3]\n" - ".word 0x4f83e039 // sdot v25.4s, v1.16b, v3.4b[0]\n" - ".word 0x4fa3e03b // sdot v27.4s, v1.16b, v3.4b[1]\n" - ".word 0x4f83e83d // sdot v29.4s, v1.16b, v3.4b[2]\n" - ".word 0x4fa3e83f // sdot v31.4s, v1.16b, v3.4b[3]\n" - - // End of accumulation. The registers v16 -- v31 contain the final - // int32 accumulator values of the current 8x8 destination block. - // We now have to compute the final 8-bit values from these int32 - // accumulators, and advance to the next 8x8 block. We intertwine - // these two aspects whenever possible for optimal pipelining, both - // at the data flow level (prefetch data for next block as early as - // possible) and instruction pipelining level (some of the next-block - // work can dual-issue with some of the final work on the current - // block). - - // Logic to advance to the next block in preparation for the next - // iteration of the main loop. For now, we only want to compute - // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are - // not yet ready to update the values of row and col, as we still need - // the current values for the rest of the work on the current block. - - "bge 4f\n" // If finished last row, go to 4 - // Not finished last row: then advance to next row. - "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #3\n" - "b 5f\n" - "4:\n" // Finished last row... - "mov %[lhs_col_ptr], x5\n" // Go back to first row - // Now we need to advance to the next column. If we already - // finished the last column, then in principle we are done, however - // we can't just return here, as we need to allow the end work of the - // current block to complete. The good news is that at this point it - // doesn't matter what data we load for the next column, since - // we will exit from the main loop below before actually storing - // anything computed from that data. - "cmp %w[col], w8\n" // Have we finished the last column? - "bge 5f\n" // If yes, just carry on without updating the column pointer. - // Not finished last column: then advance to next column. - "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #3\n" - "5:\n" - - // Set the LHS and RHS data pointers to the start of the columns just - // computed. - "mov %[lhs_ptr], %[lhs_col_ptr]\n" - // Load some parameters needed for the end work on current block. - RUY_MAKE_ZERO(v8) - "mov %[rhs_ptr], %[rhs_col_ptr]\n" - "ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" - "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n" - "ins v13.h[4], w4\n" // dst_zero_point - "ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n" - "ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" - "dup v9.4s, w3\n" // create prod_zp_depth_vec - "add x5, x4, %x[row], lsl #2\n" - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" - "csel x4, x4, x5, eq\n" - - "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" - "add x5, x1, %x[row], lsl #2\n" - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" - "csel x1, x1, x5, eq\n" - - // Load 8 bias values. - "ld1 {v14.2s}, [x1], #8\n" - "ldr x5, [x1], #8\n" - "ins v14.d[1], x5\n" - "ld1 {v15.2s}, [x1], #8\n" - "ldr x5, [x1], #8\n" - "ins v15.d[1], x5\n" - - // Add to the bias values the product (depth * lhs_zero_point * rhs_zero_point), - // See the term NZ1Z2 in equation (7) in https://arxiv.org/pdf/1712.05877.pdf - "add v14.4s, v14.4s, v9.4s\n" - "add v15.4s, v15.4s, v9.4s\n" - // Perform the bias-addition (per the above, we have just folded into - // the bias the (depth * lhs_zero_point * rhs_zero_point) term.) - "add v16.4s, v16.4s, v14.4s\n" - "add v17.4s, v17.4s, v15.4s\n" - "add v18.4s, v18.4s, v14.4s\n" - "add v19.4s, v19.4s, v15.4s\n" - "add v20.4s, v20.4s, v14.4s\n" - "add v21.4s, v21.4s, v15.4s\n" - "add v22.4s, v22.4s, v14.4s\n" - "add v23.4s, v23.4s, v15.4s\n" - "add v24.4s, v24.4s, v14.4s\n" - "add v25.4s, v25.4s, v15.4s\n" - "add v26.4s, v26.4s, v14.4s\n" - "add v27.4s, v27.4s, v15.4s\n" - "add v28.4s, v28.4s, v14.4s\n" - "add v29.4s, v29.4s, v15.4s\n" - "add v30.4s, v30.4s, v14.4s\n" - "add v31.4s, v31.4s, v15.4s\n" - - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n" - "beq 401f\n" - "ldr x3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n" - "add x3, x3, %x[col], lsl #2\n" - "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n" - "dup v10.4s, w5\n" // create lhs_zero_point_vec - // Load 8 rhs_sums values. - "ld1 {v14.2s}, [x3], #8\n" - "ldr x7, [x3], #8\n" - "ld1 {v15.2s}, [x3], #8\n" - "ins v14.d[1], x7\n" - "ldr x7, [x3], #8\n" - "ins v15.d[1], x7\n" - // Subtract rhs_sums * lhs_zero_point, per - // equation (7) in https://arxiv.org/pdf/1712.05877.pdf - "mls v16.4s, v10.4s, v14.s[0]\n" - "mls v17.4s, v10.4s, v14.s[0]\n" - "mls v18.4s, v10.4s, v14.s[1]\n" - "mls v19.4s, v10.4s, v14.s[1]\n" - "mls v20.4s, v10.4s, v14.s[2]\n" - "mls v21.4s, v10.4s, v14.s[2]\n" - "mls v22.4s, v10.4s, v14.s[3]\n" - "mls v23.4s, v10.4s, v14.s[3]\n" - "mls v24.4s, v10.4s, v15.s[0]\n" - "mls v25.4s, v10.4s, v15.s[0]\n" - "mls v26.4s, v10.4s, v15.s[1]\n" - "mls v27.4s, v10.4s, v15.s[1]\n" - "mls v28.4s, v10.4s, v15.s[2]\n" - "mls v29.4s, v10.4s, v15.s[2]\n" - "mls v30.4s, v10.4s, v15.s[3]\n" - "mls v31.4s, v10.4s, v15.s[3]\n" - "401:\n" - - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n" - "beq 402f\n" - "ldr x2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n" - "add x2, x2, %x[row], lsl #2\n" - "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n" - "ins v13.s[1], w5\n" // rhs_zero_point - // Load 8 lhs_sums values. - "ld1 {v11.2s}, [x2], #8\n" - "ldr x6, [x2], #8\n" - "ins v11.d[1], x6\n" - "ld1 {v12.2s}, [x2], #8\n" - "ldr x6, [x2], #8\n" - "ins v12.d[1], x6\n" - // Compute lhs_sums * rhs_zero_point. - "mul v11.4s, v11.4s, v13.s[1]\n" - "mul v12.4s, v12.4s, v13.s[1]\n" - // Subtract lhs_sums * rhs_zero_point, per - // equation (7) in https://arxiv.org/pdf/1712.05877.pdf - "sub v16.4s, v16.4s, v11.4s\n" - "sub v17.4s, v17.4s, v12.4s\n" - "sub v18.4s, v18.4s, v11.4s\n" - "sub v19.4s, v19.4s, v12.4s\n" - "sub v20.4s, v20.4s, v11.4s\n" - "sub v21.4s, v21.4s, v12.4s\n" - "sub v22.4s, v22.4s, v11.4s\n" - "sub v23.4s, v23.4s, v12.4s\n" - "sub v24.4s, v24.4s, v11.4s\n" - "sub v25.4s, v25.4s, v12.4s\n" - "sub v26.4s, v26.4s, v11.4s\n" - "sub v27.4s, v27.4s, v12.4s\n" - "sub v28.4s, v28.4s, v11.4s\n" - "sub v29.4s, v29.4s, v12.4s\n" - "sub v30.4s, v30.4s, v11.4s\n" - "sub v31.4s, v31.4s, v12.4s\n" - - "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n" - "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n" - - "402:\n" - - // At this point we have computed the final int32 values. Now we - // start down-quantizing them to obtain the final 8bit values from them. - - // As part of this down-quantization, our int32 values will be - // multiplied by a multiplier that has a fixed-point component and an - // exponent component. - - //Load the exponent part of the multiplier. - "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n" - "ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" - "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" - "add x5, x1, %x[row], lsl #2\n" - "csel x1, x1, x5, eq\n" - - "ldr q9, [x1]\n" - "ldr q10, [x1, #16]\n" - - "tst w6, #" RUY_STR(RUY_ASM_FLAG_NEEDS_LEFT_SHIFT) "\n" - "beq 403f\n" - "smax v11.4s, v9.4s, v8.4s\n" - "smax v12.4s, v10.4s, v8.4s\n" - "sshl v16.4s, v16.4s, v11.4s\n" - "sshl v17.4s, v17.4s, v12.4s\n" - "sshl v18.4s, v18.4s, v11.4s\n" - "sshl v19.4s, v19.4s, v12.4s\n" - "sshl v20.4s, v20.4s, v11.4s\n" - "sshl v21.4s, v21.4s, v12.4s\n" - "sshl v22.4s, v22.4s, v11.4s\n" - "sshl v23.4s, v23.4s, v12.4s\n" - "sshl v24.4s, v24.4s, v11.4s\n" - "sshl v25.4s, v25.4s, v12.4s\n" - "sshl v26.4s, v26.4s, v11.4s\n" - "sshl v27.4s, v27.4s, v12.4s\n" - "sshl v28.4s, v28.4s, v11.4s\n" - "sshl v29.4s, v29.4s, v12.4s\n" - "sshl v30.4s, v30.4s, v11.4s\n" - "sshl v31.4s, v31.4s, v12.4s\n" - "403:\n" - - "ldr q14, [x4]\n" // multiplier_fixedpoint - "ldr q15, [x4, #16]\n" // multiplier_fixedpoint - - "smin v11.4s, v9.4s, v8.4s\n" - "smin v12.4s, v10.4s, v8.4s\n" - - // Apply the fixed-point part of the multiplier. - // - // ... and, interleaved into that: - // Now that we know what LHS and RHS data the next iteration of the - // main loop will need to load, we start loading the first 32 bytes of - // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore - // in the rest of the work on the current block. - "ld1 {v0.8b}, [%[lhs_ptr]], #8\n" - "sqrdmulh v16.4s, v16.4s, v14.4s\n" - "ldr x1, [%[lhs_ptr]], #8\n" - "sqrdmulh v17.4s, v17.4s, v15.4s\n" - "ld1 {v1.8b}, [%[lhs_ptr]], #8\n" - "sqrdmulh v18.4s, v18.4s, v14.4s\n" - "ldr x2, [%[lhs_ptr]], #8\n" - "sqrdmulh v19.4s, v19.4s, v15.4s\n" - "ld1 {v2.8b}, [%[rhs_ptr]], #8\n" - "sqrdmulh v20.4s, v20.4s, v14.4s\n" - "ldr x5, [%[rhs_ptr]], #8\n" - "sqrdmulh v21.4s, v21.4s, v15.4s\n" - "ld1 {v3.8b}, [%[rhs_ptr]], #8\n" - "sqrdmulh v22.4s, v22.4s, v14.4s\n" - "ldr x6, [%[rhs_ptr]], #8\n" - "sqrdmulh v23.4s, v23.4s, v15.4s\n" - "sqrdmulh v24.4s, v24.4s, v14.4s\n" - "sqrdmulh v25.4s, v25.4s, v15.4s\n" - "sqrdmulh v26.4s, v26.4s, v14.4s\n" - "sqrdmulh v27.4s, v27.4s, v15.4s\n" - "sqrdmulh v28.4s, v28.4s, v14.4s\n" - "sqrdmulh v29.4s, v29.4s, v15.4s\n" - "sqrdmulh v30.4s, v30.4s, v14.4s\n" - "sqrdmulh v31.4s, v31.4s, v15.4s\n" - - // We have some rounding division-by-power-of-two to do. This should - // always use "round to nearest". We allow for some - // freedom in how ties are broken, to strike a good compromise of - // performance on given hardware vs. perfect agreement of results - // across hardware. - // - // When RUY_OPT_NATIVE_ROUNDING is enabled, we allow for implementation - // defined tie-breaks to help performance. On NEON, this means that we - // can just use the NEON rounding instructions, such as srshl. They - // happen to be breaking ties upward. - // - // When RUY_OPT_NATIVE_ROUNDING is disabled, we implement strict - // break-ties-away-from zero, as described in Appendix B of - // https://arxiv.org/pdf/1712.05877.pdf - // When we wrote that, we thought that that would be better unbiased - // than the NEON upwards tie-breaks, and we had observed some - // improvement on some model. However, that is only more unbiased for - // data centered at zero, which was likely the case in that model, - // but is not always the case. If we wanted something more consistently - // unbiased then we should try breaking ties toward-nearest-even. -#if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING) - // Fix up values to be right-shifted, so that the (round to nearest, - // break ties upward) behavior of srshl applied to these fixed-up - // values, produces the same result as the desired (round to nearest, - // break ties away from zero) behavior on the original values. - "and v8.16b, v16.16b, v11.16b\n" - "and v9.16b, v17.16b, v12.16b\n" - "and v14.16b, v18.16b, v11.16b\n" - "and v15.16b, v19.16b, v12.16b\n" - "sshr v8.4s, v8.4s, #31\n" - "sshr v9.4s, v9.4s, #31\n" - "sshr v14.4s, v14.4s, #31\n" - "sshr v15.4s, v15.4s, #31\n" - "sqadd v16.4s, v16.4s, v8.4s\n" - "sqadd v17.4s, v17.4s, v9.4s\n" - "sqadd v18.4s, v18.4s, v14.4s\n" - "sqadd v19.4s, v19.4s, v15.4s\n" - "and v8.16b, v20.16b, v11.16b\n" - "and v9.16b, v21.16b, v12.16b\n" - "and v14.16b, v22.16b, v11.16b\n" - "and v15.16b, v23.16b, v12.16b\n" - "sshr v8.4s, v8.4s, #31\n" - "sshr v9.4s, v9.4s, #31\n" - "sshr v14.4s, v14.4s, #31\n" - "sshr v15.4s, v15.4s, #31\n" - "sqadd v20.4s, v20.4s, v8.4s\n" - "sqadd v21.4s, v21.4s, v9.4s\n" - "sqadd v22.4s, v22.4s, v14.4s\n" - "sqadd v23.4s, v23.4s, v15.4s\n" - "and v8.16b, v24.16b, v11.16b\n" - "and v9.16b, v25.16b, v12.16b\n" - "and v14.16b, v26.16b, v11.16b\n" - "and v15.16b, v27.16b, v12.16b\n" - "sshr v8.4s, v8.4s, #31\n" - "sshr v9.4s, v9.4s, #31\n" - "sshr v14.4s, v14.4s, #31\n" - "sshr v15.4s, v15.4s, #31\n" - "sqadd v24.4s, v24.4s, v8.4s\n" - "sqadd v25.4s, v25.4s, v9.4s\n" - "sqadd v26.4s, v26.4s, v14.4s\n" - "sqadd v27.4s, v27.4s, v15.4s\n" - "and v8.16b, v28.16b, v11.16b\n" - "and v9.16b, v29.16b, v12.16b\n" - "and v14.16b, v30.16b, v11.16b\n" - "and v15.16b, v31.16b, v12.16b\n" - "sshr v8.4s, v8.4s, #31\n" - "sshr v9.4s, v9.4s, #31\n" - "sshr v14.4s, v14.4s, #31\n" - "sshr v15.4s, v15.4s, #31\n" - "sqadd v28.4s, v28.4s, v8.4s\n" - "sqadd v29.4s, v29.4s, v9.4s\n" - "sqadd v30.4s, v30.4s, v14.4s\n" - "sqadd v31.4s, v31.4s, v15.4s\n" -#endif - // At this point we have reduced the problem of correctly implementing - // rounding divide-by-power-of-two, to what the SRSHL instruction can - // do. - "srshl v16.4s, v16.4s, v11.4s\n" - "srshl v17.4s, v17.4s, v12.4s\n" - "srshl v18.4s, v18.4s, v11.4s\n" - "srshl v19.4s, v19.4s, v12.4s\n" - "srshl v20.4s, v20.4s, v11.4s\n" - "srshl v21.4s, v21.4s, v12.4s\n" - "srshl v22.4s, v22.4s, v11.4s\n" - "srshl v23.4s, v23.4s, v12.4s\n" - "srshl v24.4s, v24.4s, v11.4s\n" - "srshl v25.4s, v25.4s, v12.4s\n" - "srshl v26.4s, v26.4s, v11.4s\n" - "srshl v27.4s, v27.4s, v12.4s\n" - "ins v0.d[1], x1\n" - "srshl v28.4s, v28.4s, v11.4s\n" - "ins v1.d[1], x2\n" - "srshl v29.4s, v29.4s, v12.4s\n" - "ins v2.d[1], x5\n" - "srshl v30.4s, v30.4s, v11.4s\n" - "ins v3.d[1], x6\n" - "srshl v31.4s, v31.4s, v12.4s\n" - - "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n" - "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n" - "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n" - "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n" - - RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n" - - // Cast-and-saturate from int32 to int16 - "sqxtn v16.4h, v16.4s\n" - "sqxtn2 v16.8h, v17.4s\n" - "sqxtn v17.4h, v18.4s\n" - "sqxtn2 v17.8h, v19.4s\n" - "sqxtn v18.4h, v20.4s\n" - "sqxtn2 v18.8h, v21.4s\n" - "sqxtn v19.4h, v22.4s\n" - "sqxtn2 v19.8h, v23.4s\n" - "sqxtn v20.4h, v24.4s\n" - "sqxtn2 v20.8h, v25.4s\n" - "sqxtn v21.4h, v26.4s\n" - "sqxtn2 v21.8h, v27.4s\n" - "sqxtn v22.4h, v28.4s\n" - "sqxtn2 v22.8h, v29.4s\n" - "sqxtn v23.4h, v30.4s\n" - "sqxtn2 v23.8h, v31.4s\n" - - // Destination zero_point - "dup v14.8h, v13.h[4]\n" - // At this point, v24 -- v31 aren't used anymore for the current block, - // so we can start clearing these accumulators for the next block - // (next iteration of the main loop). - RUY_MAKE_ZERO(v24) - RUY_MAKE_ZERO(v25) - RUY_MAKE_ZERO(v26) - RUY_MAKE_ZERO(v27) - RUY_MAKE_ZERO(v28) - RUY_MAKE_ZERO(v29) - RUY_MAKE_ZERO(v30) - RUY_MAKE_ZERO(v31) - - // Add the destination zero point - "add v16.8h, v16.8h, v14.8h\n" - "add v17.8h, v17.8h, v14.8h\n" - "add v18.8h, v18.8h, v14.8h\n" - "add v19.8h, v19.8h, v14.8h\n" - "add v20.8h, v20.8h, v14.8h\n" - "add v21.8h, v21.8h, v14.8h\n" - "add v22.8h, v22.8h, v14.8h\n" - "add v23.8h, v23.8h, v14.8h\n" - - // Load the clamp_min, clamp_max bounds - "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" - // Cast-and-saturate from int16 to uint8 - "sqxtun v16.8b, v16.8h\n" - "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" - "sqxtun2 v16.16b, v17.8h\n" - "sqxtun v17.8b, v18.8h\n" - "sqxtun2 v17.16b, v19.8h\n" - "sqxtun v18.8b, v20.8h\n" - "sqxtun2 v18.16b, v21.8h\n" - "sqxtun v19.8b, v22.8h\n" - "sqxtun2 v19.16b, v23.8h\n" - - "dup v14.16b, w2\n" // clamp_min - "dup v15.16b, w3\n" // clamp_max - - // Compute how much of the 8x8 block of destination 8bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 8x8, there are some 8x8 blocks along the boundaries that do - // not fit entirely. - "sub w1, %w[dst_rows], %w[row]\n" - // Apply the clamp_min bound - "umax v16.16b, v16.16b, v14.16b\n" - "sub w2, %w[dst_cols], %w[col]\n" - "umax v17.16b, v17.16b, v14.16b\n" - "mov w3, #8\n" - "umax v18.16b, v18.16b, v14.16b\n" - "cmp w1, #8\n" - "umax v19.16b, v19.16b, v14.16b\n" - // Compute w1 = how many rows of the 8x8 block fit - "csel w1, w1, w3, le\n" - // Apply the clamp_max bound - "umin v16.16b, v16.16b, v15.16b\n" - "cmp w2, #8\n" - "umin v17.16b, v17.16b, v15.16b\n" - // Compute w2 = how many cols of the 8x8 block fit - "csel w2, w2, w3, le\n" - "umin v18.16b, v18.16b, v15.16b\n" - "umin v19.16b, v19.16b, v15.16b\n" - - // Make it so that all of the final 8bit values are stored in the - // first 64bits of 128bit NEON registers, so they can be stored - // by 64bit st1 store instructions with byte alignment. - "dup d20, v16.d[1]\n" - "dup d21, v17.d[1]\n" - "dup d22, v18.d[1]\n" - "dup d23, v19.d[1]\n" - - // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. - "cmp w1, w3\n" - "ccmp w2, w3, 0, eq\n" - // Yes, all of the 8x8 block fits, go to fast path. - "beq 30f\n" - // Not all of the 8x8 block fits. - // Set (x3 address, x4 stride) to write to dst_tmp_buf - "mov x3, %[dst_tmp_buf]\n" - "mov x4, #8\n" - "b 31f\n" - "30:\n" - // Yes, all of the 8x8 block fits. - // Set (x3 address, x4 stride) to write directly to destination matrix. - "mov x3, %[dst_ptr]\n" - "mov x4, x11\n" - "31:\n" - - // Write our 8bit values to the destination described by - // (x3 address, x4 stride). - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v16.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v16) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v20.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v20) - // For the next block: perform the first few multiply-adds on the data - // that we have already loaded. - ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v17.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v17) - ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v21.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v21) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v18.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v18) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v22.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v22) - ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v19.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v19) - ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v23.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v23) - - // If all of the 8x8 block fits, we just finished writing it to the - // destination, so we skip the next part. - "beq 41f\n" - // Not all of the 8x8 block fits in the destination matrix. We just - // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over - // it to copy into the destination matrix the part that fits. - "mov x3, %[dst_tmp_buf]\n" - "mov x4, %[dst_ptr]\n" - "mov w6, #0\n" - "50:\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov w5, #0\n" - "51:\n" - "ldrb w7, [x3, w5, uxtw]\n" - "strb w7, [x4, w5, uxtw]\n" - "add w5, w5, #1\n" - "cmp w5, w1\n" - "blt 51b\n" - "add w6, w6, #1\n" - "add x3, x3, #8\n" - "add x4, x4, x11\n" - "cmp w6, w2\n" - "blt 50b\n" - "41:\n" - "add %[dst_ptr], %[dst_ptr], #8\n" - - // At this point we have completely finished writing values to the - // destination matrix for the current block. - - "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" - - RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n" - - // Cast-and-saturate from int32 to int16 - "sqxtn v16.4h, v16.4s\n" - "sqxtn2 v16.8h, v17.4s\n" - "sqxtn v17.4h, v18.4s\n" - "sqxtn2 v17.8h, v19.4s\n" - "sqxtn v18.4h, v20.4s\n" - "sqxtn2 v18.8h, v21.4s\n" - "sqxtn v19.4h, v22.4s\n" - "sqxtn2 v19.8h, v23.4s\n" - "sqxtn v20.4h, v24.4s\n" - "sqxtn2 v20.8h, v25.4s\n" - "sqxtn v21.4h, v26.4s\n" - "sqxtn2 v21.8h, v27.4s\n" - "sqxtn v22.4h, v28.4s\n" - "sqxtn2 v22.8h, v29.4s\n" - "sqxtn v23.4h, v30.4s\n" - "sqxtn2 v23.8h, v31.4s\n" - - // Destination zero_point - "dup v14.8h, v13.h[4]\n" - // At this point, v24 -- v31 aren't used anymore for the current block, - // so we can start clearing these accumulators for the next block - // (next iteration of the main loop). - RUY_MAKE_ZERO(v24) - RUY_MAKE_ZERO(v25) - RUY_MAKE_ZERO(v26) - RUY_MAKE_ZERO(v27) - RUY_MAKE_ZERO(v28) - RUY_MAKE_ZERO(v29) - RUY_MAKE_ZERO(v30) - RUY_MAKE_ZERO(v31) - - // Add the destination zero point - "add v16.8h, v16.8h, v14.8h\n" - "add v17.8h, v17.8h, v14.8h\n" - "add v18.8h, v18.8h, v14.8h\n" - "add v19.8h, v19.8h, v14.8h\n" - "add v20.8h, v20.8h, v14.8h\n" - "add v21.8h, v21.8h, v14.8h\n" - "add v22.8h, v22.8h, v14.8h\n" - "add v23.8h, v23.8h, v14.8h\n" - - // Load the clamp_min, clamp_max bounds - "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" - // Cast-and-saturate from int16 to uint8 - "sqxtn v16.8b, v16.8h\n" - "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" - "sqxtn2 v16.16b, v17.8h\n" - "sqxtn v17.8b, v18.8h\n" - "sqxtn2 v17.16b, v19.8h\n" - "sqxtn v18.8b, v20.8h\n" - "sqxtn2 v18.16b, v21.8h\n" - "sqxtn v19.8b, v22.8h\n" - "sqxtn2 v19.16b, v23.8h\n" - - "dup v14.16b, w2\n" // clamp_min - "dup v15.16b, w3\n" // clamp_max - - // Compute how much of the 8x8 block of destination 8bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 8x8, there are some 8x8 blocks along the boundaries that do - // not fit entirely. - "sub w1, %w[dst_rows], %w[row]\n" - // Apply the clamp_min bound - "smax v16.16b, v16.16b, v14.16b\n" - "sub w2, %w[dst_cols], %w[col]\n" - "smax v17.16b, v17.16b, v14.16b\n" - "mov w3, #8\n" - "smax v18.16b, v18.16b, v14.16b\n" - "cmp w1, #8\n" - "smax v19.16b, v19.16b, v14.16b\n" - // Compute w1 = how many rows of the 8x8 block fit - "csel w1, w1, w3, le\n" - // Apply the clamp_max bound - "smin v16.16b, v16.16b, v15.16b\n" - "cmp w2, #8\n" - "smin v17.16b, v17.16b, v15.16b\n" - // Compute w2 = how many cols of the 8x8 block fit - "csel w2, w2, w3, le\n" - "smin v18.16b, v18.16b, v15.16b\n" - "smin v19.16b, v19.16b, v15.16b\n" - - // Make it so that all of the final 8bit values are stored in the - // first 64bits of 128bit NEON registers, so they can be stored - // by 64bit st1 store instructions with byte alignment. - "dup d20, v16.d[1]\n" - "dup d21, v17.d[1]\n" - "dup d22, v18.d[1]\n" - "dup d23, v19.d[1]\n" - - // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. - "cmp w1, w3\n" - "ccmp w2, w3, 0, eq\n" - // Yes, all of the 8x8 block fits, go to fast path. - "beq 130f\n" - // Not all of the 8x8 block fits. - // Set (x3 address, x4 stride) to write to dst_tmp_buf - "mov x3, %[dst_tmp_buf]\n" - "mov x4, #8\n" - "b 131f\n" - "130:\n" - // Yes, all of the 8x8 block fits. - // Set (x3 address, x4 stride) to write directly to destination matrix. - "mov x3, %[dst_ptr]\n" - "mov x4, x11\n" - "131:\n" - - // Write our 8bit values to the destination described by - // (x3 address, x4 stride). - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v16.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v16) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v20.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v20) - // For the next block: perform the first few multiply-adds on the data - // that we have already loaded. - ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v17.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v17) - ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v21.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v21) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v18.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v18) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v22.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v22) - ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v19.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v19) - ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v23.8b}, [x3], x4\n" - RUY_MAKE_ZERO(v23) - - // If all of the 8x8 block fits, we just finished writing it to the - // destination, so we skip the next part. - "beq 141f\n" - // Not all of the 8x8 block fits in the destination matrix. We just - // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over - // it to copy into the destination matrix the part that fits. - "mov x3, %[dst_tmp_buf]\n" - "mov x4, %[dst_ptr]\n" - "mov w6, #0\n" - "150:\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov w5, #0\n" - "151:\n" - "ldrb w7, [x3, w5, uxtw]\n" - "strb w7, [x4, w5, uxtw]\n" - "add w5, w5, #1\n" - "cmp w5, w1\n" - "blt 151b\n" - "add w6, w6, #1\n" - "add x3, x3, #8\n" - "add x4, x4, x11\n" - "cmp w6, w2\n" - "blt 150b\n" - "141:\n" - "add %[dst_ptr], %[dst_ptr], #8\n" - - // At this point we have completely finished writing values to the - // destination matrix for the current block. - - "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" - - RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n" - - // Add the destination zero point - "dup v14.8h, v13.h[4]\n" - "saddw v16.4s, v16.4s, v14.4h\n" - "saddw v17.4s, v17.4s, v14.4h\n" - "saddw v18.4s, v18.4s, v14.4h\n" - "saddw v19.4s, v19.4s, v14.4h\n" - "saddw v20.4s, v20.4s, v14.4h\n" - "saddw v21.4s, v21.4s, v14.4h\n" - "saddw v22.4s, v22.4s, v14.4h\n" - "saddw v23.4s, v23.4s, v14.4h\n" - "saddw v24.4s, v24.4s, v14.4h\n" - "saddw v25.4s, v25.4s, v14.4h\n" - "saddw v26.4s, v26.4s, v14.4h\n" - "saddw v27.4s, v27.4s, v14.4h\n" - "saddw v28.4s, v28.4s, v14.4h\n" - "saddw v29.4s, v29.4s, v14.4h\n" - "saddw v30.4s, v30.4s, v14.4h\n" - "saddw v31.4s, v31.4s, v14.4h\n" - - // Cast-and-saturate from int32 to int16 - "sqxtn v16.4h, v16.4s\n" - "sqxtn2 v16.8h, v17.4s\n" - "sqxtn v17.4h, v18.4s\n" - "sqxtn2 v17.8h, v19.4s\n" - "sqxtn v18.4h, v20.4s\n" - "sqxtn2 v18.8h, v21.4s\n" - "sqxtn v19.4h, v22.4s\n" - "sqxtn2 v19.8h, v23.4s\n" - "sqxtn v20.4h, v24.4s\n" - "sqxtn2 v20.8h, v25.4s\n" - "sqxtn v21.4h, v26.4s\n" - "sqxtn2 v21.8h, v27.4s\n" - "sqxtn v22.4h, v28.4s\n" - "sqxtn2 v22.8h, v29.4s\n" - "sqxtn v23.4h, v30.4s\n" - "sqxtn2 v23.8h, v31.4s\n" - - // At this point, v24 -- v31 aren't used anymore for the current block, - // so we can start clearing these accumulators for the next block - // (next iteration of the main loop). - RUY_MAKE_ZERO(v24) - RUY_MAKE_ZERO(v25) - RUY_MAKE_ZERO(v26) - RUY_MAKE_ZERO(v27) - RUY_MAKE_ZERO(v28) - RUY_MAKE_ZERO(v29) - RUY_MAKE_ZERO(v30) - RUY_MAKE_ZERO(v31) - - // Load the clamp_min, clamp_max bounds - "ldrsh w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" - "ldrsh w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" - "dup v14.8h, w2\n" // clamp_min - "dup v15.8h, w3\n" // clamp_max - - // Apply the clamp_min bound - "smax v16.8h, v16.8h, v14.8h\n" - "smax v17.8h, v17.8h, v14.8h\n" - "smax v18.8h, v18.8h, v14.8h\n" - "smax v19.8h, v19.8h, v14.8h\n" - "smax v20.8h, v20.8h, v14.8h\n" - "smax v21.8h, v21.8h, v14.8h\n" - "smax v22.8h, v22.8h, v14.8h\n" - "smax v23.8h, v23.8h, v14.8h\n" - // Apply the clamp_max bound - "smin v16.8h, v16.8h, v15.8h\n" - "smin v17.8h, v17.8h, v15.8h\n" - "smin v18.8h, v18.8h, v15.8h\n" - "smin v19.8h, v19.8h, v15.8h\n" - "smin v20.8h, v20.8h, v15.8h\n" - "smin v21.8h, v21.8h, v15.8h\n" - "smin v22.8h, v22.8h, v15.8h\n" - "smin v23.8h, v23.8h, v15.8h\n" - - // Compute how much of the 8x8 block of destination 16bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 8x8, there are some 8x8 blocks along the boundaries that do - // not fit entirely. - "sub w1, %w[dst_rows], %w[row]\n" - "sub w2, %w[dst_cols], %w[col]\n" - "mov w3, #8\n" - "cmp w1, #8\n" - // Compute w1 = how many rows of the 8x8 block fit - "csel w1, w1, w3, le\n" - "cmp w2, #8\n" - // Compute w1 = how many rows of the 8x8 block fit - "csel w2, w2, w3, le\n" - - // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. - "cmp w1, w3\n" - "ccmp w2, w3, 0, eq\n" - // Yes, all of the 8x8 block fits, go to fast path. - "beq 230f\n" - // Not all of the 8x8 block fits. - // Set (x3 address, x4 stride) to write to dst_tmp_buf - "mov x3, %[dst_tmp_buf]\n" - "mov x4, #16\n" - "b 231f\n" - "230:\n" - // Yes, all of the 8x8 block fits. - // Set (x3 address, x4 stride) to write directly to destination matrix. - "mov x3, %[dst_ptr]\n" - "mov x4, x11\n" - "231:\n" - - // Write our 8bit values to the destination described by - // (x3 address, x4 stride). - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v16.8h}, [x3], x4\n" - RUY_MAKE_ZERO(v16) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v17.8h}, [x3], x4\n" - RUY_MAKE_ZERO(v17) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v18.8h}, [x3], x4\n" - RUY_MAKE_ZERO(v18) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v19.8h}, [x3], x4\n" - RUY_MAKE_ZERO(v19) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v20.8h}, [x3], x4\n" - RUY_MAKE_ZERO(v20) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v21.8h}, [x3], x4\n" - RUY_MAKE_ZERO(v21) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v22.8h}, [x3], x4\n" - RUY_MAKE_ZERO(v22) - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "st1 {v23.8h}, [x3], x4\n" - RUY_MAKE_ZERO(v23) - - // For the next block: perform the first few multiply-adds on the data - // that we have already loaded. - ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" - ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" - ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" - ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" - - // If all of the 8x8 block fits, we just finished writing it to the - // destination, so we skip the next part. - "beq 241f\n" - // Not all of the 8x8 block fits in the destination matrix. We just - // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over - // it to copy into the destination matrix the part that fits. - "mov x3, %[dst_tmp_buf]\n" - "mov x4, %[dst_ptr]\n" - "mov w6, #0\n" - "250:\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov w5, #0\n" - "251:\n" - "ldrsh w7, [x3, x5, lsl #1]\n" - "strh w7, [x4, x5, lsl #1]\n" - "add w5, w5, #1\n" - "cmp w5, w1\n" - "blt 251b\n" - "add w6, w6, #1\n" - "add x3, x3, #16\n" - "add x4, x4, x11\n" - "cmp w6, w2\n" - "blt 250b\n" - "241:\n" - "add %[dst_ptr], %[dst_ptr], #16\n" - // At this point we have completely finished writing values to the - // destination matrix for the current block. - - "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" - - RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n" - - "ld1 {v0.8b}, [%[lhs_ptr]], #8\n" - "ldr x1, [%[lhs_ptr]], #8\n" - "ld1 {v1.8b}, [%[lhs_ptr]], #8\n" - "ldr x2, [%[lhs_ptr]], #8\n" - "ld1 {v2.8b}, [%[rhs_ptr]], #8\n" - "ldr x5, [%[rhs_ptr]], #8\n" - "ld1 {v3.8b}, [%[rhs_ptr]], #8\n" - "ldr x6, [%[rhs_ptr]], #8\n" - "ins v0.d[1], x1\n" - "ins v1.d[1], x2\n" - "ins v2.d[1], x5\n" - "ins v3.d[1], x6\n" - - // Since the store type is the same as the accum type, no need for - // downcast. There's also no need for clamp by min/max. - - // Compute how much of the 8x8 block of destination 32it values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 8x8, there are some 8x8 blocks along the boundaries that do - // not fit entirely. - "sub w1, %w[dst_rows], %w[row]\n" - "sub w2, %w[dst_cols], %w[col]\n" - "mov w3, #8\n" - "cmp w1, #8\n" - // Compute w1 = how many rows of the 8x8 block fit - "csel w1, w1, w3, le\n" - "cmp w2, #8\n" - // Compute w1 = how many rows of the 8x8 block fit - "csel w2, w2, w3, le\n" - - // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. - "cmp w1, w3\n" - "ccmp w2, w3, 0, eq\n" - // Yes, all of the 8x8 block fits, go to fast path. - "beq 330f\n" - // Not all of the 8x8 block fits. - // Write to dst_tmp_buf - "mov x3, %[dst_tmp_buf]\n" - "st1 {v16.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v16) - "st1 {v17.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v17) - "st1 {v18.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v18) - "st1 {v19.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v19) - "st1 {v20.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v20) - "st1 {v21.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v21) - "st1 {v22.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v22) - "st1 {v23.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v23) - "st1 {v24.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v24) - "st1 {v25.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v25) - "st1 {v26.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v26) - "st1 {v27.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v27) - "st1 {v28.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v28) - "st1 {v29.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v29) - "st1 {v30.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v30) - "st1 {v31.4s}, [x3], #16\n" - RUY_MAKE_ZERO(v31) - - "b 331f\n" - - "330:\n" - // Yes, all of the 8x8 block fits. - "mov x4, %[dst_ptr]\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "st1 {v16.4s, v17.4s}, [x4], x11\n" - RUY_MAKE_ZERO(v16) - RUY_MAKE_ZERO(v17) - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "st1 {v18.4s, v19.4s}, [x4], x11\n" - RUY_MAKE_ZERO(v18) - RUY_MAKE_ZERO(v19) - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "st1 {v20.4s, v21.4s}, [x4], x11\n" - RUY_MAKE_ZERO(v20) - RUY_MAKE_ZERO(v21) - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "st1 {v22.4s, v23.4s}, [x4], x11\n" - RUY_MAKE_ZERO(v22) - RUY_MAKE_ZERO(v23) - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "st1 {v24.4s, v25.4s}, [x4], x11\n" - RUY_MAKE_ZERO(v24) - RUY_MAKE_ZERO(v25) - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "st1 {v26.4s, v27.4s}, [x4], x11\n" - RUY_MAKE_ZERO(v26) - RUY_MAKE_ZERO(v27) - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "st1 {v28.4s, v29.4s}, [x4], x11\n" - RUY_MAKE_ZERO(v28) - RUY_MAKE_ZERO(v29) - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "st1 {v30.4s, v31.4s}, [x4], x11\n" - RUY_MAKE_ZERO(v30) - RUY_MAKE_ZERO(v31) - - "331:\n" - - // For the next block: perform the first few multiply-adds on the data - // that we have already loaded. - ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" - ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" - ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" - ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" - - // If all of the 8x8 block fits, we just finished writing it to the - // destination, so we skip the next part. - "beq 341f\n" - - // Not all of the 8x8 block fits in the destination matrix. We just - // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over - // it to copy into the destination matrix the part that fits. - "mov x3, %[dst_tmp_buf]\n" - "mov x4, %[dst_ptr]\n" - "mov w6, #0\n" - "350:\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov w5, #0\n" - "351:\n" - "ldr w7, [x3, x5, lsl #2]\n" - "str w7, [x4, x5, lsl #2]\n" - "add w5, w5, #1\n" - "cmp w5, w1\n" - "blt 351b\n" - "add w6, w6, #1\n" - "add x3, x3, #32\n" - "add x4, x4, x11\n" - "cmp w6, w2\n" - "blt 350b\n" - "341:\n" - "add %[dst_ptr], %[dst_ptr], #32\n" - // At this point we have completely finished writing values to the - // destination matrix for the current block. - - RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n" - - // Reload some params --- we had used x5 -- x7 for a few other things - // since the last time we had loaded them. - "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" - "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" - - // Move to the next block of the destination matrix, for the next iter - // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already - // been updated earlier. - // Have we reached the end row? - "cmp %w[row], w7\n" - "beq 20f\n" // yes, end row. - // Not end row. Move to the next row. - "add %w[row], %w[row], #8\n" - "b 21f\n" - "20:\n" - // Was already at end row. - "mov %w[row], w6\n" // Move back to first row. - "add %w[col], %w[col], #8\n" // Move to the next column. - "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #3\n" - "mov %[dst_ptr], %[dst_col_ptr]\n" - "21:\n" - - // Main loop exit condition: have we hit the end column? - "cmp %w[col], w8\n" - "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - "ble 1b\n" - - // clang-format on - - : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr), - [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), - [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col) - : [ params ] "r"(¶ms), [dst_rows] "r"(params.dst_rows), - [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf), - [dst_type_id] "r"(params.dst_type_id) - : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc", - "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", - "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", - "v26", "v27", "v28", "v29", "v30", "v31"); -} -#undef RUY_OFFSET_BIAS -#undef RUY_OFFSET_LHS_SUMS -#undef RUY_OFFSET_RHS_SUMS -#undef RUY_OFFSET_LHS_BASE_PTR -#undef RUY_OFFSET_MULTIPLIER_FIXEDPOINT -#undef RUY_OFFSET_MULTIPLIER_EXPONENT -#undef RUY_OFFSET_RHS_BASE_PTR -#undef RUY_OFFSET_DST_BASE_PTR -#undef RUY_OFFSET_LHS_ZERO_POINT -#undef RUY_OFFSET_RHS_ZERO_POINT -#undef RUY_OFFSET_DST_ZERO_POINT -#undef RUY_OFFSET_PROD_ZP_DEPTH -#undef RUY_OFFSET_START_ROW -#undef RUY_OFFSET_START_COL -#undef RUY_OFFSET_LAST_ROW -#undef RUY_OFFSET_LAST_COL -#undef RUY_OFFSET_DST_ROWS -#undef RUY_OFFSET_DST_COLS -#undef RUY_OFFSET_LHS_STRIDE -#undef RUY_OFFSET_RHS_STRIDE -#undef RUY_OFFSET_DST_STRIDE -#undef RUY_OFFSET_DEPTH -#undef RUY_OFFSET_CLAMP_MIN -#undef RUY_OFFSET_CLAMP_MAX -#undef RUY_OFFSET_FLAGS - -#define RUY_OFFSET_LHS_BASE_PTR 0 -#define RUY_OFFSET_RHS_BASE_PTR 8 -#define RUY_OFFSET_DST_BASE_PTR 16 -#define RUY_OFFSET_BIAS 24 -#define RUY_OFFSET_START_ROW 32 -#define RUY_OFFSET_START_COL 36 -#define RUY_OFFSET_LAST_ROW 40 -#define RUY_OFFSET_LAST_COL 44 -#define RUY_OFFSET_LHS_STRIDE 56 -#define RUY_OFFSET_RHS_STRIDE 60 -#define RUY_OFFSET_DST_STRIDE 64 -#define RUY_OFFSET_DEPTH 68 -#define RUY_OFFSET_CLAMP_MIN 72 -#define RUY_OFFSET_CLAMP_MAX 76 -#define RUY_OFFSET_FLAGS 80 - -template -void CheckOffsetsInKernelParamsFloat(const Params&) { - static_assert(offsetof(Params, lhs_base_ptr) == RUY_OFFSET_LHS_BASE_PTR, ""); - static_assert(offsetof(Params, rhs_base_ptr) == RUY_OFFSET_RHS_BASE_PTR, ""); - static_assert(offsetof(Params, dst_base_ptr) == RUY_OFFSET_DST_BASE_PTR, ""); - static_assert(offsetof(Params, bias) == RUY_OFFSET_BIAS, ""); - static_assert(offsetof(Params, start_row) == RUY_OFFSET_START_ROW, ""); - static_assert(offsetof(Params, start_col) == RUY_OFFSET_START_COL, ""); - static_assert(offsetof(Params, last_row) == RUY_OFFSET_LAST_ROW, ""); - static_assert(offsetof(Params, last_col) == RUY_OFFSET_LAST_COL, ""); - static_assert(offsetof(Params, lhs_stride) == RUY_OFFSET_LHS_STRIDE, ""); - static_assert(offsetof(Params, rhs_stride) == RUY_OFFSET_RHS_STRIDE, ""); - static_assert(offsetof(Params, dst_stride) == RUY_OFFSET_DST_STRIDE, ""); - static_assert(offsetof(Params, depth) == RUY_OFFSET_DEPTH, ""); - static_assert(offsetof(Params, clamp_min) == RUY_OFFSET_CLAMP_MIN, ""); - static_assert(offsetof(Params, clamp_max) == RUY_OFFSET_CLAMP_MAX, ""); - static_assert(offsetof(Params, flags) == RUY_OFFSET_FLAGS, ""); -} - -// Just a plain float kernel; good enough for out-of-order cores. -// The closest to it in the gemmlowp collection would be -// NEON_64bit_GEMM_Float32_WithScalar, -// https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L3925 -// -// Besides ruy-ification, the main nuance here is that we stick to a 8x8 -// width instead of the wider 12x8 that the register space permits and that -// the aforementioned gemmlowp kernel uses. Ruy likes powers of two for now -// and we don't have evidence that going beyond 8x8 is needed. -void KernelFloatNeonOutOfOrder(const KernelParamsFloat<8, 8>& params) { - CheckOffsetsInKernelParamsFloat(params); - profiler::ScopeLabel label( - "Kernel (kNeon, optimized for out-of-order cores)"); - - const float* lhs_col_ptr = params.lhs_base_ptr; - const float* rhs_col_ptr = params.rhs_base_ptr; - const float* lhs_ptr = lhs_col_ptr; - const float* rhs_ptr = rhs_col_ptr; - float* dst_col_ptr = params.dst_base_ptr; - float* dst_ptr = dst_col_ptr; - int row = params.start_row; - int col = params.start_col; - - // The asm kernel below has the following NEON register allocation: - // - // v16 -- v31 are accumulators. - // During accumulation, v0 -- v15 are used to load data from LHS and RHS. - // At least v0 and v1 are used to load a 8x1 block of LHS, and v2 and - // v3 are used to load a 1x8 block of RHS, like this: - // - // RHS 1x8 block - // /-----------------------------------------\ - // |v2.s[0] ... v2.s[3] v3.s[0] ... v3.s[3]| - // \-----------------------------------------/ - // LHS 8x1 block - // /---------------------\ /-----------------------------------------\ - // | v0.s[0] | |v16.s[0] ... v30.s[0]| - // | ... | | ... ... | - // | v0.s[3] | |v16.s[3] ... v30.s[3]| - // | v1.s[0] | |v17.s[0] ... v31.s[0]| - // | ... | | ... ... | - // | v1.s[3] | |v17.s[3] ... v31.s[3]| - // \---------------------/ \-----------------------------------------/ - // accumulators 8x8 block - // - // In the RUY_OPT_MAX_STREAMING part of the kernel, this elementary step - // is repeated 4 times, using 4x more registers for LHS and RHS, so that - // is where instead of using v0 -- v3 for LHS and RHS, we use v0 -- v15. - // - // Outside of the RUY_OPT_MAX_STREAMING part of the kernel, v4 -- v7 are - // unused, and v8 -- v15 are used for floading parameters used for the - // post-accumulation part of the kernel. - asm volatile( -#define RUY_MAKE_ZERO(reg) "dup " #reg ".4s, wzr\n" - - // clang-format off - - // Load some parameters into registers. - "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" - "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" - "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" - "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" - "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" - "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" - "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" - - // Load the first 32 bytes of LHS and RHS data. - "ld1 {v0.4s}, [%[lhs_ptr]], #16\n" - "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" - "ld1 {v2.4s}, [%[rhs_ptr]], #16\n" - "ld1 {v3.4s}, [%[rhs_ptr]], #16\n" - - // Clear accumulators. - RUY_MAKE_ZERO(v16) - RUY_MAKE_ZERO(v17) - RUY_MAKE_ZERO(v18) - RUY_MAKE_ZERO(v19) - RUY_MAKE_ZERO(v20) - RUY_MAKE_ZERO(v21) - RUY_MAKE_ZERO(v22) - RUY_MAKE_ZERO(v23) - RUY_MAKE_ZERO(v24) - RUY_MAKE_ZERO(v25) - RUY_MAKE_ZERO(v26) - RUY_MAKE_ZERO(v27) - RUY_MAKE_ZERO(v28) - RUY_MAKE_ZERO(v29) - RUY_MAKE_ZERO(v30) - RUY_MAKE_ZERO(v31) - - // w1 is the number of levels of depth that we have already loaded - // LHS and RHS data for. Corresponding to the initial ld1 instructions - // above, this is currently 1. - "mov w1, #1\n" - - // Main loop of the whole GEMM, over rows and columns of the - // destination matrix. - "1:\n" - - "fmla v16.4s, v0.4s, v2.s[0]\n" - "fmla v18.4s, v0.4s, v2.s[1]\n" - "fmla v20.4s, v0.4s, v2.s[2]\n" - "fmla v22.4s, v0.4s, v2.s[3]\n" - -#if RUY_OPT_ENABLED(RUY_OPT_MAX_STREAMING) - "cmp w12, #8\n" - "blt 78f\n" - "and w2, w12, #-4\n" - - "ld1 {v4.4s}, [%[lhs_ptr]], #16\n" - "ld1 {v5.4s}, [%[lhs_ptr]], #16\n" - "ld1 {v6.4s}, [%[rhs_ptr]], #16\n" - "ld1 {v7.4s}, [%[rhs_ptr]], #16\n" - - "ld1 {v8.4s}, [%[lhs_ptr]], #16\n" - "ld1 {v9.4s}, [%[lhs_ptr]], #16\n" - "ld1 {v10.4s}, [%[rhs_ptr]], #16\n" - "ld1 {v11.4s}, [%[rhs_ptr]], #16\n" - - "ld1 {v12.4s}, [%[lhs_ptr]], #16\n" - "ld1 {v13.4s}, [%[lhs_ptr]], #16\n" - "ld1 {v14.4s}, [%[rhs_ptr]], #16\n" - "ld1 {v15.4s}, [%[rhs_ptr]], #16\n" - "mov w1, #4\n" - - "80:\n" - - "add %[lhs_ptr], %[lhs_ptr], #128\n" - "add %[rhs_ptr], %[rhs_ptr], #128\n" - - "fmla v24.4s, v0.4s, v3.s[0]\n" - "fmla v26.4s, v0.4s, v3.s[1]\n" - "fmla v28.4s, v0.4s, v3.s[2]\n" - "fmla v30.4s, v0.4s, v3.s[3]\n" - "ldr q0, [%[lhs_ptr], #-128]\n" - "fmla v25.4s, v1.4s, v3.s[0]\n" - "fmla v27.4s, v1.4s, v3.s[1]\n" - "fmla v29.4s, v1.4s, v3.s[2]\n" - "fmla v31.4s, v1.4s, v3.s[3]\n" - "ldr q3, [%[rhs_ptr], #-112]\n" - "fmla v17.4s, v1.4s, v2.s[0]\n" - "fmla v19.4s, v1.4s, v2.s[1]\n" - "fmla v21.4s, v1.4s, v2.s[2]\n" - "fmla v23.4s, v1.4s, v2.s[3]\n" - "ldr q1, [%[lhs_ptr], #-112]\n" - "fmla v16.4s, v4.4s, v6.s[0]\n" - "fmla v18.4s, v4.4s, v6.s[1]\n" - "ldr q2, [%[rhs_ptr], #-128]\n" - "fmla v20.4s, v4.4s, v6.s[2]\n" - "fmla v22.4s, v4.4s, v6.s[3]\n" - - "fmla v24.4s, v4.4s, v7.s[0]\n" - "fmla v26.4s, v4.4s, v7.s[1]\n" - "fmla v28.4s, v4.4s, v7.s[2]\n" - "fmla v30.4s, v4.4s, v7.s[3]\n" - "ldr q4, [%[lhs_ptr], #-96]\n" - "fmla v25.4s, v5.4s, v7.s[0]\n" - "fmla v27.4s, v5.4s, v7.s[1]\n" - "fmla v29.4s, v5.4s, v7.s[2]\n" - "fmla v31.4s, v5.4s, v7.s[3]\n" - "ldr q7, [%[rhs_ptr], #-80]\n" - "fmla v17.4s, v5.4s, v6.s[0]\n" - "fmla v19.4s, v5.4s, v6.s[1]\n" - "fmla v21.4s, v5.4s, v6.s[2]\n" - "fmla v23.4s, v5.4s, v6.s[3]\n" - "ldr q5, [%[lhs_ptr], #-80]\n" - "fmla v16.4s, v8.4s, v10.s[0]\n" - "fmla v18.4s, v8.4s, v10.s[1]\n" - "ldr q6, [%[rhs_ptr], #-96]\n" - "fmla v20.4s, v8.4s, v10.s[2]\n" - "fmla v22.4s, v8.4s, v10.s[3]\n" - - "fmla v24.4s, v8.4s, v11.s[0]\n" - "fmla v26.4s, v8.4s, v11.s[1]\n" - "fmla v28.4s, v8.4s, v11.s[2]\n" - "fmla v30.4s, v8.4s, v11.s[3]\n" - "ldr q8, [%[lhs_ptr], #-64]\n" - "fmla v25.4s, v9.4s, v11.s[0]\n" - "fmla v27.4s, v9.4s, v11.s[1]\n" - "fmla v29.4s, v9.4s, v11.s[2]\n" - "fmla v31.4s, v9.4s, v11.s[3]\n" - "ldr q11, [%[rhs_ptr], #-48]\n" - "fmla v17.4s, v9.4s, v10.s[0]\n" - "fmla v19.4s, v9.4s, v10.s[1]\n" - "fmla v21.4s, v9.4s, v10.s[2]\n" - "fmla v23.4s, v9.4s, v10.s[3]\n" - "ldr q9, [%[lhs_ptr], #-48]\n" - "fmla v16.4s, v12.4s, v14.s[0]\n" - "fmla v18.4s, v12.4s, v14.s[1]\n" - "ldr q10, [%[rhs_ptr], #-64]\n" - "fmla v20.4s, v12.4s, v14.s[2]\n" - "fmla v22.4s, v12.4s, v14.s[3]\n" - - "fmla v24.4s, v12.4s, v15.s[0]\n" - "fmla v26.4s, v12.4s, v15.s[1]\n" - "fmla v28.4s, v12.4s, v15.s[2]\n" - "fmla v30.4s, v12.4s, v15.s[3]\n" - "ldr q12, [%[lhs_ptr], #-32]\n" - "fmla v25.4s, v13.4s, v15.s[0]\n" - "fmla v27.4s, v13.4s, v15.s[1]\n" - "fmla v29.4s, v13.4s, v15.s[2]\n" - "fmla v31.4s, v13.4s, v15.s[3]\n" - "ldr q15, [%[rhs_ptr], #-16]\n" - "fmla v17.4s, v13.4s, v14.s[0]\n" - "fmla v19.4s, v13.4s, v14.s[1]\n" - "fmla v21.4s, v13.4s, v14.s[2]\n" - "fmla v23.4s, v13.4s, v14.s[3]\n" - "ldr q13, [%[lhs_ptr], #-16]\n" - "fmla v16.4s, v0.4s, v2.s[0]\n" - "fmla v18.4s, v0.4s, v2.s[1]\n" - "ldr q14, [%[rhs_ptr], #-32]\n" - "fmla v20.4s, v0.4s, v2.s[2]\n" - "fmla v22.4s, v0.4s, v2.s[3]\n" - - "add w1, w1, #4\n" - "cmp w1, w2\n" - "blt 80b\n" - - "fmla v16.4s, v4.4s, v6.s[0]\n" - "fmla v18.4s, v4.4s, v6.s[1]\n" - "fmla v20.4s, v4.4s, v6.s[2]\n" - "fmla v22.4s, v4.4s, v6.s[3]\n" - "fmla v24.4s, v4.4s, v7.s[0]\n" - "fmla v26.4s, v4.4s, v7.s[1]\n" - "fmla v28.4s, v4.4s, v7.s[2]\n" - "fmla v30.4s, v4.4s, v7.s[3]\n" - "fmla v25.4s, v5.4s, v7.s[0]\n" - "fmla v27.4s, v5.4s, v7.s[1]\n" - "fmla v29.4s, v5.4s, v7.s[2]\n" - "fmla v31.4s, v5.4s, v7.s[3]\n" - "fmla v17.4s, v5.4s, v6.s[0]\n" - "fmla v19.4s, v5.4s, v6.s[1]\n" - "fmla v21.4s, v5.4s, v6.s[2]\n" - "fmla v23.4s, v5.4s, v6.s[3]\n" - - "fmla v16.4s, v8.4s, v10.s[0]\n" - "fmla v18.4s, v8.4s, v10.s[1]\n" - "fmla v20.4s, v8.4s, v10.s[2]\n" - "fmla v22.4s, v8.4s, v10.s[3]\n" - "fmla v24.4s, v8.4s, v11.s[0]\n" - "fmla v26.4s, v8.4s, v11.s[1]\n" - "fmla v28.4s, v8.4s, v11.s[2]\n" - "fmla v30.4s, v8.4s, v11.s[3]\n" - "fmla v25.4s, v9.4s, v11.s[0]\n" - "fmla v27.4s, v9.4s, v11.s[1]\n" - "fmla v29.4s, v9.4s, v11.s[2]\n" - "fmla v31.4s, v9.4s, v11.s[3]\n" - "fmla v17.4s, v9.4s, v10.s[0]\n" - "fmla v19.4s, v9.4s, v10.s[1]\n" - "fmla v21.4s, v9.4s, v10.s[2]\n" - "fmla v23.4s, v9.4s, v10.s[3]\n" - - "fmla v16.4s, v12.4s, v14.s[0]\n" - "fmla v18.4s, v12.4s, v14.s[1]\n" - "fmla v20.4s, v12.4s, v14.s[2]\n" - "fmla v22.4s, v12.4s, v14.s[3]\n" - "fmla v24.4s, v12.4s, v15.s[0]\n" - "fmla v26.4s, v12.4s, v15.s[1]\n" - "fmla v28.4s, v12.4s, v15.s[2]\n" - "fmla v30.4s, v12.4s, v15.s[3]\n" - "fmla v25.4s, v13.4s, v15.s[0]\n" - "fmla v27.4s, v13.4s, v15.s[1]\n" - "fmla v29.4s, v13.4s, v15.s[2]\n" - "fmla v31.4s, v13.4s, v15.s[3]\n" - "fmla v17.4s, v13.4s, v14.s[0]\n" - "fmla v19.4s, v13.4s, v14.s[1]\n" - "fmla v21.4s, v13.4s, v14.s[2]\n" - "fmla v23.4s, v13.4s, v14.s[3]\n" - - "78:\n" -#endif - - // Accumulation loop - "cmp w1, w12\n" - "beq 79f\n" - - "2:\n" - "fmla v24.4s, v0.4s, v3.s[0]\n" - "fmla v26.4s, v0.4s, v3.s[1]\n" - "ld1 {v4.4s}, [%[rhs_ptr]], #16\n" - "fmla v28.4s, v0.4s, v3.s[2]\n" - "fmla v30.4s, v0.4s, v3.s[3]\n" - "ld1 {v0.4s}, [%[lhs_ptr]], #16\n" - "fmla v25.4s, v1.4s, v3.s[0]\n" - "fmla v27.4s, v1.4s, v3.s[1]\n" - "add w1, w1, #1\n" - "fmla v29.4s, v1.4s, v3.s[2]\n" - "fmla v31.4s, v1.4s, v3.s[3]\n" - "ld1 {v3.4s}, [%[rhs_ptr]], #16\n" - "fmla v17.4s, v1.4s, v2.s[0]\n" - "fmla v19.4s, v1.4s, v2.s[1]\n" - "cmp w1, w12\n" - "fmla v21.4s, v1.4s, v2.s[2]\n" - "fmla v23.4s, v1.4s, v2.s[3]\n" - "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" - "fmla v16.4s, v0.4s, v4.s[0]\n" - "fmla v18.4s, v0.4s, v4.s[1]\n" - "mov v2.16b, v4.16b\n" - "fmla v20.4s, v0.4s, v4.s[2]\n" - "fmla v22.4s, v0.4s, v4.s[3]\n" - "blt 2b\n" - - "79:\n" - - // End of the inner loop on depth. Now perform the remaining - // multiply-adds of the last level of depth, for which the LHS - // and RHS data is already loaded. - - "fmla v24.4s, v0.4s, v3.s[0]\n" - "fmla v26.4s, v0.4s, v3.s[1]\n" - "fmla v28.4s, v0.4s, v3.s[2]\n" - "fmla v30.4s, v0.4s, v3.s[3]\n" - "fmla v25.4s, v1.4s, v3.s[0]\n" - "fmla v27.4s, v1.4s, v3.s[1]\n" - "fmla v29.4s, v1.4s, v3.s[2]\n" - "fmla v31.4s, v1.4s, v3.s[3]\n" - "fmla v17.4s, v1.4s, v2.s[0]\n" - "fmla v19.4s, v1.4s, v2.s[1]\n" - "fmla v21.4s, v1.4s, v2.s[2]\n" - "fmla v23.4s, v1.4s, v2.s[3]\n" - - // End of accumulation. The registers v16 -- v31 contain the final - // int32 accumulator values of the current 8x8 destination block. - // We now have to compute the final 8-bit values from these int32 - // accumulators, and advance to the next 8x8 block. We intertwine - // these two aspects whenever possible for optimal pipelining, both - // at the data flow level (prefetch data for next block as early as - // possible) and instruction pipelining level (some of the next-block - // work can dual-issue with some of the final work on the current - // block). - - // Logic to advance to the next block in preparation for the next - // iteration of the main loop. For now, we only want to compute - // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are - // not yet ready to update the values of row and col, as we still need - // the current values for the rest of the work on the current block. - - "cmp %w[row], w7\n" // Have we finished the last row? - "bge 4f\n" // If finished last row, go to 4 - // Not finished last row: then advance to next row. - "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #3\n" - "b 5f\n" - "4:\n" // Finished last row... - "mov %[lhs_col_ptr], x5\n" // Go back to first row - // Now we need to advance to the next column. If we already - // finished the last column, then in principle we are done, however - // we can't just return here, as we need to allow the end work of the - // current block to complete. The good news is that at this point it - // doesn't matter what data we load for the next column, since - // we will exit from the main loop below before actually storing - // anything computed from that data. - "cmp %w[col], w8\n" // Have we finished the last column? - "bge 5f\n" // If yes, just carry on without updating the column pointer. - // Not finished last column: then advance to next column. - "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #3\n" - "5:\n" - - // Set the LHS and RHS data pointers to the start of the columns just - // computed. - "mov %[lhs_ptr], %[lhs_col_ptr]\n" - "mov %[rhs_ptr], %[rhs_col_ptr]\n" - - // Load some parameters needed for the end work on current block. - "ldrb w4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" - "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" - - // Offset these base pointers as needed given the current row, col. - "add x5, x1, %x[row], lsl #2\n" - - "tst w4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" - "csel x1, x1, x5, eq\n" - - // Load 8 bias values. - "ld1 {v14.4s}, [x1], #16\n" - "ld1 {v15.4s}, [x1]\n" - - // Now that we know what LHS and RHS data the next iteration of the - // main loop will need to load, we start loading the first 32 bytes of - // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore - // in the rest of the work on the current block. - "ld1 {v0.4s}, [%[lhs_ptr]], #16\n" - "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" - "ld1 {v2.4s}, [%[rhs_ptr]], #16\n" - "ld1 {v3.4s}, [%[rhs_ptr]], #16\n" - - // Perform the bias-addition (per the above, we have just folded into - // the bias the (depth * lhs_zero_point * rhs_zero_point) term.) - "fadd v16.4s, v16.4s, v14.4s\n" - "fadd v17.4s, v17.4s, v15.4s\n" - "fadd v18.4s, v18.4s, v14.4s\n" - "fadd v19.4s, v19.4s, v15.4s\n" - "fadd v20.4s, v20.4s, v14.4s\n" - "fadd v21.4s, v21.4s, v15.4s\n" - "fadd v22.4s, v22.4s, v14.4s\n" - "fadd v23.4s, v23.4s, v15.4s\n" - "fadd v24.4s, v24.4s, v14.4s\n" - "fadd v25.4s, v25.4s, v15.4s\n" - "fadd v26.4s, v26.4s, v14.4s\n" - "fadd v27.4s, v27.4s, v15.4s\n" - "fadd v28.4s, v28.4s, v14.4s\n" - "fadd v29.4s, v29.4s, v15.4s\n" - "fadd v30.4s, v30.4s, v14.4s\n" - "fadd v31.4s, v31.4s, v15.4s\n" - - // Load the clamp_min, clamp_max bounds - "ldr w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" - "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" - "dup v14.4s, w2\n" // clamp_min - "dup v15.4s, w3\n" // clamp_max - - // Apply the clamp_min bound - "fmax v16.4s, v16.4s, v14.4s\n" - "fmax v17.4s, v17.4s, v14.4s\n" - "fmax v18.4s, v18.4s, v14.4s\n" - "fmax v19.4s, v19.4s, v14.4s\n" - "fmax v20.4s, v20.4s, v14.4s\n" - "fmax v21.4s, v21.4s, v14.4s\n" - "fmax v22.4s, v22.4s, v14.4s\n" - "fmax v23.4s, v23.4s, v14.4s\n" - "fmax v24.4s, v24.4s, v14.4s\n" - "fmax v25.4s, v25.4s, v14.4s\n" - "fmax v26.4s, v26.4s, v14.4s\n" - "fmax v27.4s, v27.4s, v14.4s\n" - "fmax v28.4s, v28.4s, v14.4s\n" - "fmax v29.4s, v29.4s, v14.4s\n" - "fmax v30.4s, v30.4s, v14.4s\n" - "fmax v31.4s, v31.4s, v14.4s\n" - - // Apply the clamp_max bound - "fmin v16.4s, v16.4s, v15.4s\n" - "fmin v17.4s, v17.4s, v15.4s\n" - "fmin v18.4s, v18.4s, v15.4s\n" - "fmin v19.4s, v19.4s, v15.4s\n" - "fmin v20.4s, v20.4s, v15.4s\n" - "fmin v21.4s, v21.4s, v15.4s\n" - "fmin v22.4s, v22.4s, v15.4s\n" - "fmin v23.4s, v23.4s, v15.4s\n" - "fmin v24.4s, v24.4s, v15.4s\n" - "fmin v25.4s, v25.4s, v15.4s\n" - "fmin v26.4s, v26.4s, v15.4s\n" - "fmin v27.4s, v27.4s, v15.4s\n" - "fmin v28.4s, v28.4s, v15.4s\n" - "fmin v29.4s, v29.4s, v15.4s\n" - "fmin v30.4s, v30.4s, v15.4s\n" - "fmin v31.4s, v31.4s, v15.4s\n" - - // Compute how much of the 8x8 block of destination 8bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 8x8, there are some 8x8 blocks along the boundaries that do - // not fit entirely. - "sub w1, %w[dst_rows], %w[row]\n" - "sub w2, %w[dst_cols], %w[col]\n" - "mov w3, #8\n" - "cmp w1, #8\n" - // Compute w1 = how many rows of the 8x8 block fit - "csel w1, w1, w3, le\n" - "cmp w2, #8\n" - // Compute w2 = how many cols of the 8x8 block fit - "csel w2, w2, w3, le\n" - - // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. - "cmp w1, w3\n" - "ccmp w2, w3, 0, eq\n" - // Yes, all of the 8x8 block fits, go to fast path. - "beq 30f\n" - // Not all of the 8x8 block fits. - // Set (x3 address, x4 stride) to write to dst_tmp_buf - "mov x3, %[dst_tmp_buf]\n" - "mov x4, #32\n" - "b 31f\n" - "30:\n" - // Yes, all of the 8x8 block fits. - // Set (x3 address, x4 stride) to write directly to destination matrix. - "mov x3, %[dst_ptr]\n" - "mov x4, x11\n" - "31:\n" - - // Write our 8bit values to the destination described by - // (x3 address, x4 stride). - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "str q16, [x3, #0]\n" - "str q17, [x3, #16]\n" - "add x3, x3, x4\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - RUY_MAKE_ZERO(v16) - RUY_MAKE_ZERO(v17) - "str q18, [x3, #0]\n" - "str q19, [x3, #16]\n" - "add x3, x3, x4\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - RUY_MAKE_ZERO(v18) - RUY_MAKE_ZERO(v19) - "str q20, [x3, #0]\n" - "str q21, [x3, #16]\n" - "add x3, x3, x4\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - RUY_MAKE_ZERO(v20) - RUY_MAKE_ZERO(v21) - "str q22, [x3, #0]\n" - "str q23, [x3, #16]\n" - "add x3, x3, x4\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - RUY_MAKE_ZERO(v22) - RUY_MAKE_ZERO(v23) - "str q24, [x3, #0]\n" - "str q25, [x3, #16]\n" - "add x3, x3, x4\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - RUY_MAKE_ZERO(v24) - RUY_MAKE_ZERO(v25) - "str q26, [x3, #0]\n" - "str q27, [x3, #16]\n" - "add x3, x3, x4\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - RUY_MAKE_ZERO(v26) - RUY_MAKE_ZERO(v27) - "str q28, [x3, #0]\n" - "str q29, [x3, #16]\n" - "add x3, x3, x4\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - RUY_MAKE_ZERO(v28) - RUY_MAKE_ZERO(v29) - "str q30, [x3, #0]\n" - "str q31, [x3, #16]\n" - RUY_MAKE_ZERO(v30) - RUY_MAKE_ZERO(v31) - - // If all of the 8x8 block fits, we just finished writing it to the - // destination, so we skip the next part. - "beq 41f\n" - // Not all of the 8x8 block fits in the destination matrix. We just - // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over - // it to copy into the destination matrix the part that fits. - "mov x3, %[dst_tmp_buf]\n" - "mov x4, %[dst_ptr]\n" - "mov w6, #0\n" - "50:\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov w5, #0\n" - "51:\n" - "ldr w7, [x3, x5, lsl #2]\n" - "str w7, [x4, x5, lsl #2]\n" - "add w5, w5, #1\n" - "cmp w5, w1\n" - "blt 51b\n" - "add w6, w6, #1\n" - "add x3, x3, #32\n" - "add x4, x4, x11\n" - "cmp w6, w2\n" - "blt 50b\n" - "41:\n" - "add %[dst_ptr], %[dst_ptr], #32\n" - // At this point we have completely finished writing values to the - // destination matrix for the current block. - - // Reload some params --- we had used x5 -- x7 for a few other things - // since the last time we had loaded them. - "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" - "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" - - // Move to the next block of the destination matrix, for the next iter - // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already - // been updated earlier. - // Have we reached the end row? - "cmp %w[row], w7\n" - "beq 20f\n" // yes, end row. - // Not end row. Move to the next row. - "add %w[row], %w[row], #8\n" - "b 21f\n" - "20:\n" - // Was already at end row. - "mov %w[row], w6\n" // Move back to first row. - "add %w[col], %w[col], #8\n" // Move to the next column. - "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #3\n" - "mov %[dst_ptr], %[dst_col_ptr]\n" - "21:\n" - - // Main loop exit condition: have we hit the end column? - "cmp %w[col], w8\n" - - // w1 is the number of levels of depth that we have already loaded - // LHS and RHS data for. Corresponding to the initial ld1 instructions - // above, this is currently 1. - "mov w1, #1\n" - - "ble 1b\n" - - // clang-format on - - : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr), - [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), - [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col) - : [ params ] "r"(¶ms), [dst_rows] "r"(params.dst_rows), - [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf) - : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc", - "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", - "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", - "v26", "v27", "v28", "v29", "v30", "v31"); -} - -// Variant of KernelFloatNeonOutOfOrder tuned for in-order CPUs that do not -// support dotprod (while dotprod by itself is not relevant to floating-point, -// this additional bit of information that we have about the target happens to -// be useful here). -// -// So a typical target CPU here would be ARM Cortex-A53 or the original -// Cortex-A55. -// -// This kernel is similar to and inspired by gemmlowp's -// NEON_64bit_GEMM_Float32_WithScalar_A53. -// which was contributed by David Mansell with very helpful -// comments. Specifically, see this comment about tuning for Cortex-A53: -// https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L4215 -void KernelFloatNeonInOrder(const KernelParamsFloat<8, 8>& params) { - profiler::ScopeLabel label("Kernel (kNeon, optimized for in-order cores)"); - - CheckOffsetsInKernelParamsFloat(params); - - const float* lhs_col_ptr = params.lhs_base_ptr; - const float* rhs_col_ptr = params.rhs_base_ptr; - const float* lhs_ptr = lhs_col_ptr; - const float* rhs_ptr = rhs_col_ptr; - float* dst_col_ptr = params.dst_base_ptr; - float* dst_ptr = dst_col_ptr; - int row = params.start_row; - int col = params.start_col; - - // The asm kernel below has the following NEON register allocation: - // - // v16 -- v31 are accumulators. - // During accumulation, v0 -- v3 are used to load data from LHS and RHS. - // - // RHS 1x8 block - // /-----------------------------------------\ - // |v2.s[0] ... v2.s[3] v3.s[0] ... v3.s[3]| - // \-----------------------------------------/ - // LHS 8x1 block - // /---------------------\ /-----------------------------------------\ - // | v0.s[0] | |v16.s[0] ... v30.s[0]| - // | ... | | ... ... | - // | v0.s[3] | |v16.s[3] ... v30.s[3]| - // | v1.s[0] | |v17.s[0] ... v31.s[0]| - // | ... | | ... ... | - // | v1.s[3] | |v17.s[3] ... v31.s[3]| - // \---------------------/ \-----------------------------------------/ - // accumulators 8x8 block - // - // There is no RUY_OPT_MAX_STREAMING 4x-unrolled part in this kernel because - // we did not observe a benefit of such partial unrolling on in-order CPUs. - // - // v4 -- v7 are unused, and v8 -- v15 are used for floading parameters used - // for the post-accumulation part of the kernel. - asm volatile( -#define RUY_MAKE_ZERO(reg) "dup " #reg ".4s, wzr\n" - - // clang-format off - - // Load some parameters into registers. - "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" - "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" - "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" - "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" - "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" - "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" - "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" - - - // Clear accumulators. - RUY_MAKE_ZERO(v16) - // Load the first 32 bytes of LHS and RHS data. - "ld1 {v0.4s}, [%[lhs_ptr]], #16\n" - RUY_MAKE_ZERO(v17) - "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" - RUY_MAKE_ZERO(v18) - "ld1 {v2.4s}, [%[rhs_ptr]], #16\n" - RUY_MAKE_ZERO(v19) - "ld1 {v3.4s}, [%[rhs_ptr]], #16\n" - RUY_MAKE_ZERO(v20) - RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #64]\n") - RUY_MAKE_ZERO(v21) - RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #64]\n") - RUY_MAKE_ZERO(v22) - RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #128]\n") - RUY_MAKE_ZERO(v23) - RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #128]\n") - RUY_MAKE_ZERO(v24) - RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #192]\n") - RUY_MAKE_ZERO(v25) - RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #192]\n") - RUY_MAKE_ZERO(v26) - RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #256]\n") - RUY_MAKE_ZERO(v27) - RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #256]\n") - RUY_MAKE_ZERO(v28) - RUY_MAKE_ZERO(v29) - RUY_MAKE_ZERO(v30) - RUY_MAKE_ZERO(v31) - - // w1 is the number of levels of depth that remain to load - // LHS and RHS data for. Corresponding to the initial ld1 instructions - // above, this is currently depth - 1. - "sub w1, w12, #1\n" - - // Main loop of the whole GEMM, over rows and columns of the - // destination matrix. - "1:\n" - - "cmp w1, #0\n" - "fmla v16.4s, v0.4s, v2.s[0]\n" - "fmla v18.4s, v0.4s, v2.s[1]\n" - "fmla v20.4s, v0.4s, v2.s[2]\n" - "fmla v22.4s, v0.4s, v2.s[3]\n" - - // Accumulation loop - "beq 79f\n" - - "2:\n" - - "fmla v24.4s, v0.4s, v3.s[0]\n" - "ldr x2, [%[lhs_ptr], #8]\n" - "fmla v26.4s, v0.4s, v3.s[1]\n" - "ldr x3, [%[lhs_ptr], #24]\n" - "fmla v28.4s, v0.4s, v3.s[2]\n" - "ldr x5, [%[rhs_ptr], #24]\n" - "fmla v30.4s, v0.4s, v3.s[3]\n" - "ldr x4, [%[rhs_ptr], #8]\n" - "fmla v25.4s, v1.4s, v3.s[0]\n" - "subs w1, w1, #1\n" - "ldr d0, [%[lhs_ptr]], #32\n" - "fmla v27.4s, v1.4s, v3.s[1]\n" - "fmla v29.4s, v1.4s, v3.s[2]\n" - "fmla v31.4s, v1.4s, v3.s[3]\n" - "ins v0.d[1], x2\n" - "ldr d3, [%[rhs_ptr], #16]\n" - "fmla v17.4s, v1.4s, v2.s[0]\n" - "fmla v19.4s, v1.4s, v2.s[1]\n" - "ins v3.d[1], x5\n" - "ldr d4, [%[rhs_ptr]], #32\n" - "fmla v21.4s, v1.4s, v2.s[2]\n" - "fmla v23.4s, v1.4s, v2.s[3]\n" - "fmla v16.4s, v0.4s, v4.s[0]\n" - "ins v4.d[1], x4\n" - "ldr d1, [%[lhs_ptr], #-16]\n" - "fmla v18.4s, v0.4s, v4.s[1]\n" - "fmla v20.4s, v0.4s, v4.s[2]\n" - "ins v1.d[1], x3\n" - RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #256]\n") - "mov v2.16b, v4.16b\n" - RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #256]\n") - "fmla v22.4s, v0.4s, v4.s[3]\n" - "bne 2b\n" - - "79:\n" - - // End of the inner loop on depth. Now perform the remaining - // multiply-adds of the last level of depth, for which the LHS - // and RHS data is already loaded. - - "fmla v24.4s, v0.4s, v3.s[0]\n" - "fmla v26.4s, v0.4s, v3.s[1]\n" - "fmla v28.4s, v0.4s, v3.s[2]\n" - "fmla v30.4s, v0.4s, v3.s[3]\n" - "fmla v25.4s, v1.4s, v3.s[0]\n" - "fmla v27.4s, v1.4s, v3.s[1]\n" - "fmla v29.4s, v1.4s, v3.s[2]\n" - "fmla v31.4s, v1.4s, v3.s[3]\n" - "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - "fmla v17.4s, v1.4s, v2.s[0]\n" - "fmla v19.4s, v1.4s, v2.s[1]\n" - "fmla v21.4s, v1.4s, v2.s[2]\n" - "fmla v23.4s, v1.4s, v2.s[3]\n" - - // End of accumulation. The registers v16 -- v31 contain the final - // int32 accumulator values of the current 8x8 destination block. - // We now have to compute the final 8-bit values from these int32 - // accumulators, and advance to the next 8x8 block. We intertwine - // these two aspects whenever possible for optimal pipelining, both - // at the data flow level (prefetch data for next block as early as - // possible) and instruction pipelining level (some of the next-block - // work can dual-issue with some of the final work on the current - // block). - - // Logic to advance to the next block in preparation for the next - // iteration of the main loop. For now, we only want to compute - // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are - // not yet ready to update the values of row and col, as we still need - // the current values for the rest of the work on the current block. - - "cmp %w[row], w7\n" // Have we finished the last row? - "bge 4f\n" // If finished last row, go to 4 - // Not finished last row: then advance to next row. - "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #3\n" - "b 5f\n" - "4:\n" // Finished last row... - "mov %[lhs_col_ptr], x5\n" // Go back to first row - // Now we need to advance to the next column. If we already - // finished the last column, then in principle we are done, however - // we can't just return here, as we need to allow the end work of the - // current block to complete. The good news is that at this point it - // doesn't matter what data we load for the next column, since - // we will exit from the main loop below before actually storing - // anything computed from that data. - "cmp %w[col], w8\n" // Have we finished the last column? - "bge 5f\n" // If yes, just carry on without updating the column pointer. - // Not finished last column: then advance to next column. - "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #3\n" - "5:\n" - - // Set the LHS and RHS data pointers to the start of the columns just - // computed. - "mov %[lhs_ptr], %[lhs_col_ptr]\n" - "mov %[rhs_ptr], %[rhs_col_ptr]\n" - - // Load some parameters needed for the end work on current block. - "ldrb w4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" - "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" - - // Offset these base pointers as needed given the current row, col. - "add x5, x1, %x[row], lsl #2\n" - - "tst w4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" - "csel x1, x1, x5, eq\n" - - // Load 8 bias values. - "ld1 {v14.4s}, [x1], #16\n" - "ld1 {v15.4s}, [x1]\n" - - // Now that we know what LHS and RHS data the next iteration of the - // main loop will need to load, we start loading the first 32 bytes of - // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore - // in the rest of the work on the current block. - "ld1 {v0.4s}, [%[lhs_ptr]], #16\n" - "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" - "ld1 {v2.4s}, [%[rhs_ptr]], #16\n" - "ld1 {v3.4s}, [%[rhs_ptr]], #16\n" - - // Perform the bias-addition (per the above, we have just folded into - // the bias the (depth * lhs_zero_point * rhs_zero_point) term.) - "fadd v16.4s, v16.4s, v14.4s\n" - "fadd v17.4s, v17.4s, v15.4s\n" - "fadd v18.4s, v18.4s, v14.4s\n" - "fadd v19.4s, v19.4s, v15.4s\n" - "fadd v20.4s, v20.4s, v14.4s\n" - "fadd v21.4s, v21.4s, v15.4s\n" - "fadd v22.4s, v22.4s, v14.4s\n" - "fadd v23.4s, v23.4s, v15.4s\n" - "fadd v24.4s, v24.4s, v14.4s\n" - "fadd v25.4s, v25.4s, v15.4s\n" - "fadd v26.4s, v26.4s, v14.4s\n" - "fadd v27.4s, v27.4s, v15.4s\n" - "fadd v28.4s, v28.4s, v14.4s\n" - "fadd v29.4s, v29.4s, v15.4s\n" - "fadd v30.4s, v30.4s, v14.4s\n" - "fadd v31.4s, v31.4s, v15.4s\n" - - // Load the clamp_min, clamp_max bounds - "ldr w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" - "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" - "dup v14.4s, w2\n" // clamp_min - "dup v15.4s, w3\n" // clamp_max - - // Apply the clamp_min bound - "fmax v16.4s, v16.4s, v14.4s\n" - "fmax v17.4s, v17.4s, v14.4s\n" - "fmax v18.4s, v18.4s, v14.4s\n" - "fmax v19.4s, v19.4s, v14.4s\n" - "fmax v20.4s, v20.4s, v14.4s\n" - "fmax v21.4s, v21.4s, v14.4s\n" - "fmax v22.4s, v22.4s, v14.4s\n" - "fmax v23.4s, v23.4s, v14.4s\n" - "fmax v24.4s, v24.4s, v14.4s\n" - "fmax v25.4s, v25.4s, v14.4s\n" - "fmax v26.4s, v26.4s, v14.4s\n" - "fmax v27.4s, v27.4s, v14.4s\n" - "fmax v28.4s, v28.4s, v14.4s\n" - "fmax v29.4s, v29.4s, v14.4s\n" - "fmax v30.4s, v30.4s, v14.4s\n" - "fmax v31.4s, v31.4s, v14.4s\n" - - // Apply the clamp_max bound - "fmin v16.4s, v16.4s, v15.4s\n" - "fmin v17.4s, v17.4s, v15.4s\n" - "fmin v18.4s, v18.4s, v15.4s\n" - "fmin v19.4s, v19.4s, v15.4s\n" - "fmin v20.4s, v20.4s, v15.4s\n" - "fmin v21.4s, v21.4s, v15.4s\n" - "fmin v22.4s, v22.4s, v15.4s\n" - "fmin v23.4s, v23.4s, v15.4s\n" - "fmin v24.4s, v24.4s, v15.4s\n" - "fmin v25.4s, v25.4s, v15.4s\n" - "fmin v26.4s, v26.4s, v15.4s\n" - "fmin v27.4s, v27.4s, v15.4s\n" - "fmin v28.4s, v28.4s, v15.4s\n" - "fmin v29.4s, v29.4s, v15.4s\n" - "fmin v30.4s, v30.4s, v15.4s\n" - "fmin v31.4s, v31.4s, v15.4s\n" - - // Compute how much of the 8x8 block of destination 8bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 8x8, there are some 8x8 blocks along the boundaries that do - // not fit entirely. - "sub w1, %w[dst_rows], %w[row]\n" - "sub w2, %w[dst_cols], %w[col]\n" - "mov w3, #8\n" - "cmp w1, #8\n" - // Compute w1 = how many rows of the 8x8 block fit - "csel w1, w1, w3, le\n" - "cmp w2, #8\n" - // Compute w2 = how many cols of the 8x8 block fit - "csel w2, w2, w3, le\n" - - // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. - "cmp w1, w3\n" - "ccmp w2, w3, 0, eq\n" - // Yes, all of the 8x8 block fits, go to fast path. - "beq 30f\n" - // Not all of the 8x8 block fits. - // Set (x3 address, x4 stride) to write to dst_tmp_buf - "mov x3, %[dst_tmp_buf]\n" - "mov x4, #32\n" - "b 31f\n" - "30:\n" - // Yes, all of the 8x8 block fits. - // Set (x3 address, x4 stride) to write directly to destination matrix. - "mov x3, %[dst_ptr]\n" - "mov x4, x11\n" - "31:\n" - - // Write our 8bit values to the destination described by - // (x3 address, x4 stride). - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "str q16, [x3, #0]\n" - "str q17, [x3, #16]\n" - "add x3, x3, x4\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - RUY_MAKE_ZERO(v16) - RUY_MAKE_ZERO(v17) - "str q18, [x3, #0]\n" - "str q19, [x3, #16]\n" - "add x3, x3, x4\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - RUY_MAKE_ZERO(v18) - RUY_MAKE_ZERO(v19) - "str q20, [x3, #0]\n" - "str q21, [x3, #16]\n" - "add x3, x3, x4\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - RUY_MAKE_ZERO(v20) - RUY_MAKE_ZERO(v21) - "str q22, [x3, #0]\n" - "str q23, [x3, #16]\n" - "add x3, x3, x4\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - RUY_MAKE_ZERO(v22) - RUY_MAKE_ZERO(v23) - "str q24, [x3, #0]\n" - "str q25, [x3, #16]\n" - "add x3, x3, x4\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - RUY_MAKE_ZERO(v24) - RUY_MAKE_ZERO(v25) - "str q26, [x3, #0]\n" - "str q27, [x3, #16]\n" - "add x3, x3, x4\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - RUY_MAKE_ZERO(v26) - RUY_MAKE_ZERO(v27) - "str q28, [x3, #0]\n" - "str q29, [x3, #16]\n" - "add x3, x3, x4\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - RUY_MAKE_ZERO(v28) - RUY_MAKE_ZERO(v29) - "str q30, [x3, #0]\n" - "str q31, [x3, #16]\n" - RUY_MAKE_ZERO(v30) - RUY_MAKE_ZERO(v31) - - // If all of the 8x8 block fits, we just finished writing it to the - // destination, so we skip the next part. - "beq 41f\n" - // Not all of the 8x8 block fits in the destination matrix. We just - // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over - // it to copy into the destination matrix the part that fits. - "mov x3, %[dst_tmp_buf]\n" - "mov x4, %[dst_ptr]\n" - "mov w6, #0\n" - "50:\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov w5, #0\n" - "51:\n" - "ldr w7, [x3, x5, lsl #2]\n" - "str w7, [x4, x5, lsl #2]\n" - "add w5, w5, #1\n" - "cmp w5, w1\n" - "blt 51b\n" - "add w6, w6, #1\n" - "add x3, x3, #32\n" - "add x4, x4, x11\n" - "cmp w6, w2\n" - "blt 50b\n" - "41:\n" - "add %[dst_ptr], %[dst_ptr], #32\n" - // At this point we have completely finished writing values to the - // destination matrix for the current block. - - // Reload some params --- we had used x5 -- x7 for a few other things - // since the last time we had loaded them. - "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" - "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" - - // Move to the next block of the destination matrix, for the next iter - // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already - // been updated earlier. - // Have we reached the end row? - "cmp %w[row], w7\n" - "beq 20f\n" // yes, end row. - // Not end row. Move to the next row. - "add %w[row], %w[row], #8\n" - "b 21f\n" - "20:\n" - // Was already at end row. - "mov %w[row], w6\n" // Move back to first row. - "add %w[col], %w[col], #8\n" // Move to the next column. - "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #3\n" - "mov %[dst_ptr], %[dst_col_ptr]\n" - "21:\n" - - // Main loop exit condition: have we hit the end column? - "cmp %w[col], w8\n" - - // w1 is the number of levels of depth that remain to load - // LHS and RHS data for. Corresponding to the initial ld1 instructions - // above, this is currently depth - 1. - "sub w1, w12, #1\n" - - "ble 1b\n" - - // clang-format on - - : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr), - [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), - [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col) - : [ params ] "r"(¶ms), [dst_rows] "r"(params.dst_rows), - [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf) - : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc", - "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", - "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", - "v26", "v27", "v28", "v29", "v30", "v31"); -} - -// Variant of KernelFloatNeonInOrder tuned for in-order CPUs that do -// support dotprod (while dotprod by itself is not relevant to floating-point, -// this additional bit of information that we have about the target happens to -// be useful here). -// -// So a typical target CPU here would be ARM Cortex-A55r1. -// -// This kernel is similar to and inspired by gemmlowp's -// NEON_64bit_GEMM_Float32_WithScalar_A55r1. -// which was contributed by David Mansell with very helpful -// comments. Specifically, see this comment about tuning for Cortex-A55r1: -// https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L4412 -void KernelFloatNeonDotprodInOrder(const KernelParamsFloat<8, 8>& params) { - profiler::ScopeLabel label( - "Kernel (kNeonDotprod, optimized for in-order cores)"); - - CheckOffsetsInKernelParamsFloat(params); - - const float* lhs_col_ptr = params.lhs_base_ptr; - const float* rhs_col_ptr = params.rhs_base_ptr; - const float* lhs_ptr = lhs_col_ptr; - const float* rhs_ptr = rhs_col_ptr; - float* dst_col_ptr = params.dst_base_ptr; - float* dst_ptr = dst_col_ptr; - int row = params.start_row; - int col = params.start_col; - - // The asm kernel below has the following NEON register allocation: - // - // v16 -- v31 are accumulators. - // During accumulation, v0 -- v3 are used to load data from LHS and RHS. - // - // RHS 1x8 block - // /-----------------------------------------\ - // |v2.s[0] ... v2.s[3] v3.s[0] ... v3.s[3]| - // \-----------------------------------------/ - // LHS 8x1 block - // /---------------------\ /-----------------------------------------\ - // | v0.s[0] | |v16.s[0] ... v30.s[0]| - // | ... | | ... ... | - // | v0.s[3] | |v16.s[3] ... v30.s[3]| - // | v1.s[0] | |v17.s[0] ... v31.s[0]| - // | ... | | ... ... | - // | v1.s[3] | |v17.s[3] ... v31.s[3]| - // \---------------------/ \-----------------------------------------/ - // accumulators 8x8 block - // - // There is no RUY_OPT_MAX_STREAMING 4x-unrolled part in this kernel because - // we did not observe a benefit of such partial unrolling on in-order CPUs. - // - // v4 -- v7 are unused, and v8 -- v15 are used for floading parameters used - // for the post-accumulation part of the kernel. - asm volatile( -#define RUY_MAKE_ZERO(reg) "dup " #reg ".4s, wzr\n" - - // clang-format off - - // Load some parameters into registers. - "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" - "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" - "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" - "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" - "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" - "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" - "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" - - - // Clear accumulators. - RUY_MAKE_ZERO(v16) - // Load the first 32 bytes of LHS and RHS data. - "ld1 {v0.4s}, [%[lhs_ptr]], #16\n" - RUY_MAKE_ZERO(v17) - "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" - RUY_MAKE_ZERO(v18) - "ld1 {v2.4s}, [%[rhs_ptr]], #16\n" - RUY_MAKE_ZERO(v19) - "ld1 {v3.4s}, [%[rhs_ptr]], #16\n" - RUY_MAKE_ZERO(v20) - RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #64]\n") - RUY_MAKE_ZERO(v21) - RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #64]\n") - RUY_MAKE_ZERO(v22) - RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #128]\n") - RUY_MAKE_ZERO(v23) - RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #128]\n") - RUY_MAKE_ZERO(v24) - RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #192]\n") - RUY_MAKE_ZERO(v25) - RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #192]\n") - RUY_MAKE_ZERO(v26) - RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #256]\n") - RUY_MAKE_ZERO(v27) - RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #256]\n") - RUY_MAKE_ZERO(v28) - RUY_MAKE_ZERO(v29) - RUY_MAKE_ZERO(v30) - RUY_MAKE_ZERO(v31) - - // w1 is the number of levels of depth that remain to load - // LHS and RHS data for. Corresponding to the initial ld1 instructions - // above, this is currently depth - 1. - "sub w1, w12, #1\n" - - // Main loop of the whole GEMM, over rows and columns of the - // destination matrix. - "1:\n" - - "cmp w1, #0\n" - "fmla v16.4s, v0.4s, v2.s[0]\n" - "fmla v18.4s, v0.4s, v2.s[1]\n" - "fmla v20.4s, v0.4s, v2.s[2]\n" - "fmla v22.4s, v0.4s, v2.s[3]\n" - - // Accumulation loop - "beq 79f\n" - - "2:\n" - - RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #256]\n") - "fmla v24.4s, v0.4s, v3.s[0]\n" - "ldr x2, [%[lhs_ptr], #8]\n" - "fmla v26.4s, v0.4s, v3.s[1]\n" - "ldr x3, [%[lhs_ptr], #24]\n" - "fmla v28.4s, v0.4s, v3.s[2]\n" - "ldr x5, [%[rhs_ptr], #24]\n" - "fmla v30.4s, v0.4s, v3.s[3]\n" - "ldr d0, [%[lhs_ptr]], #32\n" - "fmla v25.4s, v1.4s, v3.s[0]\n" - "ldr x4, [%[rhs_ptr], #8]\n" - "fmla v27.4s, v1.4s, v3.s[1]\n" - "subs w1, w1, #1\n" - "fmla v29.4s, v1.4s, v3.s[2]\n" - "ins v0.d[1], x2\n" - "fmla v31.4s, v1.4s, v3.s[3]\n" - "ldr d3, [%[rhs_ptr], #16]\n" - "fmla v17.4s, v1.4s, v2.s[0]\n" - "ins v3.d[1], x5\n" - "fmla v19.4s, v1.4s, v2.s[1]\n" - "ldr d4, [%[rhs_ptr]], #32\n" - "fmla v21.4s, v1.4s, v2.s[2]\n" - "ins v4.d[1], x4\n" - "fmla v23.4s, v1.4s, v2.s[3]\n" - RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #256]\n") - "fmla v16.4s, v0.4s, v4.s[0]\n" - "ldr d1, [%[lhs_ptr], #-16]\n" - "fmla v18.4s, v0.4s, v4.s[1]\n" - "ins v1.d[1], x3\n" - "fmla v20.4s, v0.4s, v4.s[2]\n" - "mov v2.16b, v4.16b\n" - "fmla v22.4s, v0.4s, v4.s[3]\n" - "bne 2b\n" - - "79:\n" - - // End of the inner loop on depth. Now perform the remaining - // multiply-adds of the last level of depth, for which the LHS - // and RHS data is already loaded. - - "fmla v24.4s, v0.4s, v3.s[0]\n" - "fmla v26.4s, v0.4s, v3.s[1]\n" - "fmla v28.4s, v0.4s, v3.s[2]\n" - "fmla v30.4s, v0.4s, v3.s[3]\n" - "fmla v25.4s, v1.4s, v3.s[0]\n" - "fmla v27.4s, v1.4s, v3.s[1]\n" - "fmla v29.4s, v1.4s, v3.s[2]\n" - "fmla v31.4s, v1.4s, v3.s[3]\n" - "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - "fmla v17.4s, v1.4s, v2.s[0]\n" - "fmla v19.4s, v1.4s, v2.s[1]\n" - "fmla v21.4s, v1.4s, v2.s[2]\n" - "fmla v23.4s, v1.4s, v2.s[3]\n" - - // End of accumulation. The registers v16 -- v31 contain the final - // int32 accumulator values of the current 8x8 destination block. - // We now have to compute the final 8-bit values from these int32 - // accumulators, and advance to the next 8x8 block. We intertwine - // these two aspects whenever possible for optimal pipelining, both - // at the data flow level (prefetch data for next block as early as - // possible) and instruction pipelining level (some of the next-block - // work can dual-issue with some of the final work on the current - // block). - - // Logic to advance to the next block in preparation for the next - // iteration of the main loop. For now, we only want to compute - // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are - // not yet ready to update the values of row and col, as we still need - // the current values for the rest of the work on the current block. - - "cmp %w[row], w7\n" // Have we finished the last row? - "bge 4f\n" // If finished last row, go to 4 - // Not finished last row: then advance to next row. - "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #3\n" - "b 5f\n" - "4:\n" // Finished last row... - "mov %[lhs_col_ptr], x5\n" // Go back to first row - // Now we need to advance to the next column. If we already - // finished the last column, then in principle we are done, however - // we can't just return here, as we need to allow the end work of the - // current block to complete. The good news is that at this point it - // doesn't matter what data we load for the next column, since - // we will exit from the main loop below before actually storing - // anything computed from that data. - "cmp %w[col], w8\n" // Have we finished the last column? - "bge 5f\n" // If yes, just carry on without updating the column pointer. - // Not finished last column: then advance to next column. - "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #3\n" - "5:\n" - - // Set the LHS and RHS data pointers to the start of the columns just - // computed. - "mov %[lhs_ptr], %[lhs_col_ptr]\n" - "mov %[rhs_ptr], %[rhs_col_ptr]\n" - - // Load some parameters needed for the end work on current block. - "ldrb w4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" - "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" - - // Offset these base pointers as needed given the current row, col. - "add x5, x1, %x[row], lsl #2\n" - - "tst w4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" - "csel x1, x1, x5, eq\n" - - // Load 8 bias values. - "ld1 {v14.4s}, [x1], #16\n" - "ld1 {v15.4s}, [x1]\n" - - // Now that we know what LHS and RHS data the next iteration of the - // main loop will need to load, we start loading the first 32 bytes of - // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore - // in the rest of the work on the current block. - "ld1 {v0.4s}, [%[lhs_ptr]], #16\n" - "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" - "ld1 {v2.4s}, [%[rhs_ptr]], #16\n" - "ld1 {v3.4s}, [%[rhs_ptr]], #16\n" - - // Perform the bias-addition (per the above, we have just folded into - // the bias the (depth * lhs_zero_point * rhs_zero_point) term.) - "fadd v16.4s, v16.4s, v14.4s\n" - "fadd v17.4s, v17.4s, v15.4s\n" - "fadd v18.4s, v18.4s, v14.4s\n" - "fadd v19.4s, v19.4s, v15.4s\n" - "fadd v20.4s, v20.4s, v14.4s\n" - "fadd v21.4s, v21.4s, v15.4s\n" - "fadd v22.4s, v22.4s, v14.4s\n" - "fadd v23.4s, v23.4s, v15.4s\n" - "fadd v24.4s, v24.4s, v14.4s\n" - "fadd v25.4s, v25.4s, v15.4s\n" - "fadd v26.4s, v26.4s, v14.4s\n" - "fadd v27.4s, v27.4s, v15.4s\n" - "fadd v28.4s, v28.4s, v14.4s\n" - "fadd v29.4s, v29.4s, v15.4s\n" - "fadd v30.4s, v30.4s, v14.4s\n" - "fadd v31.4s, v31.4s, v15.4s\n" - - // Load the clamp_min, clamp_max bounds - "ldr w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" - "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" - "dup v14.4s, w2\n" // clamp_min - "dup v15.4s, w3\n" // clamp_max - - // Apply the clamp_min bound - "fmax v16.4s, v16.4s, v14.4s\n" - "fmax v17.4s, v17.4s, v14.4s\n" - "fmax v18.4s, v18.4s, v14.4s\n" - "fmax v19.4s, v19.4s, v14.4s\n" - "fmax v20.4s, v20.4s, v14.4s\n" - "fmax v21.4s, v21.4s, v14.4s\n" - "fmax v22.4s, v22.4s, v14.4s\n" - "fmax v23.4s, v23.4s, v14.4s\n" - "fmax v24.4s, v24.4s, v14.4s\n" - "fmax v25.4s, v25.4s, v14.4s\n" - "fmax v26.4s, v26.4s, v14.4s\n" - "fmax v27.4s, v27.4s, v14.4s\n" - "fmax v28.4s, v28.4s, v14.4s\n" - "fmax v29.4s, v29.4s, v14.4s\n" - "fmax v30.4s, v30.4s, v14.4s\n" - "fmax v31.4s, v31.4s, v14.4s\n" - - // Apply the clamp_max bound - "fmin v16.4s, v16.4s, v15.4s\n" - "fmin v17.4s, v17.4s, v15.4s\n" - "fmin v18.4s, v18.4s, v15.4s\n" - "fmin v19.4s, v19.4s, v15.4s\n" - "fmin v20.4s, v20.4s, v15.4s\n" - "fmin v21.4s, v21.4s, v15.4s\n" - "fmin v22.4s, v22.4s, v15.4s\n" - "fmin v23.4s, v23.4s, v15.4s\n" - "fmin v24.4s, v24.4s, v15.4s\n" - "fmin v25.4s, v25.4s, v15.4s\n" - "fmin v26.4s, v26.4s, v15.4s\n" - "fmin v27.4s, v27.4s, v15.4s\n" - "fmin v28.4s, v28.4s, v15.4s\n" - "fmin v29.4s, v29.4s, v15.4s\n" - "fmin v30.4s, v30.4s, v15.4s\n" - "fmin v31.4s, v31.4s, v15.4s\n" - - // Compute how much of the 8x8 block of destination 8bit values that - // we have computed, fit in the destination matrix. Typically, all of - // it fits, but when the destination matrix shape is not a multiple - // of 8x8, there are some 8x8 blocks along the boundaries that do - // not fit entirely. - "sub w1, %w[dst_rows], %w[row]\n" - "sub w2, %w[dst_cols], %w[col]\n" - "mov w3, #8\n" - "cmp w1, #8\n" - // Compute w1 = how many rows of the 8x8 block fit - "csel w1, w1, w3, le\n" - "cmp w2, #8\n" - // Compute w2 = how many cols of the 8x8 block fit - "csel w2, w2, w3, le\n" - - // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. - "cmp w1, w3\n" - "ccmp w2, w3, 0, eq\n" - // Yes, all of the 8x8 block fits, go to fast path. - "beq 30f\n" - // Not all of the 8x8 block fits. - // Set (x3 address, x4 stride) to write to dst_tmp_buf - "mov x3, %[dst_tmp_buf]\n" - "mov x4, #32\n" - "b 31f\n" - "30:\n" - // Yes, all of the 8x8 block fits. - // Set (x3 address, x4 stride) to write directly to destination matrix. - "mov x3, %[dst_ptr]\n" - "mov x4, x11\n" - "31:\n" - - // Write our 8bit values to the destination described by - // (x3 address, x4 stride). - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - "str q16, [x3, #0]\n" - "str q17, [x3, #16]\n" - "add x3, x3, x4\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - RUY_MAKE_ZERO(v16) - RUY_MAKE_ZERO(v17) - "str q18, [x3, #0]\n" - "str q19, [x3, #16]\n" - "add x3, x3, x4\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - RUY_MAKE_ZERO(v18) - RUY_MAKE_ZERO(v19) - "str q20, [x3, #0]\n" - "str q21, [x3, #16]\n" - "add x3, x3, x4\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - RUY_MAKE_ZERO(v20) - RUY_MAKE_ZERO(v21) - "str q22, [x3, #0]\n" - "str q23, [x3, #16]\n" - "add x3, x3, x4\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - RUY_MAKE_ZERO(v22) - RUY_MAKE_ZERO(v23) - "str q24, [x3, #0]\n" - "str q25, [x3, #16]\n" - "add x3, x3, x4\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - RUY_MAKE_ZERO(v24) - RUY_MAKE_ZERO(v25) - "str q26, [x3, #0]\n" - "str q27, [x3, #16]\n" - "add x3, x3, x4\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - RUY_MAKE_ZERO(v26) - RUY_MAKE_ZERO(v27) - "str q28, [x3, #0]\n" - "str q29, [x3, #16]\n" - "add x3, x3, x4\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") - RUY_MAKE_ZERO(v28) - RUY_MAKE_ZERO(v29) - "str q30, [x3, #0]\n" - "str q31, [x3, #16]\n" - RUY_MAKE_ZERO(v30) - RUY_MAKE_ZERO(v31) - - // If all of the 8x8 block fits, we just finished writing it to the - // destination, so we skip the next part. - "beq 41f\n" - // Not all of the 8x8 block fits in the destination matrix. We just - // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over - // it to copy into the destination matrix the part that fits. - "mov x3, %[dst_tmp_buf]\n" - "mov x4, %[dst_ptr]\n" - "mov w6, #0\n" - "50:\n" - RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") - "mov w5, #0\n" - "51:\n" - "ldr w7, [x3, x5, lsl #2]\n" - "str w7, [x4, x5, lsl #2]\n" - "add w5, w5, #1\n" - "cmp w5, w1\n" - "blt 51b\n" - "add w6, w6, #1\n" - "add x3, x3, #32\n" - "add x4, x4, x11\n" - "cmp w6, w2\n" - "blt 50b\n" - "41:\n" - "add %[dst_ptr], %[dst_ptr], #32\n" - // At this point we have completely finished writing values to the - // destination matrix for the current block. - - // Reload some params --- we had used x5 -- x7 for a few other things - // since the last time we had loaded them. - "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" - "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" - "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" - - // Move to the next block of the destination matrix, for the next iter - // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already - // been updated earlier. - // Have we reached the end row? - "cmp %w[row], w7\n" - "beq 20f\n" // yes, end row. - // Not end row. Move to the next row. - "add %w[row], %w[row], #8\n" - "b 21f\n" - "20:\n" - // Was already at end row. - "mov %w[row], w6\n" // Move back to first row. - "add %w[col], %w[col], #8\n" // Move to the next column. - "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #3\n" - "mov %[dst_ptr], %[dst_col_ptr]\n" - "21:\n" - - // Main loop exit condition: have we hit the end column? - "cmp %w[col], w8\n" - - // w1 is the number of levels of depth that remain to load - // LHS and RHS data for. Corresponding to the initial ld1 instructions - // above, this is currently depth - 1. - "sub w1, w12, #1\n" - - "ble 1b\n" - - // clang-format on - - : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr), - [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), - [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col) - : [ params ] "r"(¶ms), [dst_rows] "r"(params.dst_rows), - [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf) - : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc", - "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", - "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", - "v26", "v27", "v28", "v29", "v30", "v31"); -} -#undef RUY_OFFSET_BIAS -#undef RUY_OFFSET_FLAGS -#undef RUY_OFFSET_LHS_BASE_PTR -#undef RUY_OFFSET_CLAMP_MIN -#undef RUY_OFFSET_CLAMP_MAX -#undef RUY_OFFSET_START_ROW -#undef RUY_OFFSET_LAST_ROW -#undef RUY_OFFSET_LAST_COL -#undef RUY_OFFSET_LHS_STRIDE -#undef RUY_OFFSET_RHS_STRIDE -#undef RUY_OFFSET_DST_STRIDE -#undef RUY_OFFSET_DEPTH -#undef RUY_OFFSET_START_COL -#undef RUY_OFFSET_RHS_BASE_PTR -#undef RUY_OFFSET_DST_BASE_PTR - -#endif // RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM) - -} // namespace ruy diff --git a/kernel_avx2.cc b/kernel_avx2.cc deleted file mode 100644 index da660b4..0000000 --- a/kernel_avx2.cc +++ /dev/null @@ -1,1664 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include - -#include "check_macros.h" -#include "kernel.h" -#include "opt_set.h" -#include "platform.h" -#include "profiler/instrumentation.h" - -#if RUY_PLATFORM(AVX2) && RUY_OPT_ENABLED(RUY_OPT_ASM) -#include // IWYU pragma: keep -#endif - -namespace ruy { - -#if !(RUY_PLATFORM(AVX2) && RUY_OPT_ENABLED(RUY_OPT_ASM)) - -void Kernel8bitAvx2(const KernelParams8bit<8, 8>& params) { - // CPU-ID-based checks should disable the path that would reach this point. - RUY_DCHECK(false); -} - -void Kernel8bitAvx2SingleCol(const KernelParams8bit<8, 8>& params) { - // CPU-ID-based checks should disable the path that would reach this point. - RUY_DCHECK(false); -} - -void KernelFloatAvx2(const KernelParamsFloat<8, 8>& params) { - // CPU-ID-based checks should disable the path that would reach this point. - RUY_DCHECK(false); -} - -void KernelFloatAvx2SingleCol(const KernelParamsFloat<8, 8>& params) { - // CPU-ID-based checks should disable the path that would reach this point. - RUY_DCHECK(false); -} - -#else // RUY_PLATFORM(AVX2) && RUY_OPT_ENABLED(RUY_OPT_ASM) - -static constexpr int kAvx8bitBlockSize = 8; -static constexpr int kAvx8bitInnerSize = 4; - -namespace { -namespace intrin_utils { - -inline __m256 mm256_n_loadu_epi32(int n, const std::int32_t* src) { - switch (n) { - case 0: - return _mm256_setzero_si256(); - case 1: - return _mm256_setr_m128(_mm_setr_epi32(src[0], 0, 0, 0), - _mm_setzero_si128()); - case 2: - return _mm256_setr_m128(_mm_setr_epi32(src[0], src[1], 0, 0), - _mm_setzero_si128()); - case 3: - return _mm256_setr_m128(_mm_setr_epi32(src[0], src[1], src[2], 0), - _mm_setzero_si128()); - case 4: - return _mm256_castsi128_si256( - _mm_loadu_si128(reinterpret_cast<__m128i const*>(src))); - case 5: - return _mm256_setr_epi32(src[0], src[1], src[2], src[3], src[4], 0, 0, 0); - case 6: - return _mm256_setr_epi32(src[0], src[1], src[2], src[3], src[4], src[5], - 0, 0); - case 7: - return _mm256_setr_epi32(src[0], src[1], src[2], src[3], src[4], src[5], - src[6], 0); - case 8: - return _mm256_loadu_si256(reinterpret_cast<__m256i const*>(src)); - default: - RUY_DCHECK_LT(n, 9); - return _mm256_setzero_si256(); - } -} - -inline void mm256_n_storeu_cvtepi32_epi8(std::uint8_t* dst, int residual_rows, - const __m256 v) { - // Select bytes 0, 4, 8, 12 within each lane, effectively truncating. - const __m256i repack_perm = _mm256_set1_epi32(0x0c080400); - __m256i shuffled_v; - if (residual_rows > 1) { - // This selects 0, 4, 8, 12, 0, 4, 8, 12, ..., but we only use the first 4 - // in each 128-bit lane. - shuffled_v = _mm256_shuffle_epi8(v, repack_perm); - } - switch (residual_rows) { - case 0: - break; - case 1: - dst[0] = _mm256_extract_epi8(v, 0); - break; - case 2: - _mm_storeu_si16(dst, _mm256_extracti128_si256(shuffled_v, 0)); - break; - case 3: { - __m128i trailing_packed = _mm256_extracti128_si256(shuffled_v, 0); - _mm_storeu_si16(dst, trailing_packed); - dst[2] = _mm_extract_epi8(trailing_packed, 2); - break; - } - case 4: - _mm_storeu_si32(dst, _mm256_extracti128_si256(shuffled_v, 0)); - break; - case 5: - _mm_storeu_si32(dst, _mm256_extracti128_si256(shuffled_v, 0)); - dst[4] = _mm256_extract_epi8(shuffled_v, 16); - break; - case 6: - _mm_storeu_si32(dst, _mm256_extracti128_si256(shuffled_v, 0)); - _mm_storeu_si16(dst + 4, _mm256_extracti128_si256(shuffled_v, 1)); - break; - case 7: { - _mm_storeu_si32(dst, _mm256_extracti128_si256(shuffled_v, 0)); - __m128i trailing_packed = _mm256_extracti128_si256(shuffled_v, 1); - _mm_storeu_si16(dst + 4, trailing_packed); - dst[6] = _mm_extract_epi8(trailing_packed, 2); - break; - } - case 8: - _mm_storeu_si32(dst, _mm256_extracti128_si256(shuffled_v, 0)); - _mm_storeu_si32(dst + 4, _mm256_extracti128_si256(shuffled_v, 1)); - break; - default: - RUY_DCHECK_LE(residual_rows, 8); - break; - } -} - -inline void mm256_storeu_cvtepi32_epi8(std::uint8_t* dst, const __m256 v) { - // Select bytes 0, 4, 8, 12 within each lane, effectively truncating. - const __m256i repack_perm = _mm256_set1_epi32(0x0c080400); - const __m256i shuffled_v = _mm256_shuffle_epi8(v, repack_perm); - _mm_storeu_si32(dst, _mm256_extracti128_si256(shuffled_v, 0)); - _mm_storeu_si32(dst + 4, _mm256_extracti128_si256(shuffled_v, 1)); -} - -inline void mm256_n_storeu_cvtepi32_epi8(std::int8_t* dst, int residual_rows, - const __m256 v) { - intrin_utils::mm256_n_storeu_cvtepi32_epi8( - reinterpret_cast(dst), residual_rows, v); -} - -inline void mm256_storeu_cvtepi32_epi8(std::int8_t* dst, const __m256 v) { - // Select bytes 0, 4, 8, 12 within each lane, effectively truncating. - const __m256i repack_perm = _mm256_set1_epi32(0x0c080400); - const __m256i shuffled_v = _mm256_shuffle_epi8(v, repack_perm); - _mm_storeu_si32(dst, _mm256_extracti128_si256(shuffled_v, 0)); - _mm_storeu_si32(dst + 4, _mm256_extracti128_si256(shuffled_v, 1)); -} - -inline void mm256_n_storeu_cvtepi32_epi16(std::int16_t* dst, int residual_rows, - const __m256 v) { - // Select bytes 0, 1, 4, 5, 8, 9, 12, 13 within each lane, effectively - // truncating each 16-bit integer. - const __m256i repack_perm = _mm256_set1_epi64x(0x0d0c090805040100); - __m256i shuffled_v; - __m128i shuffled_v_low; - if (residual_rows > 1) { - shuffled_v = _mm256_shuffle_epi8(v, repack_perm); - shuffled_v_low = _mm256_extracti128_si256(shuffled_v, 0); - } else { - shuffled_v_low = _mm256_extracti128_si256(v, 0); - } - switch (residual_rows) { - case 0: - break; - case 1: - _mm_storeu_si16(dst, shuffled_v_low); - break; - case 2: - _mm_storeu_si32(dst, shuffled_v_low); - break; - case 3: { - _mm_storeu_si32(dst, shuffled_v_low); - dst[2] = _mm_extract_epi16(shuffled_v_low, 2); - break; - } - case 4: - _mm_storeu_si64(dst, shuffled_v_low); - break; - case 5: - _mm_storeu_si64(dst, shuffled_v_low); - dst[4] = _mm256_extract_epi16(shuffled_v, 8); - break; - case 6: - _mm_storeu_si64(dst, shuffled_v_low); - _mm_storeu_si32(dst + 4, _mm256_extracti128_si256(shuffled_v, 1)); - break; - case 7: { - _mm_storeu_si64(dst, shuffled_v_low); - __m128i trailing_packed = _mm256_extracti128_si256(shuffled_v, 1); - _mm_storeu_si32(dst + 4, trailing_packed); - dst[6] = _mm_extract_epi16(trailing_packed, 2); - break; - } - case 8: - _mm_storeu_si64(dst, _mm256_extracti128_si256(shuffled_v, 0)); - _mm_storeu_si64(dst + 4, _mm256_extracti128_si256(shuffled_v, 1)); - break; - default: - RUY_DCHECK_LE(residual_rows, 8); - break; - } -} - -inline void mm256_storeu_cvtepi32_epi16(std::int16_t* dst, const __m256 v) { - // Select bytes 0, 1, 4, 5, 8, 9, 12, 13 within each lane, effectively - // truncating each 16-bit integer. - const __m256i repack_perm = _mm256_set1_epi64x(0x0d0c090805040100); - const __m256i shuffled_v = _mm256_shuffle_epi8(v, repack_perm); - _mm_storeu_si64(dst, _mm256_extracti128_si256(shuffled_v, 0)); - _mm_storeu_si64(dst + 4, _mm256_extracti128_si256(shuffled_v, 1)); -} - -inline void mm256_n_storeu_epi32(std::int32_t* dst, int residual_rows, - const __m256 v) { - const __m128i v_low = _mm256_extracti128_si256(v, 0); - switch (residual_rows) { - case 0: - break; - case 1: - _mm_storeu_si32(dst, v_low); - break; - case 2: - _mm_storeu_si64(dst, v_low); - break; - case 3: { - __m128i trailing_packed = v_low; - _mm_storeu_si64(dst, trailing_packed); - dst[2] = _mm_extract_epi32(trailing_packed, 2); - break; - } - case 4: - _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), v_low); - break; - case 5: - _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), v_low); - dst[4] = _mm256_extract_epi32(v, 4); - break; - case 6: - _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), v_low); - _mm_storeu_si64(dst + 4, _mm256_extracti128_si256(v, 1)); - break; - case 7: { - _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), v_low); - __m128i trailing_packed = _mm256_extracti128_si256(v, 1); - _mm_storeu_si64(dst + 4, trailing_packed); - dst[6] = _mm_extract_epi32(trailing_packed, 2); - break; - } - case 8: - _mm256_storeu_si256(reinterpret_cast<__m256i*>(dst), v); - break; - default: - RUY_DCHECK_LE(residual_rows, 8); - break; - } -} - -inline void mm256_storeu_epi32(std::int32_t* dst, const __m256 v) { - _mm256_storeu_si256(reinterpret_cast<__m256i*>(dst), v); -} - -inline float mm256_get1_ps(const __m256 a, int i) { - __m256i ai = _mm256_castps_si256(a); - int float_val_as_int; - switch (i) { - case 0: - float_val_as_int = _mm256_extract_epi32(ai, 0); - break; - case 1: - float_val_as_int = _mm256_extract_epi32(ai, 1); - break; - case 2: - float_val_as_int = _mm256_extract_epi32(ai, 2); - break; - case 3: - float_val_as_int = _mm256_extract_epi32(ai, 3); - break; - case 4: - float_val_as_int = _mm256_extract_epi32(ai, 4); - break; - case 5: - float_val_as_int = _mm256_extract_epi32(ai, 5); - break; - case 6: - float_val_as_int = _mm256_extract_epi32(ai, 6); - break; - case 7: - float_val_as_int = _mm256_extract_epi32(ai, 7); - break; - default: - RUY_DCHECK_LT(i, 8); - return .0f; - } - return reinterpret_cast(float_val_as_int); -} - -inline __m256 mm256_n_loadu_ps(int i, const float* src) { - switch (i) { - case 0: - return _mm256_setzero_ps(); - case 1: - return _mm256_setr_m128(_mm_setr_ps(src[0], .0f, .0f, .0f), - _mm_setzero_ps()); - case 2: - return _mm256_setr_m128(_mm_setr_ps(src[0], src[1], .0f, .0f), - _mm_setzero_ps()); - case 3: - return _mm256_setr_m128(_mm_setr_ps(src[0], src[1], src[2], .0f), - _mm_setzero_ps()); - case 4: - return _mm256_setr_m128(_mm_setr_ps(src[0], src[1], src[2], src[3]), - _mm_setzero_ps()); - case 5: - return _mm256_setr_ps(src[0], src[1], src[2], src[3], src[4], .0f, .0f, - .0f); - case 6: - return _mm256_setr_ps(src[0], src[1], src[2], src[3], src[4], src[5], .0f, - .0f); - case 7: - return _mm256_setr_ps(src[0], src[1], src[2], src[3], src[4], src[5], - src[6], .0f); - case 8: - return _mm256_loadu_ps(src); - default: - RUY_DCHECK_LT(i, 9); - return _mm256_setzero_ps(); - } -} - -inline void mm256_n_storeu_ps(float* dst, int residual_rows, const __m256 v) { - for (int i = 0; i < residual_rows; ++i) { - dst[i] = intrin_utils::mm256_get1_ps(v, i); - } -} -} // namespace intrin_utils -} // namespace - -void Kernel8bitAvx2(const KernelParams8bit<8, 8>& params) { - profiler::ScopeLabel label("Kernel kAvx2 8-bit"); - const std::int8_t splitter_idx_data[32] = { - 0, 1, 4, 5, 8, 9, 12, 13, // - 2, 3, 6, 7, 10, 11, 14, 15, // - 0, 1, 4, 5, 8, 9, 12, 13, // - 2, 3, 6, 7, 10, 11, 14, 15 // - }; - - std::int32_t dst_stride; - if ((params.dst_type_id == DstTypeId::kValue) || - (params.dst_type_id == DstTypeId::kValue)) { - dst_stride = params.dst_stride; - } else if (params.dst_type_id == DstTypeId::kValue) { - dst_stride = params.dst_stride / sizeof(std::int16_t); - } else if (params.dst_type_id == DstTypeId::kValue) { - dst_stride = params.dst_stride / sizeof(std::int32_t); - } else { - RUY_DCHECK(false); - } - - int bias_ptr_block_increment = - params.flags & RUY_ASM_FLAG_HAS_BIAS ? kAvx8bitBlockSize : 0; - - const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; - void* dst_col_ptr = params.dst_base_ptr; - const std::int32_t* bias_col_ptr = params.bias; - if (params.flags & RUY_ASM_FLAG_HAS_BIAS) { - bias_col_ptr += params.start_row; - } - - for (int col = params.start_col; col <= params.last_col; - col += kAvx8bitBlockSize) { - const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; - void* dst_ptr = dst_col_ptr; - const std::int32_t* bias_ptr = bias_col_ptr; - - const std::int32_t lhs_zero_point = params.lhs_zero_point; - const bool has_rhs_sums_offsets = - (params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && lhs_zero_point; - std::int32_t rhs_sums_offsets[8]; - if (has_rhs_sums_offsets) { - const __m256i rhs_sums_offset_v = _mm256_mullo_epi32( - _mm256_set1_epi32(lhs_zero_point), - _mm256_loadu_si256( - reinterpret_cast<__m256i const*>(¶ms.rhs_sums[col]))); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_sums_offsets), - rhs_sums_offset_v); - } - - for (int row = params.start_row; row <= params.last_row; - row += kAvx8bitBlockSize) { - const int residual_rows = - std::min(params.dst_rows - row, kAvx8bitBlockSize); - const int residual_cols = - std::min(params.dst_cols - col, kAvx8bitBlockSize); - - const __m256i splitter_idx = _mm256_loadu_si256( - reinterpret_cast<__m256i const*>(splitter_idx_data)); - - __m256i accum_data_v0; - __m256i accum_data_v1; - __m256i accum_data_v2; - __m256i accum_data_v3; - __m256i accum_data_v4; - __m256i accum_data_v5; - __m256i accum_data_v6; - __m256i accum_data_v7; - - // Initialize with bias. - __m256i initial_accum_data = - intrin_utils::mm256_n_loadu_epi32(residual_rows, bias_ptr); - bias_ptr += bias_ptr_block_increment; - - // Adjustments common across columns. - const std::int32_t rhs_zero_point = params.rhs_zero_point; - if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && rhs_zero_point) { - const __m256i lhs_sums_offset = _mm256_mullo_epi32( - _mm256_set1_epi32(rhs_zero_point), - _mm256_loadu_si256( - reinterpret_cast<__m256i const*>(¶ms.lhs_sums[row]))); - initial_accum_data = - _mm256_sub_epi32(initial_accum_data, lhs_sums_offset); - } - const std::int32_t prod_zp_depth = params.prod_zp_depth; - if (prod_zp_depth) { - initial_accum_data = _mm256_add_epi32(initial_accum_data, - _mm256_set1_epi32(prod_zp_depth)); - } - - // Adjustments differing across columns. - if (has_rhs_sums_offsets) { - accum_data_v0 = _mm256_sub_epi32( - initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[0])); - accum_data_v1 = _mm256_sub_epi32( - initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[1])); - accum_data_v2 = _mm256_sub_epi32( - initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[2])); - accum_data_v3 = _mm256_sub_epi32( - initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[3])); - accum_data_v4 = _mm256_sub_epi32( - initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[4])); - accum_data_v5 = _mm256_sub_epi32( - initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[5])); - accum_data_v6 = _mm256_sub_epi32( - initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[6])); - accum_data_v7 = _mm256_sub_epi32( - initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[7])); - } else { - accum_data_v0 = initial_accum_data; - accum_data_v1 = initial_accum_data; - accum_data_v2 = initial_accum_data; - accum_data_v3 = initial_accum_data; - accum_data_v4 = initial_accum_data; - accum_data_v5 = initial_accum_data; - accum_data_v6 = initial_accum_data; - accum_data_v7 = initial_accum_data; - } - - const std::int8_t* lhs_ptr = lhs_col_ptr; - const std::int8_t* rhs_ptr = rhs_col_ptr; - for (int d = 0; d < params.depth; d += kAvx8bitInnerSize) { - const __m256i lhs_data = - _mm256_load_si256(reinterpret_cast(lhs_ptr)); - const __m256i rhs_data_8bit = - _mm256_load_si256(reinterpret_cast(rhs_ptr)); - - // Each "int32" is two 16-bit RHS values, sign extended from 8-bit. - std::int32_t rhs_data[16]; - const __m128i rhs_data_bottom_lane = - _mm256_castsi256_si128(rhs_data_8bit); - const __m128i rhs_data_top_lane = - _mm256_extracti128_si256(rhs_data_8bit, 1); - const __m256i rhs_16_bit_dup_low = - _mm256_cvtepi8_epi16(rhs_data_bottom_lane); - const __m256i rhs_16_bit_dup_high = - _mm256_cvtepi8_epi16(rhs_data_top_lane); - // Now that we have cast the RHS data, we store it so that each value - // can be separately loaded in the accumulation loop. - _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_data), - rhs_16_bit_dup_low); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_data + 8), - rhs_16_bit_dup_high); - - // NOTE: There may be opportunities for permuting the data in the - // packing code instead of here. - const __m256i lhs_data_split = - _mm256_shuffle_epi8(lhs_data, splitter_idx); - const __m256i lhs_data_split_expand_bottom = - _mm256_cvtepi8_epi16(_mm256_extracti128_si256(lhs_data_split, 0)); - const __m256i lhs_data_split_expand_top = - _mm256_cvtepi8_epi16(_mm256_extracti128_si256(lhs_data_split, 1)); - - // Take bytes 0, 1, 4, 5, 8, 9, ... expanded to 16-bit. - const __m256i lhs_16_bit_low = _mm256_permute2x128_si256( - lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x20); - // Take bytes 2, 3, 6, 7, 10, 11, ... expanded to 16-bit. - const __m256i lhs_16_bit_high = _mm256_permute2x128_si256( - lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x31); - // Accumulate for column 0. - { - const std::int32_t low_rhs_value = rhs_data[0]; - const std::int32_t high_rhs_value = rhs_data[1]; - - const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value); - const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value); - - accum_data_v0 = _mm256_add_epi32( - accum_data_v0, - _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_data_v0 = _mm256_add_epi32( - accum_data_v0, - _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - } - // Accumulate for column 1. - { - const std::int32_t low_rhs_value = rhs_data[2]; - const std::int32_t high_rhs_value = rhs_data[3]; - - const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value); - const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value); - - accum_data_v1 = _mm256_add_epi32( - accum_data_v1, - _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_data_v1 = _mm256_add_epi32( - accum_data_v1, - _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - } - // Accumulate for column 2. - { - const std::int32_t low_rhs_value = rhs_data[4]; - const std::int32_t high_rhs_value = rhs_data[5]; - - const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value); - const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value); - - accum_data_v2 = _mm256_add_epi32( - accum_data_v2, - _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_data_v2 = _mm256_add_epi32( - accum_data_v2, - _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - } - // Accumulate for column 3. - { - const std::int32_t low_rhs_value = rhs_data[6]; - const std::int32_t high_rhs_value = rhs_data[7]; - - const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value); - const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value); - - accum_data_v3 = _mm256_add_epi32( - accum_data_v3, - _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_data_v3 = _mm256_add_epi32( - accum_data_v3, - _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - } - // Accumulate for column 4. - { - const std::int32_t low_rhs_value = rhs_data[8]; - const std::int32_t high_rhs_value = rhs_data[9]; - - const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value); - const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value); - - accum_data_v4 = _mm256_add_epi32( - accum_data_v4, - _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_data_v4 = _mm256_add_epi32( - accum_data_v4, - _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - } - // Accumulate for column 5. - { - const std::int32_t low_rhs_value = rhs_data[10]; - const std::int32_t high_rhs_value = rhs_data[11]; - - const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value); - const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value); - - accum_data_v5 = _mm256_add_epi32( - accum_data_v5, - _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_data_v5 = _mm256_add_epi32( - accum_data_v5, - _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - } - // Accumulate for column 6. - { - const std::int32_t low_rhs_value = rhs_data[12]; - const std::int32_t high_rhs_value = rhs_data[13]; - - const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value); - const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value); - - accum_data_v6 = _mm256_add_epi32( - accum_data_v6, - _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_data_v6 = _mm256_add_epi32( - accum_data_v6, - _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - } - // Accumulate for column 7. - { - const std::int32_t low_rhs_value = rhs_data[14]; - const std::int32_t high_rhs_value = rhs_data[15]; - - const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value); - const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value); - - accum_data_v7 = _mm256_add_epi32( - accum_data_v7, - _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_data_v7 = _mm256_add_epi32( - accum_data_v7, - _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - } - - lhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize; - rhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize; - } - - if (params.dst_type_id != DstTypeId::kValue) { - __m256i m_vector; - __m256i e_vector; - // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT. - if (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) { - m_vector = intrin_utils::mm256_n_loadu_epi32( - residual_rows, ¶ms.multiplier_fixedpoint[row]); - e_vector = intrin_utils::mm256_n_loadu_epi32( - residual_rows, ¶ms.multiplier_exponent[row]); - } else { - // These arrays have size LhsCols, and are pre-filled. - m_vector = _mm256_set1_epi32(params.multiplier_fixedpoint[0]); - e_vector = _mm256_set1_epi32(params.multiplier_exponent[0]); - } - - const __m256i m_64bit_low = - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(m_vector, 0)); - const __m256i m_64bit_high = - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(m_vector, 1)); - - const __m256i zero_vector = _mm256_setzero_si256(); - const __m256i left_shift = _mm256_max_epi32(e_vector, zero_vector); - const __m256i neg_e_vector = _mm256_sub_epi32(zero_vector, e_vector); - const __m256i right_shift = _mm256_max_epi32(neg_e_vector, zero_vector); - const __m256i final_right_shift = - _mm256_add_epi32(right_shift, _mm256_set1_epi32(31)); - const __m256i final_right_shift_low = _mm256_cvtepi32_epi64( - _mm256_extracti128_si256(final_right_shift, 0)); - const __m256i final_right_shift_high = _mm256_cvtepi32_epi64( - _mm256_extracti128_si256(final_right_shift, 1)); - // Really we want 0x100000000, but use half to avoid overflowing. - const __m256i convert_to_signed_halved = - _mm256_srlv_epi32(_mm256_set1_epi32(0x80000000), right_shift); - const __m256i convert_to_unsigned_64 = - _mm256_set1_epi64x(0x8000000000000000); - - __m256i post_scaling_offset = _mm256_add_epi32( - convert_to_signed_halved, convert_to_signed_halved); - - const __m256i offset_vector = - _mm256_slli_epi64(_mm256_set1_epi64x(1), 30); - // Really these should be shifted by neg_e_vector, but tests pass when - // using right_shift. - const __m256i offset_vector_low = _mm256_add_epi64( - _mm256_sllv_epi64(offset_vector, - _mm256_cvtepi32_epi64( - _mm256_extracti128_si256(right_shift, 0))), - convert_to_unsigned_64); - const __m256i offset_vector_high = _mm256_add_epi64( - _mm256_sllv_epi64(offset_vector, - _mm256_cvtepi32_epi64( - _mm256_extracti128_si256(right_shift, 1))), - convert_to_unsigned_64); - - if (params.dst_zero_point) { - const __m256i dst_zero_point = - _mm256_set1_epi32(params.dst_zero_point); - // The post-scaling offset is subtracted later, so this has the effect - // of adding the zero point. - post_scaling_offset = - _mm256_sub_epi32(post_scaling_offset, dst_zero_point); - } - -#if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING) - RUY_DCHECK(false); -#endif - const __m256i repack_perm = _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7); - - // We cannot do - // - // scaled_v_low = - // _mm256_srav_epi64(scaled_v_low, final_right_shift_low); - // scaled_v_high = - // _mm256_srav_epi64(scaled_v_high, final_right_shift_high); - // - // since this instruction is not in AVX2. Instead we use - // _mm256_srlv_epi64, but this is an unsigned shift, so we applied - // offsets before (convert_to_unsigned_64) and after - // (convert_to_signed_halved). - // - // The overall process is, for 64-bit scaled accumulator: - // unsigned_accum = signed_accum + 1 << 63; - // unsigned_accum = (unsigned_accum >> right_shift) >> 31; - // signed_accum = unsigned_accum - ((1 << 32) >> right_shift) / 2 * 2; - - // There are various ways to repack the results, in the absence of - // _mm256_cvtepi64_epi32() or anything like it. - // A. - // accum_data_v[j] = - // _mm256_set_epi32(_mm256_extract_epi32(scaled_v_high, 6), - // _mm256_extract_epi32(scaled_v_high, 4), - // _mm256_extract_epi32(scaled_v_high, 2), - // _mm256_extract_epi32(scaled_v_high, 0), - // _mm256_extract_epi32(scaled_v_low, 6), - // _mm256_extract_epi32(scaled_v_low, 4), - // _mm256_extract_epi32(scaled_v_low, 2), - // _mm256_extract_epi32(scaled_v_low, 0)); - // B. - // scaled_v_low = _mm256_shuffle_epi32(scaled_v_low, 0xd8); - // scaled_v_high = _mm256_shuffle_epi32(scaled_v_high, 0xd8); - // accum_data_v[j] = - // _mm256_set_epi64x(_mm256_extract_epi64(scaled_v_high, 2), - // _mm256_extract_epi64(scaled_v_high, 0), - // _mm256_extract_epi64(scaled_v_low, 2), - // _mm256_extract_epi64(scaled_v_low, 0)); - // C. - // scaled_v_low = - // _mm256_permutevar8x32_epi32(scaled_v_low, repack_perm); - // scaled_v_high = - // _mm256_permutevar8x32_epi32(scaled_v_high, repack_perm); - // accum_data_v[j] = - // _mm256_permute2x128_si256(scaled_v_low, scaled_v_high, 0x20); - // - // However, we choose the following because it uses two lighter - // instructions. The permutation does have a longer latency, but this - // loop can be unrolled. - // D. - // scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32); - // __m256i results = - // _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa); - // results = _mm256_permutevar8x32_epi32(results, repack_perm); - // accum_data_v[j] = _mm256_sub_epi32(results, post_scaling_offset); - { - __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v0, left_shift); - // Apply the fixed-point part of the multiplier. - __m256i scaled_v_low = _mm256_mul_epi32( - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)), - m_64bit_low); - __m256i scaled_v_high = _mm256_mul_epi32( - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)), - m_64bit_high); - - scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = - _mm256_srlv_epi64(scaled_v_high, final_right_shift_high); - - scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32); - __m256i results = - _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa); - results = _mm256_permutevar8x32_epi32(results, repack_perm); - - accum_data_v0 = _mm256_sub_epi32(results, post_scaling_offset); - } - { - __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v1, left_shift); - // Apply the fixed-point part of the multiplier. - __m256i scaled_v_low = _mm256_mul_epi32( - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)), - m_64bit_low); - __m256i scaled_v_high = _mm256_mul_epi32( - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)), - m_64bit_high); - - scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = - _mm256_srlv_epi64(scaled_v_high, final_right_shift_high); - - scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32); - __m256i results = - _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa); - results = _mm256_permutevar8x32_epi32(results, repack_perm); - - accum_data_v1 = _mm256_sub_epi32(results, post_scaling_offset); - } - { - __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v2, left_shift); - // Apply the fixed-point part of the multiplier. - __m256i scaled_v_low = _mm256_mul_epi32( - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)), - m_64bit_low); - __m256i scaled_v_high = _mm256_mul_epi32( - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)), - m_64bit_high); - - scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = - _mm256_srlv_epi64(scaled_v_high, final_right_shift_high); - - scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32); - __m256i results = - _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa); - results = _mm256_permutevar8x32_epi32(results, repack_perm); - - accum_data_v2 = _mm256_sub_epi32(results, post_scaling_offset); - } - { - __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v3, left_shift); - // Apply the fixed-point part of the multiplier. - __m256i scaled_v_low = _mm256_mul_epi32( - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)), - m_64bit_low); - __m256i scaled_v_high = _mm256_mul_epi32( - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)), - m_64bit_high); - - scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = - _mm256_srlv_epi64(scaled_v_high, final_right_shift_high); - - scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32); - __m256i results = - _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa); - results = _mm256_permutevar8x32_epi32(results, repack_perm); - - accum_data_v3 = _mm256_sub_epi32(results, post_scaling_offset); - } - { - __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v4, left_shift); - // Apply the fixed-point part of the multiplier. - __m256i scaled_v_low = _mm256_mul_epi32( - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)), - m_64bit_low); - __m256i scaled_v_high = _mm256_mul_epi32( - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)), - m_64bit_high); - - scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = - _mm256_srlv_epi64(scaled_v_high, final_right_shift_high); - - scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32); - __m256i results = - _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa); - results = _mm256_permutevar8x32_epi32(results, repack_perm); - - accum_data_v4 = _mm256_sub_epi32(results, post_scaling_offset); - } - { - __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v5, left_shift); - // Apply the fixed-point part of the multiplier. - __m256i scaled_v_low = _mm256_mul_epi32( - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)), - m_64bit_low); - __m256i scaled_v_high = _mm256_mul_epi32( - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)), - m_64bit_high); - - scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = - _mm256_srlv_epi64(scaled_v_high, final_right_shift_high); - - scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32); - __m256i results = - _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa); - results = _mm256_permutevar8x32_epi32(results, repack_perm); - - accum_data_v5 = _mm256_sub_epi32(results, post_scaling_offset); - } - { - __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v6, left_shift); - // Apply the fixed-point part of the multiplier. - __m256i scaled_v_low = _mm256_mul_epi32( - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)), - m_64bit_low); - __m256i scaled_v_high = _mm256_mul_epi32( - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)), - m_64bit_high); - - scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = - _mm256_srlv_epi64(scaled_v_high, final_right_shift_high); - - scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32); - __m256i results = - _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa); - results = _mm256_permutevar8x32_epi32(results, repack_perm); - - accum_data_v6 = _mm256_sub_epi32(results, post_scaling_offset); - } - { - __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v7, left_shift); - // Apply the fixed-point part of the multiplier. - __m256i scaled_v_low = _mm256_mul_epi32( - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)), - m_64bit_low); - __m256i scaled_v_high = _mm256_mul_epi32( - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)), - m_64bit_high); - - scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = - _mm256_srlv_epi64(scaled_v_high, final_right_shift_high); - - scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32); - __m256i results = - _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa); - results = _mm256_permutevar8x32_epi32(results, repack_perm); - - accum_data_v7 = _mm256_sub_epi32(results, post_scaling_offset); - } - } - const __m256i clamp_max_v = _mm256_set1_epi32(params.clamp_max); - const __m256i clamp_min_v = _mm256_set1_epi32(params.clamp_min); - const bool store_full_block = (residual_rows == kAvx8bitBlockSize) && - (residual_cols == kAvx8bitBlockSize); - - __m256i accum_data_v[kAvx8bitBlockSize]; - if (!store_full_block) { - accum_data_v[0] = accum_data_v0; - accum_data_v[1] = accum_data_v1; - accum_data_v[2] = accum_data_v2; - accum_data_v[3] = accum_data_v3; - accum_data_v[4] = accum_data_v4; - accum_data_v[5] = accum_data_v5; - accum_data_v[6] = accum_data_v6; - accum_data_v[7] = accum_data_v7; - } - - if (params.dst_type_id == DstTypeId::kValue) { - std::int8_t* tmp_ptr = static_cast(dst_ptr); - if (store_full_block) { - accum_data_v0 = _mm256_min_epi32(accum_data_v0, clamp_max_v); - accum_data_v0 = _mm256_max_epi32(accum_data_v0, clamp_min_v); - accum_data_v1 = _mm256_min_epi32(accum_data_v1, clamp_max_v); - accum_data_v1 = _mm256_max_epi32(accum_data_v1, clamp_min_v); - accum_data_v2 = _mm256_min_epi32(accum_data_v2, clamp_max_v); - accum_data_v2 = _mm256_max_epi32(accum_data_v2, clamp_min_v); - accum_data_v3 = _mm256_min_epi32(accum_data_v3, clamp_max_v); - accum_data_v3 = _mm256_max_epi32(accum_data_v3, clamp_min_v); - accum_data_v4 = _mm256_min_epi32(accum_data_v4, clamp_max_v); - accum_data_v4 = _mm256_max_epi32(accum_data_v4, clamp_min_v); - accum_data_v5 = _mm256_min_epi32(accum_data_v5, clamp_max_v); - accum_data_v5 = _mm256_max_epi32(accum_data_v5, clamp_min_v); - accum_data_v6 = _mm256_min_epi32(accum_data_v6, clamp_max_v); - accum_data_v6 = _mm256_max_epi32(accum_data_v6, clamp_min_v); - accum_data_v7 = _mm256_min_epi32(accum_data_v7, clamp_max_v); - accum_data_v7 = _mm256_max_epi32(accum_data_v7, clamp_min_v); - intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[0 * dst_stride], - accum_data_v0); - intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[1 * dst_stride], - accum_data_v1); - intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[2 * dst_stride], - accum_data_v2); - intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[3 * dst_stride], - accum_data_v3); - intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[4 * dst_stride], - accum_data_v4); - intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[5 * dst_stride], - accum_data_v5); - intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[6 * dst_stride], - accum_data_v6); - intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[7 * dst_stride], - accum_data_v7); - } else { - for (int j = 0; j < residual_cols; ++j) { - __m256 result = accum_data_v[j]; - result = _mm256_min_epi32(result, clamp_max_v); - result = _mm256_max_epi32(result, clamp_min_v); - intrin_utils::mm256_n_storeu_cvtepi32_epi8(tmp_ptr, residual_rows, - result); - tmp_ptr += dst_stride; - } - } - dst_ptr = static_cast(static_cast(dst_ptr) + - kAvx8bitBlockSize); - } else if (params.dst_type_id == DstTypeId::kValue) { - std::uint8_t* tmp_ptr = static_cast(dst_ptr); - if (store_full_block) { - accum_data_v0 = _mm256_min_epi32(accum_data_v0, clamp_max_v); - accum_data_v0 = _mm256_max_epi32(accum_data_v0, clamp_min_v); - accum_data_v1 = _mm256_min_epi32(accum_data_v1, clamp_max_v); - accum_data_v1 = _mm256_max_epi32(accum_data_v1, clamp_min_v); - accum_data_v2 = _mm256_min_epi32(accum_data_v2, clamp_max_v); - accum_data_v2 = _mm256_max_epi32(accum_data_v2, clamp_min_v); - accum_data_v3 = _mm256_min_epi32(accum_data_v3, clamp_max_v); - accum_data_v3 = _mm256_max_epi32(accum_data_v3, clamp_min_v); - accum_data_v4 = _mm256_min_epi32(accum_data_v4, clamp_max_v); - accum_data_v4 = _mm256_max_epi32(accum_data_v4, clamp_min_v); - accum_data_v5 = _mm256_min_epi32(accum_data_v5, clamp_max_v); - accum_data_v5 = _mm256_max_epi32(accum_data_v5, clamp_min_v); - accum_data_v6 = _mm256_min_epi32(accum_data_v6, clamp_max_v); - accum_data_v6 = _mm256_max_epi32(accum_data_v6, clamp_min_v); - accum_data_v7 = _mm256_min_epi32(accum_data_v7, clamp_max_v); - accum_data_v7 = _mm256_max_epi32(accum_data_v7, clamp_min_v); - intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[0], accum_data_v0); - intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[dst_stride], - accum_data_v1); - intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[2 * dst_stride], - accum_data_v2); - intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[3 * dst_stride], - accum_data_v3); - intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[4 * dst_stride], - accum_data_v4); - intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[5 * dst_stride], - accum_data_v5); - intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[6 * dst_stride], - accum_data_v6); - intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[7 * dst_stride], - accum_data_v7); - } else { - for (int j = 0; j < residual_cols; ++j) { - __m256 result = accum_data_v[j]; - result = _mm256_min_epi32(result, clamp_max_v); - result = _mm256_max_epi32(result, clamp_min_v); - intrin_utils::mm256_n_storeu_cvtepi32_epi8(tmp_ptr, residual_rows, - result); - tmp_ptr += dst_stride; - } - } - dst_ptr = static_cast(static_cast(dst_ptr) + - kAvx8bitBlockSize); - } else if (params.dst_type_id == DstTypeId::kValue) { - std::int16_t* tmp_ptr = static_cast(dst_ptr); - if (store_full_block) { - accum_data_v0 = _mm256_min_epi32(accum_data_v0, clamp_max_v); - accum_data_v0 = _mm256_max_epi32(accum_data_v0, clamp_min_v); - accum_data_v1 = _mm256_min_epi32(accum_data_v1, clamp_max_v); - accum_data_v1 = _mm256_max_epi32(accum_data_v1, clamp_min_v); - accum_data_v2 = _mm256_min_epi32(accum_data_v2, clamp_max_v); - accum_data_v2 = _mm256_max_epi32(accum_data_v2, clamp_min_v); - accum_data_v3 = _mm256_min_epi32(accum_data_v3, clamp_max_v); - accum_data_v3 = _mm256_max_epi32(accum_data_v3, clamp_min_v); - accum_data_v4 = _mm256_min_epi32(accum_data_v4, clamp_max_v); - accum_data_v4 = _mm256_max_epi32(accum_data_v4, clamp_min_v); - accum_data_v5 = _mm256_min_epi32(accum_data_v5, clamp_max_v); - accum_data_v5 = _mm256_max_epi32(accum_data_v5, clamp_min_v); - accum_data_v6 = _mm256_min_epi32(accum_data_v6, clamp_max_v); - accum_data_v6 = _mm256_max_epi32(accum_data_v6, clamp_min_v); - accum_data_v7 = _mm256_min_epi32(accum_data_v7, clamp_max_v); - accum_data_v7 = _mm256_max_epi32(accum_data_v7, clamp_min_v); - intrin_utils::mm256_storeu_cvtepi32_epi16(&tmp_ptr[0], accum_data_v0); - intrin_utils::mm256_storeu_cvtepi32_epi16(&tmp_ptr[dst_stride], - accum_data_v1); - intrin_utils::mm256_storeu_cvtepi32_epi16(&tmp_ptr[2 * dst_stride], - accum_data_v2); - intrin_utils::mm256_storeu_cvtepi32_epi16(&tmp_ptr[3 * dst_stride], - accum_data_v3); - intrin_utils::mm256_storeu_cvtepi32_epi16(&tmp_ptr[4 * dst_stride], - accum_data_v4); - intrin_utils::mm256_storeu_cvtepi32_epi16(&tmp_ptr[5 * dst_stride], - accum_data_v5); - intrin_utils::mm256_storeu_cvtepi32_epi16(&tmp_ptr[6 * dst_stride], - accum_data_v6); - intrin_utils::mm256_storeu_cvtepi32_epi16(&tmp_ptr[7 * dst_stride], - accum_data_v7); - } else { - for (int j = 0; j < residual_cols; ++j) { - __m256 result = accum_data_v[j]; - result = _mm256_min_epi32(result, clamp_max_v); - result = _mm256_max_epi32(result, clamp_min_v); - intrin_utils::mm256_n_storeu_cvtepi32_epi16(tmp_ptr, residual_rows, - result); - tmp_ptr += dst_stride; - } - } - dst_ptr = static_cast(static_cast(dst_ptr) + - kAvx8bitBlockSize); - } else if (params.dst_type_id == DstTypeId::kValue) { - if (store_full_block) { - std::int32_t* tmp_ptr = static_cast(dst_ptr); - intrin_utils::mm256_storeu_epi32(&tmp_ptr[0], accum_data_v0); - intrin_utils::mm256_storeu_epi32(&tmp_ptr[dst_stride], accum_data_v1); - intrin_utils::mm256_storeu_epi32(&tmp_ptr[2 * dst_stride], - accum_data_v2); - intrin_utils::mm256_storeu_epi32(&tmp_ptr[3 * dst_stride], - accum_data_v3); - intrin_utils::mm256_storeu_epi32(&tmp_ptr[4 * dst_stride], - accum_data_v4); - intrin_utils::mm256_storeu_epi32(&tmp_ptr[5 * dst_stride], - accum_data_v5); - intrin_utils::mm256_storeu_epi32(&tmp_ptr[6 * dst_stride], - accum_data_v6); - intrin_utils::mm256_storeu_epi32(&tmp_ptr[7 * dst_stride], - accum_data_v7); - } else { - std::int32_t* dst_block_ptr = static_cast(dst_ptr); - for (int j = 0; j < residual_cols; ++j) { - intrin_utils::mm256_n_storeu_epi32(dst_block_ptr, residual_rows, - accum_data_v[j]); - dst_block_ptr += dst_stride; - } - } - dst_ptr = static_cast(static_cast(dst_ptr) + - kAvx8bitBlockSize); - } else { - RUY_DCHECK(false); - } - - lhs_col_ptr += kAvx8bitBlockSize * params.lhs_stride; - } // End row-block loop. - - dst_col_ptr = static_cast(static_cast(dst_col_ptr) + - kAvx8bitBlockSize * params.dst_stride); - rhs_col_ptr += kAvx8bitBlockSize * params.rhs_stride; - } // End col-block loop. -} // NOLINT(readability/fn_size) - -void Kernel8bitAvx2SingleCol(const KernelParams8bit<8, 8>& params) { - profiler::ScopeLabel label("Kernel kAvx2 8-bit GEMV"); - - RUY_DCHECK_EQ(params.dst_cols, 1); - RUY_DCHECK_EQ(params.last_col, 0); - RUY_DCHECK_EQ(params.start_col, 0); - - const std::int8_t splitter_idx_data[32] = { - 0, 1, 4, 5, 8, 9, 12, 13, // - 2, 3, 6, 7, 10, 11, 14, 15, // - 0, 1, 4, 5, 8, 9, 12, 13, // - 2, 3, 6, 7, 10, 11, 14, 15 // - }; - - int bias_ptr_block_increment = - params.flags & RUY_ASM_FLAG_HAS_BIAS ? kAvx8bitBlockSize : 0; - - const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; - void* dst_col_ptr = params.dst_base_ptr; - const std::int32_t* bias_col_ptr = params.bias; - if (params.flags & RUY_ASM_FLAG_HAS_BIAS) { - bias_col_ptr += params.start_row; - } - - const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; - void* dst_ptr = dst_col_ptr; - const std::int32_t* bias_ptr = bias_col_ptr; - - const std::int32_t lhs_zero_point = params.lhs_zero_point; - const bool has_rhs_sums_offsets = - (params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && lhs_zero_point; - std::int32_t rhs_sums_offsets[8]; - if (has_rhs_sums_offsets) { - const __m256i rhs_sums_offset_v = _mm256_mullo_epi32( - _mm256_set1_epi32(lhs_zero_point), - _mm256_loadu_si256( - reinterpret_cast<__m256i const*>(¶ms.rhs_sums[0]))); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_sums_offsets), - rhs_sums_offset_v); - } - - for (int row = params.start_row; row <= params.last_row; - row += kAvx8bitBlockSize) { - const int residual_rows = - std::min(params.dst_rows - row, kAvx8bitBlockSize); - - const __m256i splitter_idx = - _mm256_loadu_si256(reinterpret_cast<__m256i const*>(splitter_idx_data)); - - __m256i accum_data_v0; - - // Initialize with bias. - __m256i initial_accum_data = - intrin_utils::mm256_n_loadu_epi32(residual_rows, bias_ptr); - bias_ptr += bias_ptr_block_increment; - - // Adjustments common across columns. - const std::int32_t rhs_zero_point = params.rhs_zero_point; - if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && rhs_zero_point) { - const __m256i lhs_sums_offset = _mm256_mullo_epi32( - _mm256_set1_epi32(rhs_zero_point), - _mm256_loadu_si256( - reinterpret_cast<__m256i const*>(¶ms.lhs_sums[row]))); - initial_accum_data = - _mm256_sub_epi32(initial_accum_data, lhs_sums_offset); - } - const std::int32_t prod_zp_depth = params.prod_zp_depth; - if (prod_zp_depth) { - initial_accum_data = _mm256_add_epi32(initial_accum_data, - _mm256_set1_epi32(prod_zp_depth)); - } - - // Adjustments differing across columns. - if (has_rhs_sums_offsets) { - accum_data_v0 = _mm256_sub_epi32(initial_accum_data, - _mm256_set1_epi32(rhs_sums_offsets[0])); - } else { - accum_data_v0 = initial_accum_data; - } - - const std::int8_t* lhs_ptr = lhs_col_ptr; - const std::int8_t* rhs_ptr = rhs_col_ptr; - for (int d = 0; d < params.depth; d += kAvx8bitInnerSize) { - const __m256i lhs_data = - _mm256_load_si256(reinterpret_cast(lhs_ptr)); - const __m128i rhs_data_8bit = _mm_loadu_si32(rhs_ptr); - - // Each "int32" is two 16-bit RHS values, sign extended from 8-bit. - // For simplicity we load 4x the data that we need and process twice the - // data that we need and store only the data we need. - std::int32_t rhs_data[2]; - const __m128i rhs_16_bit_dup = _mm_cvtepi8_epi16(rhs_data_8bit); - // Now that we have cast the RHS data, we store it so that each value - // can be separately loaded in the accumulation loop. - _mm_storeu_si64(reinterpret_cast<__m128i*>(rhs_data), rhs_16_bit_dup); - - // NOTE: There may be opportunities for permuting the data in the packing - // code instead of here. - const __m256i lhs_data_split = - _mm256_shuffle_epi8(lhs_data, splitter_idx); - const __m256i lhs_data_split_expand_bottom = - _mm256_cvtepi8_epi16(_mm256_extracti128_si256(lhs_data_split, 0)); - const __m256i lhs_data_split_expand_top = - _mm256_cvtepi8_epi16(_mm256_extracti128_si256(lhs_data_split, 1)); - - // Take bytes 0, 1, 4, 5, 8, 9, ... expanded to 16-bit. - const __m256i lhs_16_bit_low = _mm256_permute2x128_si256( - lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x20); - // Take bytes 2, 3, 6, 7, 10, 11, ... expanded to 16-bit. - const __m256i lhs_16_bit_high = _mm256_permute2x128_si256( - lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x31); - // Accumulate for column 0. - const std::int32_t low_rhs_value = rhs_data[0]; - const std::int32_t high_rhs_value = rhs_data[1]; - - const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value); - const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value); - - accum_data_v0 = _mm256_add_epi32( - accum_data_v0, _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_data_v0 = _mm256_add_epi32( - accum_data_v0, - _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - - lhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize; - rhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize; - } - - if (params.dst_type_id != DstTypeId::kValue) { - __m256i m_vector; - __m256i e_vector; - // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT. - if (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) { - m_vector = intrin_utils::mm256_n_loadu_epi32( - residual_rows, ¶ms.multiplier_fixedpoint[row]); - e_vector = intrin_utils::mm256_n_loadu_epi32( - residual_rows, ¶ms.multiplier_exponent[row]); - } else { - // These arrays have size LhsCols, and are pre-filled. - m_vector = _mm256_set1_epi32(params.multiplier_fixedpoint[0]); - e_vector = _mm256_set1_epi32(params.multiplier_exponent[0]); - } - - const __m256i m_64bit_low = - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(m_vector, 0)); - const __m256i m_64bit_high = - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(m_vector, 1)); - - const __m256i zero_vector = _mm256_setzero_si256(); - const __m256i left_shift = _mm256_max_epi32(e_vector, zero_vector); - const __m256i neg_e_vector = _mm256_sub_epi32(zero_vector, e_vector); - const __m256i right_shift = _mm256_max_epi32(neg_e_vector, zero_vector); - const __m256i final_right_shift = - _mm256_add_epi32(right_shift, _mm256_set1_epi32(31)); - const __m256i final_right_shift_low = - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(final_right_shift, 0)); - const __m256i final_right_shift_high = - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(final_right_shift, 1)); - // Really we want 0x100000000, but use half to avoid overflowing. - const __m256i convert_to_signed_halved = - _mm256_srlv_epi32(_mm256_set1_epi32(0x80000000), right_shift); - const __m256i convert_to_unsigned_64 = - _mm256_set1_epi64x(0x8000000000000000); - - __m256i post_scaling_offset = - _mm256_add_epi32(convert_to_signed_halved, convert_to_signed_halved); - - const __m256i offset_vector = - _mm256_slli_epi64(_mm256_set1_epi64x(1), 30); - // Really these should be shifted by neg_e_vector, but tests pass when - // using right_shift. - const __m256i offset_vector_low = _mm256_add_epi64( - _mm256_sllv_epi64( - offset_vector, - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(right_shift, 0))), - convert_to_unsigned_64); - const __m256i offset_vector_high = _mm256_add_epi64( - _mm256_sllv_epi64( - offset_vector, - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(right_shift, 1))), - convert_to_unsigned_64); - - if (params.dst_zero_point) { - const __m256i dst_zero_point = _mm256_set1_epi32(params.dst_zero_point); - // The post-scaling offset is subtracted later, so this has the effect - // of adding the zero point. - post_scaling_offset = - _mm256_sub_epi32(post_scaling_offset, dst_zero_point); - } - -#if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING) - RUY_DCHECK(false); -#endif - const __m256i repack_perm = _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7); - - // See GEMM version for details of this process. - { - __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v0, left_shift); - // Apply the fixed-point part of the multiplier. - __m256i scaled_v_low = _mm256_mul_epi32( - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)), - m_64bit_low); - __m256i scaled_v_high = _mm256_mul_epi32( - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)), - m_64bit_high); - - scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = - _mm256_srlv_epi64(scaled_v_high, final_right_shift_high); - - scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32); - __m256i results = _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa); - results = _mm256_permutevar8x32_epi32(results, repack_perm); - - accum_data_v0 = _mm256_sub_epi32(results, post_scaling_offset); - } - } - const __m256i clamp_max_v = _mm256_set1_epi32(params.clamp_max); - const __m256i clamp_min_v = _mm256_set1_epi32(params.clamp_min); - - if (params.dst_type_id == DstTypeId::kValue) { - std::int8_t* tmp_ptr = static_cast(dst_ptr); - __m256 result = accum_data_v0; - result = _mm256_min_epi32(result, clamp_max_v); - result = _mm256_max_epi32(result, clamp_min_v); - intrin_utils::mm256_n_storeu_cvtepi32_epi8(tmp_ptr, residual_rows, - result); - dst_ptr = static_cast(static_cast(dst_ptr) + - kAvx8bitBlockSize); - } else if (params.dst_type_id == DstTypeId::kValue) { - std::uint8_t* tmp_ptr = static_cast(dst_ptr); - __m256 result = accum_data_v0; - result = _mm256_min_epi32(result, clamp_max_v); - result = _mm256_max_epi32(result, clamp_min_v); - intrin_utils::mm256_n_storeu_cvtepi32_epi8(tmp_ptr, residual_rows, - result); - dst_ptr = static_cast(static_cast(dst_ptr) + - kAvx8bitBlockSize); - } else if (params.dst_type_id == DstTypeId::kValue) { - std::int16_t* tmp_ptr = static_cast(dst_ptr); - __m256 result = accum_data_v0; - result = _mm256_min_epi32(result, clamp_max_v); - result = _mm256_max_epi32(result, clamp_min_v); - intrin_utils::mm256_n_storeu_cvtepi32_epi16(tmp_ptr, residual_rows, - result); - dst_ptr = static_cast(static_cast(dst_ptr) + - kAvx8bitBlockSize); - } else if (params.dst_type_id == DstTypeId::kValue) { - std::int32_t* dst_block_ptr = static_cast(dst_ptr); - intrin_utils::mm256_n_storeu_epi32(dst_block_ptr, residual_rows, - accum_data_v0); - dst_ptr = static_cast(static_cast(dst_ptr) + - kAvx8bitBlockSize); - } else { - RUY_DCHECK(false); - } - - lhs_col_ptr += kAvx8bitBlockSize * params.lhs_stride; - } // End row-block loop. - - dst_col_ptr = static_cast(static_cast(dst_col_ptr) + - kAvx8bitBlockSize * params.dst_stride); - rhs_col_ptr += kAvx8bitBlockSize * params.rhs_stride; -} // NOLINT(readability/fn_size) - -void KernelFloatAvx2(const KernelParamsFloat<8, 8>& params) { - profiler::ScopeLabel label("Kernel kAvx2 float"); - - // As parameters are defined, we need to scale by sizeof(float). - const std::int64_t lhs_stride = params.lhs_stride >> 2; - const std::int64_t dst_stride = params.dst_stride >> 2; - const std::int64_t rhs_stride = params.rhs_stride >> 2; - // - int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 1 : 0; - // AVX2 float block size = 8. - const int end_row = std::min(params.dst_rows, params.last_row + 8); - const int end_col = std::min(params.dst_cols, params.last_col + 8); - // - const float* adj_rhs_col_ptr = - params.rhs_base_ptr - params.start_col * rhs_stride; - float* adj_dst_col_ptr = - params.dst_base_ptr - params.start_col * dst_stride - params.start_row; - const float* adj_lhs_col_ptr = - params.lhs_base_ptr - params.start_row * lhs_stride; - const float* bias_col_ptr = params.bias; - - const __m256 clamp_max_v = _mm256_set1_ps(params.clamp_max); - const __m256 clamp_min_v = _mm256_set1_ps(params.clamp_min); - - int col = params.start_col; - // Loop through cols by float block size, leaving incomplete remainder - for (; col <= end_col - 8; col += 8) { - __m256 accum_data_v[8]; - - const float* rhs_col_ptr = adj_rhs_col_ptr + col * rhs_stride; - float* dst_col_ptr = adj_dst_col_ptr + col * dst_stride; - - for (int row = params.start_row; row < end_row; row += 8) { - const int residual_rows = std::min(end_row - row, 8); - - const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; - float* dst_ptr = dst_col_ptr + row; - const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment; - - // Initialize with bias. - const __m256 initial_accum_data = - intrin_utils::mm256_n_loadu_ps(residual_rows, bias_ptr); - - for (int j = 0; j < 8; ++j) { - accum_data_v[j] = initial_accum_data; - } - - const float* lhs_ptr = lhs_col_ptr; - const float* rhs_ptr = rhs_col_ptr; - for (int d = 0; d < params.depth; ++d) { - const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr); - // In this version RHS values are loaded individually rather than first - // loading together and then extract with broadcasting. This is because - // AVX flavours and instrinsics and compilers in combination do not - // handle this pattern of extraction very well. - const float* rhs_data = rhs_ptr; - - for (int j = 0; j < 8; ++j) { - const __m256 dup_rhs_element_j = _mm256_set1_ps(rhs_data[j]); - accum_data_v[j] = - _mm256_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v[j]); - } - lhs_ptr += 8; - rhs_ptr += 8; - } - - if (residual_rows == 8) { - for (int j = 0; j < 8; ++j) { - float* block_ptr = dst_ptr + j * dst_stride; - accum_data_v[j] = _mm256_min_ps(accum_data_v[j], clamp_max_v); - accum_data_v[j] = _mm256_max_ps(accum_data_v[j], clamp_min_v); - _mm256_storeu_ps(block_ptr, accum_data_v[j]); - } - } else { - for (int j = 0; j < 8; ++j) { - float* block_ptr = dst_ptr + j * dst_stride; - accum_data_v[j] = _mm256_min_ps(accum_data_v[j], clamp_max_v); - accum_data_v[j] = _mm256_max_ps(accum_data_v[j], clamp_min_v); - intrin_utils::mm256_n_storeu_ps(block_ptr, residual_rows, - accum_data_v[j]); - } - } - } // End row-block loop. - } // End col-block loop. - - if (col < end_col) { - // Remaining cols in [0, float block size). - RUY_DCHECK_GE(end_col - col, 0); - RUY_DCHECK_LT(end_col - col, 8); - - __m256 accum_data_v[8]; - - const float* rhs_col_ptr = adj_rhs_col_ptr + col * rhs_stride; - float* dst_col_ptr = adj_dst_col_ptr + col * dst_stride; - const int residual_cols = std::min(end_col - col, 8); - - for (int row = params.start_row; row < end_row; row += 8) { - const int residual_rows = std::min(end_row - row, 8); - - const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; - float* dst_ptr = dst_col_ptr + row; - const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment; - - // Initialize with bias. - const __m256 initial_accum_data = - intrin_utils::mm256_n_loadu_ps(residual_rows, bias_ptr); - - for (int j = 0; j < 8; ++j) { - accum_data_v[j] = initial_accum_data; - } - - const float* lhs_ptr = lhs_col_ptr; - const float* rhs_ptr = rhs_col_ptr; - for (int d = 0; d < params.depth; ++d) { - const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr); - const float* rhs_data = rhs_ptr; - - for (int j = 0; j < 8; ++j) { - const __m256 dup_rhs_element_j = _mm256_set1_ps(rhs_data[j]); - accum_data_v[j] = - _mm256_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v[j]); - } - lhs_ptr += 8; - rhs_ptr += 8; - } - - for (int j = 0; j < residual_cols; ++j) { - float* block_ptr = dst_ptr + j * dst_stride; - accum_data_v[j] = _mm256_min_ps(accum_data_v[j], clamp_max_v); - accum_data_v[j] = _mm256_max_ps(accum_data_v[j], clamp_min_v); - intrin_utils::mm256_n_storeu_ps(block_ptr, residual_rows, - accum_data_v[j]); - } - } // End row-block loop. - } // End col-block terminal conditional. -} - -void KernelFloatAvx2SingleCol(const KernelParamsFloat<8, 8>& params) { - profiler::ScopeLabel label("Kernel kAvx2 float GEMV"); - - RUY_DCHECK_EQ(params.dst_cols, 1); - RUY_DCHECK_EQ(params.last_col, 0); - RUY_DCHECK_EQ(params.start_col, 0); - - // As parameters are defined, we need to scale by sizeof(float). - const std::int64_t lhs_stride = params.lhs_stride >> 2; - // - int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 1 : 0; - // AVX2 float block size = 8. - const int end_row = std::min(params.dst_rows, params.last_row + 8); - - float* adj_dst_col_ptr = params.dst_base_ptr - params.start_row; - const float* adj_lhs_col_ptr = - params.lhs_base_ptr - params.start_row * lhs_stride; - const float* bias_col_ptr = params.bias; - - const __m256 clamp_max_v = _mm256_set1_ps(params.clamp_max); - const __m256 clamp_min_v = _mm256_set1_ps(params.clamp_min); - - __m256 accum_data_v; - - const float* rhs_col_ptr = params.rhs_base_ptr; - float* dst_col_ptr = adj_dst_col_ptr; - - int row = params.start_row; - for (; row <= end_row - 8; row += 8) { - const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; - float* dst_ptr = dst_col_ptr + row; - const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment; - - // Initialize with bias. - accum_data_v = _mm256_loadu_ps(bias_ptr); - - const float* lhs_ptr = lhs_col_ptr; - const float* rhs_ptr = rhs_col_ptr; - int d = 0; - for (; d <= params.depth - 4; d += 4) { - const __m256 lhs_data_0 = _mm256_loadu_ps(lhs_ptr); - const __m256 dup_rhs_element_0 = _mm256_set1_ps(rhs_ptr[0]); - accum_data_v = - _mm256_fmadd_ps(lhs_data_0, dup_rhs_element_0, accum_data_v); - const __m256 dup_rhs_element_1 = _mm256_set1_ps(rhs_ptr[8]); - const __m256 lhs_data_1 = _mm256_loadu_ps(lhs_ptr + 8); - accum_data_v = - _mm256_fmadd_ps(lhs_data_1, dup_rhs_element_1, accum_data_v); - - const __m256 lhs_data_2 = _mm256_loadu_ps(lhs_ptr + 16); - const __m256 dup_rhs_element_2 = _mm256_set1_ps(rhs_ptr[16]); - accum_data_v = - _mm256_fmadd_ps(lhs_data_2, dup_rhs_element_2, accum_data_v); - const __m256 dup_rhs_element_3 = _mm256_set1_ps(rhs_ptr[24]); - const __m256 lhs_data_3 = _mm256_loadu_ps(lhs_ptr + 24); - accum_data_v = - _mm256_fmadd_ps(lhs_data_3, dup_rhs_element_3, accum_data_v); - lhs_ptr += 32; // Loaded 8 * 4 floats. - rhs_ptr += 32; - } - for (; d < params.depth; ++d) { - const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr); - const float* rhs_data = rhs_ptr; - - const __m256 dup_rhs_element_j = _mm256_set1_ps(rhs_data[0]); - accum_data_v = _mm256_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v); - lhs_ptr += 8; - rhs_ptr += 8; - } - - accum_data_v = _mm256_min_ps(accum_data_v, clamp_max_v); - accum_data_v = _mm256_max_ps(accum_data_v, clamp_min_v); - _mm256_storeu_ps(dst_ptr, accum_data_v); - } // End row-block loop. - - if (row < end_row) { - const int residual_rows = end_row - row; - RUY_CHECK_GE(residual_rows, 1); - RUY_CHECK_LT(residual_rows, 8); - - const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; - float* dst_ptr = dst_col_ptr + row; - const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment; - - // Initialize with bias. - accum_data_v = intrin_utils::mm256_n_loadu_ps(residual_rows, bias_ptr); - - const float* lhs_ptr = lhs_col_ptr; - const float* rhs_ptr = rhs_col_ptr; - for (int d = 0; d < params.depth; ++d) { - const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr); - const float* rhs_data = rhs_ptr; - - const __m256 dup_rhs_element_j = _mm256_set1_ps(rhs_data[0]); - accum_data_v = _mm256_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v); - lhs_ptr += 8; - rhs_ptr += 8; - } - - accum_data_v = _mm256_min_ps(accum_data_v, clamp_max_v); - accum_data_v = _mm256_max_ps(accum_data_v, clamp_min_v); - intrin_utils::mm256_n_storeu_ps(dst_ptr, residual_rows, accum_data_v); - } // End handling of residual rows. -} - -#endif // RUY_PLATFORM(AVX2) && RUY_OPT_ENABLED(RUY_OPT_ASM) - -} // namespace ruy diff --git a/kernel_avx512.cc b/kernel_avx512.cc deleted file mode 100644 index 202b347..0000000 --- a/kernel_avx512.cc +++ /dev/null @@ -1,1820 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include - -#include "check_macros.h" -#include "kernel.h" -#include "opt_set.h" -#include "platform.h" -#include "profiler/instrumentation.h" - -#if RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_ASM) -#include // IWYU pragma: keep -#endif - -namespace ruy { - -#if !(RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_ASM)) - -void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params) { - // CPU-ID-based checks should disable the path that would reach this point. - RUY_DCHECK(false); -} - -void Kernel8bitAvx512SingleCol(const KernelParams8bit<16, 16>& params) { - // CPU-ID-based checks should disable the path that would reach this point. - RUY_DCHECK(false); -} - -void KernelFloatAvx512(const KernelParamsFloat<16, 16>& params) { - // CPU-ID-based checks should disable the path that would reach this point. - RUY_DCHECK(false); -} - -void KernelFloatAvx512SingleCol(const KernelParamsFloat<16, 16>& params) { - // CPU-ID-based checks should disable the path that would reach this point. - RUY_DCHECK(false); -} - -#else // RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_ASM) - -void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params) { - profiler::ScopeLabel label("Kernel kAvx512 8-bit"); - - std::int32_t dst_stride; - if ((params.dst_type_id == DstTypeId::kValue) || - (params.dst_type_id == DstTypeId::kValue)) { - dst_stride = params.dst_stride; - } else if (params.dst_type_id == DstTypeId::kValue) { - dst_stride = params.dst_stride / sizeof(std::int16_t); - } else if (params.dst_type_id == DstTypeId::kValue) { - dst_stride = params.dst_stride / sizeof(std::int32_t); - } else { - RUY_DCHECK(false); - } - - int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 16 : 0; - - const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; - void* dst_col_ptr = params.dst_base_ptr; - const std::int32_t* bias_col_ptr = params.bias; - if (params.flags & RUY_ASM_FLAG_HAS_BIAS) { - bias_col_ptr += params.start_row; - } - - for (int col = params.start_col; col <= params.last_col; col += 16) { - const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; - void* dst_ptr = dst_col_ptr; - const std::int32_t* bias_ptr = bias_col_ptr; - - const std::int32_t lhs_zero_point = params.lhs_zero_point; - const bool has_rhs_sums_offsets = - (params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && lhs_zero_point; - std::int32_t rhs_sums_offsets[16]; - if (has_rhs_sums_offsets) { - const __m512i rhs_sums_offset_v = - _mm512_mullo_epi32(_mm512_set1_epi32(lhs_zero_point), - _mm512_loadu_epi32(¶ms.rhs_sums[col])); - _mm512_storeu_si512(reinterpret_cast<__m512i*>(rhs_sums_offsets), - rhs_sums_offset_v); - } - - for (int row = params.start_row; row <= params.last_row; row += 16) { - const int residual_rows = std::min(params.dst_rows - row, 16); - const int residual_cols = std::min(params.dst_cols - col, 16); - - __m512i accum_data_v0; - __m512i accum_data_v1; - __m512i accum_data_v2; - __m512i accum_data_v3; - __m512i accum_data_v4; - __m512i accum_data_v5; - __m512i accum_data_v6; - __m512i accum_data_v7; - __m512i accum_data_v8; - __m512i accum_data_v9; - __m512i accum_data_va; - __m512i accum_data_vb; - __m512i accum_data_vc; - __m512i accum_data_vd; - __m512i accum_data_ve; - __m512i accum_data_vf; - - // Initialize with bias. - const __mmask16 row_mask = - (static_cast(1) << residual_rows) - 1; - __m512i initial_accum_data = _mm512_maskz_loadu_epi32(row_mask, bias_ptr); - bias_ptr += bias_ptr_block_increment; - - const std::int32_t rhs_zero_point = params.rhs_zero_point; - if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && rhs_zero_point) { - const __m512i lhs_sums_offset = - _mm512_mullo_epi32(_mm512_set1_epi32(rhs_zero_point), - _mm512_loadu_epi32(¶ms.lhs_sums[row])); - initial_accum_data = - _mm512_sub_epi32(initial_accum_data, lhs_sums_offset); - } - - const std::int32_t prod_zp_depth = params.prod_zp_depth; - if (prod_zp_depth != 0) { - initial_accum_data = _mm512_add_epi32(initial_accum_data, - _mm512_set1_epi32(prod_zp_depth)); - } - - // Adjustments differing across columns. - if (has_rhs_sums_offsets) { - accum_data_v0 = _mm512_sub_epi32( - initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[0])); - accum_data_v1 = _mm512_sub_epi32( - initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[1])); - accum_data_v2 = _mm512_sub_epi32( - initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[2])); - accum_data_v3 = _mm512_sub_epi32( - initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[3])); - accum_data_v4 = _mm512_sub_epi32( - initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[4])); - accum_data_v5 = _mm512_sub_epi32( - initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[5])); - accum_data_v6 = _mm512_sub_epi32( - initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[6])); - accum_data_v7 = _mm512_sub_epi32( - initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[7])); - accum_data_v8 = _mm512_sub_epi32( - initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[8])); - accum_data_v9 = _mm512_sub_epi32( - initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[9])); - accum_data_va = _mm512_sub_epi32( - initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[10])); - accum_data_vb = _mm512_sub_epi32( - initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[11])); - accum_data_vc = _mm512_sub_epi32( - initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[12])); - accum_data_vd = _mm512_sub_epi32( - initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[13])); - accum_data_ve = _mm512_sub_epi32( - initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[14])); - accum_data_vf = _mm512_sub_epi32( - initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[15])); - } else { - accum_data_v0 = initial_accum_data; - accum_data_v1 = initial_accum_data; - accum_data_v2 = initial_accum_data; - accum_data_v3 = initial_accum_data; - accum_data_v4 = initial_accum_data; - accum_data_v5 = initial_accum_data; - accum_data_v6 = initial_accum_data; - accum_data_v7 = initial_accum_data; - accum_data_v8 = initial_accum_data; - accum_data_v9 = initial_accum_data; - accum_data_va = initial_accum_data; - accum_data_vb = initial_accum_data; - accum_data_vc = initial_accum_data; - accum_data_vd = initial_accum_data; - accum_data_ve = initial_accum_data; - accum_data_vf = initial_accum_data; - } - - const std::int8_t* lhs_ptr = lhs_col_ptr; - const std::int8_t* rhs_ptr = rhs_col_ptr; - for (int d = 0; d < params.depth; d += 4) { - const __m512i lhs_data = _mm512_loadu_epi8(lhs_ptr); - __m512i rhs_data_8bit = _mm512_loadu_epi8(rhs_ptr); - - // Each "int32" is two 16-bit RHS values, sign extended from 8-bit. - std::int32_t rhs_data[32]; - const __m256i rhs_data_bottom_lane = - _mm512_castsi512_si256(rhs_data_8bit); - const __m256i rhs_data_top_lane = - _mm512_extracti32x8_epi32(rhs_data_8bit, 1); - const __m512i rhs_16_bit_dup_low = - _mm512_cvtepi8_epi16(rhs_data_bottom_lane); - const __m512i rhs_16_bit_dup_high = - _mm512_cvtepi8_epi16(rhs_data_top_lane); - // Now that we have cast the RHS data, we store it so that each value - // can be separately loaded in the accumulation loop. - _mm512_storeu_si512(reinterpret_cast<__m256i*>(rhs_data), - rhs_16_bit_dup_low); - _mm512_storeu_si512(reinterpret_cast<__m256i*>(rhs_data + 16), - rhs_16_bit_dup_high); - - // Take bytes 0, 1, 4, 5, 8, 9, ... and expand to 16-bit. - const __m512i lhs_16_bit_low = - _mm512_cvtepi8_epi16(_mm512_cvtepi32_epi16(lhs_data)); - // Take bytes 2, 3, 6, 7, 10, 11, ... and expand to 16-bit. - const __m512i lhs_16_bit_high = _mm512_cvtepi8_epi16( - _mm512_cvtepi32_epi16(_mm512_srli_epi32(lhs_data, 16))); - - // Process column 0. - { - __m512i accum_v = accum_data_v0; - constexpr int index = 0; - - const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); - const __m512i rhs_16_bit_dup_high = - _mm512_set1_epi32(rhs_data[index + 1]); - - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - accum_data_v0 = accum_v; - } - // Process column 1. - { - __m512i accum_v = accum_data_v1; - constexpr int index = 2; - - const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); - const __m512i rhs_16_bit_dup_high = - _mm512_set1_epi32(rhs_data[index + 1]); - - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - accum_data_v1 = accum_v; - } - // Process column 2. - { - __m512i accum_v = accum_data_v2; - constexpr int index = 4; - - const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); - const __m512i rhs_16_bit_dup_high = - _mm512_set1_epi32(rhs_data[index + 1]); - - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - accum_data_v2 = accum_v; - } - // Process column 3. - { - __m512i accum_v = accum_data_v3; - constexpr int index = 6; - - const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); - const __m512i rhs_16_bit_dup_high = - _mm512_set1_epi32(rhs_data[index + 1]); - - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - accum_data_v3 = accum_v; - } - // Process column 4. - { - __m512i accum_v = accum_data_v4; - constexpr int index = 8; - - const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); - const __m512i rhs_16_bit_dup_high = - _mm512_set1_epi32(rhs_data[index + 1]); - - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - accum_data_v4 = accum_v; - } - // Process column 5. - { - __m512i accum_v = accum_data_v5; - constexpr int index = 10; - - const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); - const __m512i rhs_16_bit_dup_high = - _mm512_set1_epi32(rhs_data[index + 1]); - - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - accum_data_v5 = accum_v; - } - // Process column 6. - { - __m512i accum_v = accum_data_v6; - constexpr int index = 12; - - const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); - const __m512i rhs_16_bit_dup_high = - _mm512_set1_epi32(rhs_data[index + 1]); - - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - accum_data_v6 = accum_v; - } - // Process column 7. - { - __m512i accum_v = accum_data_v7; - constexpr int index = 14; - - const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); - const __m512i rhs_16_bit_dup_high = - _mm512_set1_epi32(rhs_data[index + 1]); - - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - accum_data_v7 = accum_v; - } - // Process column 8. - { - __m512i accum_v = accum_data_v8; - constexpr int index = 16; - - const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); - const __m512i rhs_16_bit_dup_high = - _mm512_set1_epi32(rhs_data[index + 1]); - - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - accum_data_v8 = accum_v; - } - // Process column 9. - { - __m512i accum_v = accum_data_v9; - constexpr int index = 18; - - const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); - const __m512i rhs_16_bit_dup_high = - _mm512_set1_epi32(rhs_data[index + 1]); - - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - accum_data_v9 = accum_v; - } - // Process column 10. - { - __m512i accum_v = accum_data_va; - constexpr int index = 20; - - const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); - const __m512i rhs_16_bit_dup_high = - _mm512_set1_epi32(rhs_data[index + 1]); - - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - accum_data_va = accum_v; - } - // Process column 11. - { - __m512i accum_v = accum_data_vb; - constexpr int index = 22; - - const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); - const __m512i rhs_16_bit_dup_high = - _mm512_set1_epi32(rhs_data[index + 1]); - - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - accum_data_vb = accum_v; - } - // Process column 12. - { - __m512i accum_v = accum_data_vc; - constexpr int index = 24; - - const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); - const __m512i rhs_16_bit_dup_high = - _mm512_set1_epi32(rhs_data[index + 1]); - - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - accum_data_vc = accum_v; - } - // Process column 13. - { - __m512i accum_v = accum_data_vd; - constexpr int index = 26; - - const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); - const __m512i rhs_16_bit_dup_high = - _mm512_set1_epi32(rhs_data[index + 1]); - - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - accum_data_vd = accum_v; - } - // Process column 14. - { - __m512i accum_v = accum_data_ve; - constexpr int index = 28; - - const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); - const __m512i rhs_16_bit_dup_high = - _mm512_set1_epi32(rhs_data[index + 1]); - - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - accum_data_ve = accum_v; - } - // Process column 15. - { - __m512i accum_v = accum_data_vf; - constexpr int index = 30; - - const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); - const __m512i rhs_16_bit_dup_high = - _mm512_set1_epi32(rhs_data[index + 1]); - - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - accum_data_vf = accum_v; - } - - lhs_ptr += 16 * 4; - rhs_ptr += 16 * 4; - } - - if (params.dst_type_id != DstTypeId::kValue) { - __m512i m_vector; - __m512i e_vector; - // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT. - if (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) { - m_vector = _mm512_maskz_loadu_epi32( - row_mask, ¶ms.multiplier_fixedpoint[row]); - e_vector = _mm512_maskz_loadu_epi32(row_mask, - ¶ms.multiplier_exponent[row]); - } else { - // These arrays have size LhsCols, and are pre-filled. - m_vector = _mm512_set1_epi32(params.multiplier_fixedpoint[0]); - e_vector = _mm512_set1_epi32(params.multiplier_exponent[0]); - } - - const __m512i m_64bit_low = - _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(m_vector, 0)); - const __m512i m_64bit_high = - _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(m_vector, 1)); - - const __m512i zero_vector = _mm512_setzero_epi32(); - const __m512i left_shift = _mm512_max_epi32(e_vector, zero_vector); - const __m512i neg_e_vector = _mm512_sub_epi32(zero_vector, e_vector); - const __m512i right_shift = _mm512_max_epi32(neg_e_vector, zero_vector); - const __m512i final_right_shift = - _mm512_add_epi32(right_shift, _mm512_set1_epi32(31)); - const __m512i final_right_shift_low = _mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(final_right_shift, 0)); - const __m512i final_right_shift_high = _mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(final_right_shift, 1)); - - const __m512i offset_vector = - _mm512_slli_epi64(_mm512_set1_epi64(1), 30); - // Really these should be shifted by neg_e_vector, but tests pass when - // using right_shift. - const __m512i offset_vector_low = _mm512_sllv_epi64( - offset_vector, - _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(right_shift, 0))); - const __m512i offset_vector_high = _mm512_sllv_epi64( - offset_vector, - _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(right_shift, 1))); - - // Shift and round column 0. - { - accum_data_v0 = _mm512_sllv_epi32(accum_data_v0, left_shift); - // Apply the fixed-point part of the multiplier. - __m512i scaled_v_low = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_v0, 0)), - m_64bit_low); - __m512i scaled_v_high = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_v0, 1)), - m_64bit_high); - - scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = - _mm512_srav_epi64(scaled_v_high, final_right_shift_high); - - accum_data_v0 = - _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); - accum_data_v0 = _mm512_inserti32x8( - accum_data_v0, _mm512_cvtepi64_epi32(scaled_v_high), 1); - } - // Shift and round column 1. - { - accum_data_v1 = _mm512_sllv_epi32(accum_data_v1, left_shift); - // Apply the fixed-point part of the multiplier. - __m512i scaled_v_low = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_v1, 0)), - m_64bit_low); - __m512i scaled_v_high = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_v1, 1)), - m_64bit_high); - - scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = - _mm512_srav_epi64(scaled_v_high, final_right_shift_high); - - accum_data_v1 = - _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); - accum_data_v1 = _mm512_inserti32x8( - accum_data_v1, _mm512_cvtepi64_epi32(scaled_v_high), 1); - } - // Shift and round column 2. - { - accum_data_v2 = _mm512_sllv_epi32(accum_data_v2, left_shift); - // Apply the fixed-point part of the multiplier. - __m512i scaled_v_low = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_v2, 0)), - m_64bit_low); - __m512i scaled_v_high = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_v2, 1)), - m_64bit_high); - - scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = - _mm512_srav_epi64(scaled_v_high, final_right_shift_high); - - accum_data_v2 = - _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); - accum_data_v2 = _mm512_inserti32x8( - accum_data_v2, _mm512_cvtepi64_epi32(scaled_v_high), 1); - } - // Shift and round column 3. - { - accum_data_v3 = _mm512_sllv_epi32(accum_data_v3, left_shift); - // Apply the fixed-point part of the multiplier. - __m512i scaled_v_low = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_v3, 0)), - m_64bit_low); - __m512i scaled_v_high = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_v3, 1)), - m_64bit_high); - - scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = - _mm512_srav_epi64(scaled_v_high, final_right_shift_high); - - accum_data_v3 = - _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); - accum_data_v3 = _mm512_inserti32x8( - accum_data_v3, _mm512_cvtepi64_epi32(scaled_v_high), 1); - } - // Shift and round column 4. - { - accum_data_v4 = _mm512_sllv_epi32(accum_data_v4, left_shift); - // Apply the fixed-point part of the multiplier. - __m512i scaled_v_low = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_v4, 0)), - m_64bit_low); - __m512i scaled_v_high = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_v4, 1)), - m_64bit_high); - - scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = - _mm512_srav_epi64(scaled_v_high, final_right_shift_high); - - accum_data_v4 = - _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); - accum_data_v4 = _mm512_inserti32x8( - accum_data_v4, _mm512_cvtepi64_epi32(scaled_v_high), 1); - } - // Shift and round column 5. - { - accum_data_v5 = _mm512_sllv_epi32(accum_data_v5, left_shift); - // Apply the fixed-point part of the multiplier. - __m512i scaled_v_low = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_v5, 0)), - m_64bit_low); - __m512i scaled_v_high = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_v5, 1)), - m_64bit_high); - - scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = - _mm512_srav_epi64(scaled_v_high, final_right_shift_high); - - accum_data_v5 = - _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); - accum_data_v5 = _mm512_inserti32x8( - accum_data_v5, _mm512_cvtepi64_epi32(scaled_v_high), 1); - } - // Shift and round column 6. - { - accum_data_v6 = _mm512_sllv_epi32(accum_data_v6, left_shift); - // Apply the fixed-point part of the multiplier. - __m512i scaled_v_low = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_v6, 0)), - m_64bit_low); - __m512i scaled_v_high = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_v6, 1)), - m_64bit_high); - - scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = - _mm512_srav_epi64(scaled_v_high, final_right_shift_high); - - accum_data_v6 = - _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); - accum_data_v6 = _mm512_inserti32x8( - accum_data_v6, _mm512_cvtepi64_epi32(scaled_v_high), 1); - } - // Shift and round column 7. - { - accum_data_v7 = _mm512_sllv_epi32(accum_data_v7, left_shift); - // Apply the fixed-point part of the multiplier. - __m512i scaled_v_low = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_v7, 0)), - m_64bit_low); - __m512i scaled_v_high = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_v7, 1)), - m_64bit_high); - - scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = - _mm512_srav_epi64(scaled_v_high, final_right_shift_high); - - accum_data_v7 = - _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); - accum_data_v7 = _mm512_inserti32x8( - accum_data_v7, _mm512_cvtepi64_epi32(scaled_v_high), 1); - } - // Shift and round column 8. - { - accum_data_v8 = _mm512_sllv_epi32(accum_data_v8, left_shift); - // Apply the fixed-point part of the multiplier. - __m512i scaled_v_low = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_v8, 0)), - m_64bit_low); - __m512i scaled_v_high = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_v8, 1)), - m_64bit_high); - - scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = - _mm512_srav_epi64(scaled_v_high, final_right_shift_high); - - accum_data_v8 = - _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); - accum_data_v8 = _mm512_inserti32x8( - accum_data_v8, _mm512_cvtepi64_epi32(scaled_v_high), 1); - } - // Shift and round column 9. - { - accum_data_v9 = _mm512_sllv_epi32(accum_data_v9, left_shift); - // Apply the fixed-point part of the multiplier. - __m512i scaled_v_low = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_v9, 0)), - m_64bit_low); - __m512i scaled_v_high = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_v9, 1)), - m_64bit_high); - - scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = - _mm512_srav_epi64(scaled_v_high, final_right_shift_high); - - accum_data_v9 = - _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); - accum_data_v9 = _mm512_inserti32x8( - accum_data_v9, _mm512_cvtepi64_epi32(scaled_v_high), 1); - } - // Shift and round column 10. - { - accum_data_va = _mm512_sllv_epi32(accum_data_va, left_shift); - // Apply the fixed-point part of the multiplier. - __m512i scaled_v_low = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_va, 0)), - m_64bit_low); - __m512i scaled_v_high = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_va, 1)), - m_64bit_high); - - scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = - _mm512_srav_epi64(scaled_v_high, final_right_shift_high); - - accum_data_va = - _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); - accum_data_va = _mm512_inserti32x8( - accum_data_va, _mm512_cvtepi64_epi32(scaled_v_high), 1); - } - // Shift and round column 11. - { - accum_data_vb = _mm512_sllv_epi32(accum_data_vb, left_shift); - // Apply the fixed-point part of the multiplier. - __m512i scaled_v_low = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_vb, 0)), - m_64bit_low); - __m512i scaled_v_high = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_vb, 1)), - m_64bit_high); - - scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = - _mm512_srav_epi64(scaled_v_high, final_right_shift_high); - - accum_data_vb = - _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); - accum_data_vb = _mm512_inserti32x8( - accum_data_vb, _mm512_cvtepi64_epi32(scaled_v_high), 1); - } - // Shift and round column 12. - { - accum_data_vc = _mm512_sllv_epi32(accum_data_vc, left_shift); - // Apply the fixed-point part of the multiplier. - __m512i scaled_v_low = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_vc, 0)), - m_64bit_low); - __m512i scaled_v_high = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_vc, 1)), - m_64bit_high); - - scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = - _mm512_srav_epi64(scaled_v_high, final_right_shift_high); - - accum_data_vc = - _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); - accum_data_vc = _mm512_inserti32x8( - accum_data_vc, _mm512_cvtepi64_epi32(scaled_v_high), 1); - } - // Shift and round column 13. - { - accum_data_vd = _mm512_sllv_epi32(accum_data_vd, left_shift); - // Apply the fixed-point part of the multiplier. - __m512i scaled_v_low = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_vd, 0)), - m_64bit_low); - __m512i scaled_v_high = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_vd, 1)), - m_64bit_high); - - scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = - _mm512_srav_epi64(scaled_v_high, final_right_shift_high); - - accum_data_vd = - _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); - accum_data_vd = _mm512_inserti32x8( - accum_data_vd, _mm512_cvtepi64_epi32(scaled_v_high), 1); - } - // Shift and round column 14. - { - accum_data_ve = _mm512_sllv_epi32(accum_data_ve, left_shift); - // Apply the fixed-point part of the multiplier. - __m512i scaled_v_low = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_ve, 0)), - m_64bit_low); - __m512i scaled_v_high = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_ve, 1)), - m_64bit_high); - - scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = - _mm512_srav_epi64(scaled_v_high, final_right_shift_high); - - accum_data_ve = - _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); - accum_data_ve = _mm512_inserti32x8( - accum_data_ve, _mm512_cvtepi64_epi32(scaled_v_high), 1); - } - // Shift and round column 15. - { - accum_data_vf = _mm512_sllv_epi32(accum_data_vf, left_shift); - // Apply the fixed-point part of the multiplier. - __m512i scaled_v_low = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_vf, 0)), - m_64bit_low); - __m512i scaled_v_high = - _mm512_mul_epi32(_mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(accum_data_vf, 1)), - m_64bit_high); - - scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = - _mm512_srav_epi64(scaled_v_high, final_right_shift_high); - - accum_data_vf = - _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); - accum_data_vf = _mm512_inserti32x8( - accum_data_vf, _mm512_cvtepi64_epi32(scaled_v_high), 1); - } -#if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING) - RUY_DCHECK(false); -#endif - - if (params.dst_zero_point != 0) { - __m512i dst_zero_point = _mm512_set1_epi32(params.dst_zero_point); - accum_data_v0 = _mm512_add_epi32(accum_data_v0, dst_zero_point); - accum_data_v1 = _mm512_add_epi32(accum_data_v1, dst_zero_point); - accum_data_v2 = _mm512_add_epi32(accum_data_v2, dst_zero_point); - accum_data_v3 = _mm512_add_epi32(accum_data_v3, dst_zero_point); - accum_data_v4 = _mm512_add_epi32(accum_data_v4, dst_zero_point); - accum_data_v5 = _mm512_add_epi32(accum_data_v5, dst_zero_point); - accum_data_v6 = _mm512_add_epi32(accum_data_v6, dst_zero_point); - accum_data_v7 = _mm512_add_epi32(accum_data_v7, dst_zero_point); - accum_data_v8 = _mm512_add_epi32(accum_data_v8, dst_zero_point); - accum_data_v9 = _mm512_add_epi32(accum_data_v9, dst_zero_point); - accum_data_va = _mm512_add_epi32(accum_data_va, dst_zero_point); - accum_data_vb = _mm512_add_epi32(accum_data_vb, dst_zero_point); - accum_data_vc = _mm512_add_epi32(accum_data_vc, dst_zero_point); - accum_data_vd = _mm512_add_epi32(accum_data_vd, dst_zero_point); - accum_data_ve = _mm512_add_epi32(accum_data_ve, dst_zero_point); - accum_data_vf = _mm512_add_epi32(accum_data_vf, dst_zero_point); - } - } - - const __m512i clamp_max_v = _mm512_set1_epi32(params.clamp_max); - const __m512i clamp_min_v = _mm512_set1_epi32(params.clamp_min); - - const bool store_full_block = - (residual_rows == 16) && (residual_cols == 16); - - __m512i accum_data_v[16]; - - // In most cases we would make this conditional on (!store_full_block) and - // unwind the clamp-and-store loop, but the benefit appears small. - { - accum_data_v[0] = accum_data_v0; - accum_data_v[1] = accum_data_v1; - accum_data_v[2] = accum_data_v2; - accum_data_v[3] = accum_data_v3; - accum_data_v[4] = accum_data_v4; - accum_data_v[5] = accum_data_v5; - accum_data_v[6] = accum_data_v6; - accum_data_v[7] = accum_data_v7; - accum_data_v[8] = accum_data_v8; - accum_data_v[9] = accum_data_v9; - accum_data_v[10] = accum_data_va; - accum_data_v[11] = accum_data_vb; - accum_data_v[12] = accum_data_vc; - accum_data_v[13] = accum_data_vd; - accum_data_v[14] = accum_data_ve; - accum_data_v[15] = accum_data_vf; - } - - if (params.dst_type_id == DstTypeId::kValue) { - std::int8_t* tmp_ptr = static_cast(dst_ptr); - const int block_col_offset = dst_stride; - if (store_full_block) { - for (int j = 0; j < 16; ++j) { - __m512i result = accum_data_v[j]; - result = _mm512_min_epi32(result, clamp_max_v); - result = _mm512_max_epi32(result, clamp_min_v); - _mm_storeu_epi8(tmp_ptr + j * block_col_offset, - _mm512_cvtepi32_epi8(result)); - } - } else { - for (int j = 0; j < residual_cols; ++j) { - __m512i result = accum_data_v[j]; - result = _mm512_min_epi32(result, clamp_max_v); - result = _mm512_max_epi32(result, clamp_min_v); - _mm_mask_storeu_epi8(tmp_ptr + j * block_col_offset, row_mask, - _mm512_cvtepi32_epi8(result)); - } - } - dst_ptr = static_cast(static_cast(dst_ptr) + 16); - } else if (params.dst_type_id == DstTypeId::kValue) { - std::uint8_t* tmp_ptr = static_cast(dst_ptr); - const int block_col_offset = dst_stride; - if (store_full_block) { - for (int j = 0; j < residual_cols; ++j) { - __m512i result = accum_data_v[j]; - result = _mm512_min_epi32(result, clamp_max_v); - result = _mm512_max_epi32(result, clamp_min_v); - _mm_storeu_epi8(tmp_ptr + j * block_col_offset, - _mm512_cvtepi32_epi8(result)); - } - } else { - for (int j = 0; j < residual_cols; ++j) { - __m512i result = accum_data_v[j]; - result = _mm512_min_epi32(result, clamp_max_v); - result = _mm512_max_epi32(result, clamp_min_v); - _mm_mask_storeu_epi8(tmp_ptr + j * block_col_offset, row_mask, - _mm512_cvtepi32_epi8(result)); - } - } - dst_ptr = static_cast(static_cast(dst_ptr) + 16); - } else if (params.dst_type_id == DstTypeId::kValue) { - std::int16_t* tmp_ptr = static_cast(dst_ptr); - const int block_col_offset = dst_stride; - if (store_full_block) { - for (int j = 0; j < 16; ++j) { - __m512i result = accum_data_v[j]; - result = _mm512_min_epi32(result, clamp_max_v); - result = _mm512_max_epi32(result, clamp_min_v); - _mm256_storeu_epi16(tmp_ptr + j * block_col_offset, - _mm512_cvtepi32_epi16(result)); - } - } else { - for (int j = 0; j < residual_cols; ++j) { - __m512i result = accum_data_v[j]; - result = _mm512_min_epi32(result, clamp_max_v); - result = _mm512_max_epi32(result, clamp_min_v); - _mm256_mask_storeu_epi16(tmp_ptr + j * block_col_offset, row_mask, - _mm512_cvtepi32_epi16(result)); - } - } - dst_ptr = static_cast(static_cast(dst_ptr) + 16); - } else if (params.dst_type_id == DstTypeId::kValue) { - if (store_full_block) { - std::int32_t* tmp_ptr = static_cast(dst_ptr); - for (int j = 0; j < 16; ++j) { - _mm512_storeu_epi32(tmp_ptr + j * dst_stride, accum_data_v[j]); - } - } else { - std::int32_t* tmp_ptr = static_cast(dst_ptr); - for (int j = 0; j < residual_cols; ++j) { - _mm512_mask_storeu_epi32(tmp_ptr + j * dst_stride, row_mask, - accum_data_v[j]); - } - } - dst_ptr = static_cast(static_cast(dst_ptr) + 16); - } else { - RUY_DCHECK(false); - } - - lhs_col_ptr += 16 * params.lhs_stride; - } // End row-block loop. - - dst_col_ptr = static_cast(static_cast(dst_col_ptr) + - 16 * params.dst_stride); - rhs_col_ptr += 16 * params.rhs_stride; - } // End col-block loop. -} // NOLINT(readability/fn_size) - -void Kernel8bitAvx512SingleCol(const KernelParams8bit<16, 16>& params) { - profiler::ScopeLabel label("Kernel kAvx512 8-bit GEMV"); - - RUY_DCHECK_EQ(params.dst_cols, 1); - RUY_DCHECK_EQ(params.last_col, 0); - RUY_DCHECK_EQ(params.start_col, 0); - - std::int32_t dst_stride; - if ((params.dst_type_id == DstTypeId::kValue) || - (params.dst_type_id == DstTypeId::kValue)) { - dst_stride = params.dst_stride; - } else if (params.dst_type_id == DstTypeId::kValue) { - dst_stride = params.dst_stride / sizeof(std::int16_t); - } else if (params.dst_type_id == DstTypeId::kValue) { - dst_stride = params.dst_stride / sizeof(std::int32_t); - } else { - RUY_DCHECK(false); - } - - int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 16 : 0; - - const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; - void* dst_col_ptr = params.dst_base_ptr; - const std::int32_t* bias_col_ptr = params.bias; - if (params.flags & RUY_ASM_FLAG_HAS_BIAS) { - bias_col_ptr += params.start_row; - } - - const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; - void* dst_ptr = dst_col_ptr; - const std::int32_t* bias_ptr = bias_col_ptr; - - const std::int32_t lhs_zero_point = params.lhs_zero_point; - const bool has_rhs_sums_offsets = - (params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && lhs_zero_point; - std::int32_t rhs_sums_offsets[16]; - if (has_rhs_sums_offsets) { - const __m512i rhs_sums_offset_v = - _mm512_mullo_epi32(_mm512_set1_epi32(lhs_zero_point), - _mm512_loadu_epi32(¶ms.rhs_sums[0])); - _mm512_storeu_si512(reinterpret_cast<__m512i*>(rhs_sums_offsets), - rhs_sums_offset_v); - } - - for (int row = params.start_row; row <= params.last_row; row += 16) { - const int residual_rows = std::min(params.dst_rows - row, 16); - - __m512i accum_data_v0; - - // Initialize with bias. - const __mmask16 row_mask = - (static_cast(1) << residual_rows) - 1; - __m512i initial_accum_data = _mm512_maskz_loadu_epi32(row_mask, bias_ptr); - bias_ptr += bias_ptr_block_increment; - - const std::int32_t rhs_zero_point = params.rhs_zero_point; - if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && rhs_zero_point) { - const __m512i lhs_sums_offset = - _mm512_mullo_epi32(_mm512_set1_epi32(rhs_zero_point), - _mm512_loadu_epi32(¶ms.lhs_sums[row])); - initial_accum_data = - _mm512_sub_epi32(initial_accum_data, lhs_sums_offset); - } - - const std::int32_t prod_zp_depth = params.prod_zp_depth; - if (prod_zp_depth != 0) { - initial_accum_data = _mm512_add_epi32(initial_accum_data, - _mm512_set1_epi32(prod_zp_depth)); - } - - // Adjustments differing across columns. - if (has_rhs_sums_offsets) { - accum_data_v0 = _mm512_sub_epi32(initial_accum_data, - _mm512_set1_epi32(rhs_sums_offsets[0])); - } else { - accum_data_v0 = initial_accum_data; - } - - const std::int8_t* lhs_ptr = lhs_col_ptr; - const std::int8_t* rhs_ptr = rhs_col_ptr; - for (int d = 0; d < params.depth; d += 4) { - const __m512i lhs_data = _mm512_loadu_epi8(lhs_ptr); - const __m128i rhs_data_8bit = _mm_loadu_epi8(rhs_ptr); - - // Each "int32" is two 16-bit RHS values, sign extended from 8-bit. - // For simplicity we load 4x the data that we need and process twice the - // data that we need and store only the data we need. - std::int32_t rhs_data[2]; - const __m128i rhs_16_bit_dup = _mm_cvtepi8_epi16(rhs_data_8bit); - // Now that we have cast the RHS data, we store it so that each value - // can be separately loaded in the accumulation loop. - _mm_storeu_si64(reinterpret_cast<__m128i*>(rhs_data), rhs_16_bit_dup); - - // Take bytes 0, 1, 4, 5, 8, 9, ... and expand to 16-bit. - const __m512i lhs_16_bit_low = - _mm512_cvtepi8_epi16(_mm512_cvtepi32_epi16(lhs_data)); - // Take bytes 2, 3, 6, 7, 10, 11, ... and expand to 16-bit. - const __m512i lhs_16_bit_high = _mm512_cvtepi8_epi16( - _mm512_cvtepi32_epi16(_mm512_srli_epi32(lhs_data, 16))); - - // Process column 0. - __m512i accum_v = accum_data_v0; - constexpr int index = 0; - - const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); - const __m512i rhs_16_bit_dup_high = - _mm512_set1_epi32(rhs_data[index + 1]); - - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_v = _mm512_add_epi32( - accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - accum_data_v0 = accum_v; - - lhs_ptr += 16 * 4; - rhs_ptr += 16 * 4; - } - - if (params.dst_type_id != DstTypeId::kValue) { - __m512i m_vector; - __m512i e_vector; - // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT. - if (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) { - m_vector = _mm512_maskz_loadu_epi32(row_mask, - ¶ms.multiplier_fixedpoint[row]); - e_vector = _mm512_maskz_loadu_epi32(row_mask, - ¶ms.multiplier_exponent[row]); - } else { - // These arrays have size LhsCols, and are pre-filled. - m_vector = _mm512_set1_epi32(params.multiplier_fixedpoint[0]); - e_vector = _mm512_set1_epi32(params.multiplier_exponent[0]); - } - - const __m512i m_64bit_low = - _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(m_vector, 0)); - const __m512i m_64bit_high = - _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(m_vector, 1)); - - const __m512i zero_vector = _mm512_setzero_epi32(); - const __m512i left_shift = _mm512_max_epi32(e_vector, zero_vector); - const __m512i neg_e_vector = _mm512_sub_epi32(zero_vector, e_vector); - const __m512i right_shift = _mm512_max_epi32(neg_e_vector, zero_vector); - const __m512i final_right_shift = - _mm512_add_epi32(right_shift, _mm512_set1_epi32(31)); - const __m512i final_right_shift_low = _mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(final_right_shift, 0)); - const __m512i final_right_shift_high = _mm512_cvtepi32_epi64( - _mm512_extracti32x8_epi32(final_right_shift, 1)); - - const __m512i offset_vector = _mm512_slli_epi64(_mm512_set1_epi64(1), 30); - // Really these should be shifted by neg_e_vector, but tests pass when - // using right_shift. - const __m512i offset_vector_low = _mm512_sllv_epi64( - offset_vector, - _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(right_shift, 0))); - const __m512i offset_vector_high = _mm512_sllv_epi64( - offset_vector, - _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(right_shift, 1))); - - // Shift and round column 0. - accum_data_v0 = _mm512_sllv_epi32(accum_data_v0, left_shift); - // Apply the fixed-point part of the multiplier. - __m512i scaled_v_low = _mm512_mul_epi32( - _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(accum_data_v0, 0)), - m_64bit_low); - __m512i scaled_v_high = _mm512_mul_epi32( - _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(accum_data_v0, 1)), - m_64bit_high); - - scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = _mm512_srav_epi64(scaled_v_high, final_right_shift_high); - - accum_data_v0 = - _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); - accum_data_v0 = _mm512_inserti32x8( - accum_data_v0, _mm512_cvtepi64_epi32(scaled_v_high), 1); -#if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING) - RUY_DCHECK(false); -#endif - - if (params.dst_zero_point != 0) { - __m512i dst_zero_point = _mm512_set1_epi32(params.dst_zero_point); - accum_data_v0 = _mm512_add_epi32(accum_data_v0, dst_zero_point); - } - } - - const __m512i clamp_max_v = _mm512_set1_epi32(params.clamp_max); - const __m512i clamp_min_v = _mm512_set1_epi32(params.clamp_min); - - if (params.dst_type_id == DstTypeId::kValue) { - std::int8_t* tmp_ptr = static_cast(dst_ptr); - __m512i result = accum_data_v0; - result = _mm512_min_epi32(result, clamp_max_v); - result = _mm512_max_epi32(result, clamp_min_v); - _mm_mask_storeu_epi8(tmp_ptr, row_mask, _mm512_cvtepi32_epi8(result)); - dst_ptr = static_cast(static_cast(dst_ptr) + 16); - } else if (params.dst_type_id == DstTypeId::kValue) { - std::uint8_t* tmp_ptr = static_cast(dst_ptr); - __m512i result = accum_data_v0; - result = _mm512_min_epi32(result, clamp_max_v); - result = _mm512_max_epi32(result, clamp_min_v); - _mm_mask_storeu_epi8(tmp_ptr, row_mask, _mm512_cvtepi32_epi8(result)); - dst_ptr = static_cast(static_cast(dst_ptr) + 16); - } else if (params.dst_type_id == DstTypeId::kValue) { - std::int16_t* tmp_ptr = static_cast(dst_ptr); - __m512i result = accum_data_v0; - result = _mm512_min_epi32(result, clamp_max_v); - result = _mm512_max_epi32(result, clamp_min_v); - _mm256_mask_storeu_epi16(tmp_ptr, row_mask, - _mm512_cvtepi32_epi16(result)); - dst_ptr = static_cast(static_cast(dst_ptr) + 16); - } else if (params.dst_type_id == DstTypeId::kValue) { - std::int32_t* tmp_ptr = static_cast(dst_ptr); - _mm512_mask_storeu_epi32(tmp_ptr, row_mask, accum_data_v0); - dst_ptr = static_cast(static_cast(dst_ptr) + 16); - } else { - RUY_DCHECK(false); - } - - lhs_col_ptr += 16 * params.lhs_stride; - } // End row-block loop. -} // NOLINT(readability/fn_size) - -void KernelFloatAvx512(const KernelParamsFloat<16, 16>& params) { - profiler::ScopeLabel label("Kernel kAvx512 float"); - - // As parameters are defined, we need to scale by sizeof(float). - const std::int64_t lhs_stride = params.lhs_stride >> 2; - const std::int64_t dst_stride = params.dst_stride >> 2; - const std::int64_t rhs_stride = params.rhs_stride >> 2; - - int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 1 : 0; - const int end_row = std::min(params.dst_rows, params.last_row + 16); - const int end_col = std::min(params.dst_cols, params.last_col + 16); - - const float* adj_rhs_col_ptr = - params.rhs_base_ptr - params.start_col * rhs_stride; - float* adj_dst_col_ptr = - params.dst_base_ptr - params.start_col * dst_stride - params.start_row; - const float* adj_lhs_col_ptr = - params.lhs_base_ptr - params.start_row * lhs_stride; - const float* bias_col_ptr = params.bias; - - const __m512 clamp_max_v = _mm512_set1_ps(params.clamp_max); - const __m512 clamp_min_v = _mm512_set1_ps(params.clamp_min); - - int col = params.start_col; - for (; col <= end_col - 16; col += 16) { - const float* rhs_col_ptr = adj_rhs_col_ptr + col * rhs_stride; - float* dst_col_ptr = adj_dst_col_ptr + col * dst_stride; - - int row = params.start_row; - for (; row <= end_row - 16; row += 16) { - const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; - float* dst_ptr = dst_col_ptr + row; - const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment; - - // Initialize with bias. - const __m512 initial_accum_data = _mm512_loadu_ps(bias_ptr); - - // Process block in two halves, split by columns. - { - constexpr int mmm = 0; - - __m512 accum_data_v0 = initial_accum_data; - __m512 accum_data_v1 = initial_accum_data; - __m512 accum_data_v2 = initial_accum_data; - __m512 accum_data_v3 = initial_accum_data; - __m512 accum_data_v4 = initial_accum_data; - __m512 accum_data_v5 = initial_accum_data; - __m512 accum_data_v6 = initial_accum_data; - __m512 accum_data_v7 = initial_accum_data; - - const float* lhs_ptr = lhs_col_ptr; - const float* rhs_ptr = rhs_col_ptr + 8 * mmm; - for (int d = 0; d < (params.depth - 1); ++d) { - const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr); - // In this version RHS values are loaded individually rather than - // first loading together and then extract with broadcasting. This is - // because AVX flavours and instrinsics and compilers in combination - // do not handle this pattern of extraction very well. - const float* rhs_data = rhs_ptr; - lhs_ptr += 16; - rhs_ptr += 16; - - { - const __m512 dup_rhs_element_j0 = _mm512_set1_ps(rhs_data[0]); - accum_data_v0 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0); - const __m512 dup_rhs_element_j1 = _mm512_set1_ps(rhs_data[1]); - accum_data_v1 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1); - const __m512 dup_rhs_element_j2 = _mm512_set1_ps(rhs_data[2]); - accum_data_v2 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2); - const __m512 dup_rhs_element_j3 = _mm512_set1_ps(rhs_data[3]); - accum_data_v3 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3); - const __m512 dup_rhs_element_j4 = _mm512_set1_ps(rhs_data[4]); - accum_data_v4 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4); - const __m512 dup_rhs_element_j5 = _mm512_set1_ps(rhs_data[5]); - accum_data_v5 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5); - const __m512 dup_rhs_element_j6 = _mm512_set1_ps(rhs_data[6]); - accum_data_v6 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6); - const __m512 dup_rhs_element_j7 = _mm512_set1_ps(rhs_data[7]); - accum_data_v7 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7); - } - } - { - const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr); - const float* rhs_data = rhs_ptr; - { - const __m512 dup_rhs_element_j0 = _mm512_set1_ps(rhs_data[0]); - accum_data_v0 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0); - const __m512 dup_rhs_element_j1 = _mm512_set1_ps(rhs_data[1]); - accum_data_v1 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1); - const __m512 dup_rhs_element_j2 = _mm512_set1_ps(rhs_data[2]); - accum_data_v2 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2); - const __m512 dup_rhs_element_j3 = _mm512_set1_ps(rhs_data[3]); - accum_data_v3 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3); - const __m512 dup_rhs_element_j4 = _mm512_set1_ps(rhs_data[4]); - accum_data_v4 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4); - const __m512 dup_rhs_element_j5 = _mm512_set1_ps(rhs_data[5]); - accum_data_v5 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5); - const __m512 dup_rhs_element_j6 = _mm512_set1_ps(rhs_data[6]); - accum_data_v6 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6); - const __m512 dup_rhs_element_j7 = _mm512_set1_ps(rhs_data[7]); - accum_data_v7 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7); - } - { - float* block_ptr = dst_ptr + (mmm * 8 + 0) * dst_stride; - accum_data_v0 = _mm512_min_ps(accum_data_v0, clamp_max_v); - accum_data_v0 = _mm512_max_ps(accum_data_v0, clamp_min_v); - _mm512_storeu_ps(block_ptr + 0 * dst_stride, accum_data_v0); - accum_data_v1 = _mm512_min_ps(accum_data_v1, clamp_max_v); - accum_data_v1 = _mm512_max_ps(accum_data_v1, clamp_min_v); - _mm512_storeu_ps(block_ptr + 1 * dst_stride, accum_data_v1); - accum_data_v2 = _mm512_min_ps(accum_data_v2, clamp_max_v); - accum_data_v2 = _mm512_max_ps(accum_data_v2, clamp_min_v); - _mm512_storeu_ps(block_ptr + 2 * dst_stride, accum_data_v2); - accum_data_v3 = _mm512_min_ps(accum_data_v3, clamp_max_v); - accum_data_v3 = _mm512_max_ps(accum_data_v3, clamp_min_v); - _mm512_storeu_ps(block_ptr + 3 * dst_stride, accum_data_v3); - accum_data_v4 = _mm512_min_ps(accum_data_v4, clamp_max_v); - accum_data_v4 = _mm512_max_ps(accum_data_v4, clamp_min_v); - _mm512_storeu_ps(block_ptr + 4 * dst_stride, accum_data_v4); - accum_data_v5 = _mm512_min_ps(accum_data_v5, clamp_max_v); - accum_data_v5 = _mm512_max_ps(accum_data_v5, clamp_min_v); - _mm512_storeu_ps(block_ptr + 5 * dst_stride, accum_data_v5); - accum_data_v6 = _mm512_min_ps(accum_data_v6, clamp_max_v); - accum_data_v6 = _mm512_max_ps(accum_data_v6, clamp_min_v); - _mm512_storeu_ps(block_ptr + 6 * dst_stride, accum_data_v6); - accum_data_v7 = _mm512_min_ps(accum_data_v7, clamp_max_v); - accum_data_v7 = _mm512_max_ps(accum_data_v7, clamp_min_v); - _mm512_storeu_ps(block_ptr + 7 * dst_stride, accum_data_v7); - } - } - } // Inner half-block loop, unrolled, first iteration. - { - constexpr int mmm = 1; - - __m512 accum_data_v0 = initial_accum_data; - __m512 accum_data_v1 = initial_accum_data; - __m512 accum_data_v2 = initial_accum_data; - __m512 accum_data_v3 = initial_accum_data; - __m512 accum_data_v4 = initial_accum_data; - __m512 accum_data_v5 = initial_accum_data; - __m512 accum_data_v6 = initial_accum_data; - __m512 accum_data_v7 = initial_accum_data; - - const float* lhs_ptr = lhs_col_ptr; - const float* rhs_ptr = rhs_col_ptr + 8 * mmm; - for (int d = 0; d < (params.depth - 1); ++d) { - const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr); - const float* rhs_data = rhs_ptr; - lhs_ptr += 16; - rhs_ptr += 16; - { - const __m512 dup_rhs_element_j0 = _mm512_set1_ps(rhs_data[0]); - accum_data_v0 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0); - const __m512 dup_rhs_element_j1 = _mm512_set1_ps(rhs_data[1]); - accum_data_v1 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1); - const __m512 dup_rhs_element_j2 = _mm512_set1_ps(rhs_data[2]); - accum_data_v2 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2); - const __m512 dup_rhs_element_j3 = _mm512_set1_ps(rhs_data[3]); - accum_data_v3 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3); - const __m512 dup_rhs_element_j4 = _mm512_set1_ps(rhs_data[4]); - accum_data_v4 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4); - const __m512 dup_rhs_element_j5 = _mm512_set1_ps(rhs_data[5]); - accum_data_v5 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5); - const __m512 dup_rhs_element_j6 = _mm512_set1_ps(rhs_data[6]); - accum_data_v6 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6); - const __m512 dup_rhs_element_j7 = _mm512_set1_ps(rhs_data[7]); - accum_data_v7 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7); - } - } - { - const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr); - const float* rhs_data = rhs_ptr; - { - const __m512 dup_rhs_element_j0 = _mm512_set1_ps(rhs_data[0]); - accum_data_v0 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0); - const __m512 dup_rhs_element_j1 = _mm512_set1_ps(rhs_data[1]); - accum_data_v1 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1); - const __m512 dup_rhs_element_j2 = _mm512_set1_ps(rhs_data[2]); - accum_data_v2 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2); - const __m512 dup_rhs_element_j3 = _mm512_set1_ps(rhs_data[3]); - accum_data_v3 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3); - const __m512 dup_rhs_element_j4 = _mm512_set1_ps(rhs_data[4]); - accum_data_v4 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4); - const __m512 dup_rhs_element_j5 = _mm512_set1_ps(rhs_data[5]); - accum_data_v5 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5); - const __m512 dup_rhs_element_j6 = _mm512_set1_ps(rhs_data[6]); - accum_data_v6 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6); - const __m512 dup_rhs_element_j7 = _mm512_set1_ps(rhs_data[7]); - accum_data_v7 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7); - } - { - float* block_ptr = dst_ptr + (mmm * 8 + 0) * dst_stride; - accum_data_v0 = _mm512_min_ps(accum_data_v0, clamp_max_v); - accum_data_v0 = _mm512_max_ps(accum_data_v0, clamp_min_v); - _mm512_storeu_ps(block_ptr + 0 * dst_stride, accum_data_v0); - accum_data_v1 = _mm512_min_ps(accum_data_v1, clamp_max_v); - accum_data_v1 = _mm512_max_ps(accum_data_v1, clamp_min_v); - _mm512_storeu_ps(block_ptr + 1 * dst_stride, accum_data_v1); - accum_data_v2 = _mm512_min_ps(accum_data_v2, clamp_max_v); - accum_data_v2 = _mm512_max_ps(accum_data_v2, clamp_min_v); - _mm512_storeu_ps(block_ptr + 2 * dst_stride, accum_data_v2); - accum_data_v3 = _mm512_min_ps(accum_data_v3, clamp_max_v); - accum_data_v3 = _mm512_max_ps(accum_data_v3, clamp_min_v); - _mm512_storeu_ps(block_ptr + 3 * dst_stride, accum_data_v3); - accum_data_v4 = _mm512_min_ps(accum_data_v4, clamp_max_v); - accum_data_v4 = _mm512_max_ps(accum_data_v4, clamp_min_v); - _mm512_storeu_ps(block_ptr + 4 * dst_stride, accum_data_v4); - accum_data_v5 = _mm512_min_ps(accum_data_v5, clamp_max_v); - accum_data_v5 = _mm512_max_ps(accum_data_v5, clamp_min_v); - _mm512_storeu_ps(block_ptr + 5 * dst_stride, accum_data_v5); - accum_data_v6 = _mm512_min_ps(accum_data_v6, clamp_max_v); - accum_data_v6 = _mm512_max_ps(accum_data_v6, clamp_min_v); - _mm512_storeu_ps(block_ptr + 6 * dst_stride, accum_data_v6); - accum_data_v7 = _mm512_min_ps(accum_data_v7, clamp_max_v); - accum_data_v7 = _mm512_max_ps(accum_data_v7, clamp_min_v); - _mm512_storeu_ps(block_ptr + 7 * dst_stride, accum_data_v7); - } - } - } // Inner half-block loop, unrolled, second iteration. - } // End row-block loop. - - // The unrolling within this conditional may be somewhat pointless. It - // depends on the kinds of models. - if (row < end_row) { - const int residual_rows = end_row - row; - - const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; - float* dst_ptr = dst_col_ptr + row; - const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment; - - // Initialize with bias. - const __mmask16 row_mask = - (static_cast(1) << residual_rows) - 1; - const __m512 initial_accum_data = - _mm512_maskz_loadu_ps(row_mask, bias_ptr); - - // Process block in two halves, split by columns. - for (int mmm = 0; mmm < 2; ++mmm) { - __m512 accum_data_v0 = initial_accum_data; - __m512 accum_data_v1 = initial_accum_data; - __m512 accum_data_v2 = initial_accum_data; - __m512 accum_data_v3 = initial_accum_data; - __m512 accum_data_v4 = initial_accum_data; - __m512 accum_data_v5 = initial_accum_data; - __m512 accum_data_v6 = initial_accum_data; - __m512 accum_data_v7 = initial_accum_data; - - const float* lhs_ptr = lhs_col_ptr; - const float* rhs_ptr = rhs_col_ptr + 8 * mmm; - for (int d = 0; d < (params.depth - 1); ++d) { - const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr); - const float* rhs_data = rhs_ptr; - lhs_ptr += 16; - rhs_ptr += 16; - { - const __m512 dup_rhs_element_j0 = _mm512_set1_ps(rhs_data[0]); - accum_data_v0 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0); - const __m512 dup_rhs_element_j1 = _mm512_set1_ps(rhs_data[1]); - accum_data_v1 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1); - const __m512 dup_rhs_element_j2 = _mm512_set1_ps(rhs_data[2]); - accum_data_v2 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2); - const __m512 dup_rhs_element_j3 = _mm512_set1_ps(rhs_data[3]); - accum_data_v3 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3); - const __m512 dup_rhs_element_j4 = _mm512_set1_ps(rhs_data[4]); - accum_data_v4 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4); - const __m512 dup_rhs_element_j5 = _mm512_set1_ps(rhs_data[5]); - accum_data_v5 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5); - const __m512 dup_rhs_element_j6 = _mm512_set1_ps(rhs_data[6]); - accum_data_v6 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6); - const __m512 dup_rhs_element_j7 = _mm512_set1_ps(rhs_data[7]); - accum_data_v7 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7); - } - } - { - const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr); - const float* rhs_data = rhs_ptr; - { - const __m512 dup_rhs_element_j0 = _mm512_set1_ps(rhs_data[0]); - accum_data_v0 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0); - const __m512 dup_rhs_element_j1 = _mm512_set1_ps(rhs_data[1]); - accum_data_v1 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1); - const __m512 dup_rhs_element_j2 = _mm512_set1_ps(rhs_data[2]); - accum_data_v2 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2); - const __m512 dup_rhs_element_j3 = _mm512_set1_ps(rhs_data[3]); - accum_data_v3 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3); - const __m512 dup_rhs_element_j4 = _mm512_set1_ps(rhs_data[4]); - accum_data_v4 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4); - const __m512 dup_rhs_element_j5 = _mm512_set1_ps(rhs_data[5]); - accum_data_v5 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5); - const __m512 dup_rhs_element_j6 = _mm512_set1_ps(rhs_data[6]); - accum_data_v6 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6); - const __m512 dup_rhs_element_j7 = _mm512_set1_ps(rhs_data[7]); - accum_data_v7 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7); - } - { - float* block_ptr = dst_ptr + (mmm * 8 + 0) * dst_stride; - accum_data_v0 = _mm512_min_ps(accum_data_v0, clamp_max_v); - accum_data_v0 = _mm512_max_ps(accum_data_v0, clamp_min_v); - _mm512_mask_storeu_ps(block_ptr + 0 * dst_stride, row_mask, - accum_data_v0); - accum_data_v1 = _mm512_min_ps(accum_data_v1, clamp_max_v); - accum_data_v1 = _mm512_max_ps(accum_data_v1, clamp_min_v); - _mm512_mask_storeu_ps(block_ptr + 1 * dst_stride, row_mask, - accum_data_v1); - accum_data_v2 = _mm512_min_ps(accum_data_v2, clamp_max_v); - accum_data_v2 = _mm512_max_ps(accum_data_v2, clamp_min_v); - _mm512_mask_storeu_ps(block_ptr + 2 * dst_stride, row_mask, - accum_data_v2); - accum_data_v3 = _mm512_min_ps(accum_data_v3, clamp_max_v); - accum_data_v3 = _mm512_max_ps(accum_data_v3, clamp_min_v); - _mm512_mask_storeu_ps(block_ptr + 3 * dst_stride, row_mask, - accum_data_v3); - accum_data_v4 = _mm512_min_ps(accum_data_v4, clamp_max_v); - accum_data_v4 = _mm512_max_ps(accum_data_v4, clamp_min_v); - _mm512_mask_storeu_ps(block_ptr + 4 * dst_stride, row_mask, - accum_data_v4); - accum_data_v5 = _mm512_min_ps(accum_data_v5, clamp_max_v); - accum_data_v5 = _mm512_max_ps(accum_data_v5, clamp_min_v); - _mm512_mask_storeu_ps(block_ptr + 5 * dst_stride, row_mask, - accum_data_v5); - accum_data_v6 = _mm512_min_ps(accum_data_v6, clamp_max_v); - accum_data_v6 = _mm512_max_ps(accum_data_v6, clamp_min_v); - _mm512_mask_storeu_ps(block_ptr + 6 * dst_stride, row_mask, - accum_data_v6); - accum_data_v7 = _mm512_min_ps(accum_data_v7, clamp_max_v); - accum_data_v7 = _mm512_max_ps(accum_data_v7, clamp_min_v); - _mm512_mask_storeu_ps(block_ptr + 7 * dst_stride, row_mask, - accum_data_v7); - } - } - } // Inner half-block loop. - } // Residual rows, main col-block loop. - } // End col-block loop. - - if (col < end_col) { - RUY_DCHECK_GE(end_col - col, 0); - RUY_DCHECK_LT(end_col - col, 16); - - __m512 accum_data_v[8]; - - const float* rhs_col_ptr = adj_rhs_col_ptr + col * rhs_stride; - float* dst_col_ptr = adj_dst_col_ptr + col * dst_stride; - - for (int row = params.start_row; row < end_row; row += 16) { - const int residual_rows = std::min(end_row - row, 16); - - const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; - float* dst_ptr = dst_col_ptr + row; - const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment; - - // Initialize with bias. - const __mmask16 row_mask = - (static_cast(1) << residual_rows) - 1; - const __m512 initial_accum_data = - _mm512_maskz_loadu_ps(row_mask, bias_ptr); - - // Process block in two halves, split by columns. - for (int mmm = 0; mmm < 2; ++mmm) { - for (int j = 0; j < 8; ++j) { - accum_data_v[j] = initial_accum_data; - } - - const float* lhs_ptr = lhs_col_ptr; - const float* rhs_ptr = rhs_col_ptr + 8 * mmm; - for (int d = 0; d < params.depth; ++d) { - const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr); - const float* rhs_data = rhs_ptr; - - for (int j = 0; j < 8; ++j) { - const __m512 dup_rhs_element_j = _mm512_set1_ps(rhs_data[j]); - accum_data_v[j] = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v[j]); - } - lhs_ptr += 16; - rhs_ptr += 16; - } - - const int residual_cols = std::min(end_col - col - 8 * mmm, 8); - - if (residual_rows == 16) { - if (residual_cols == 8) { - for (int j = 0; j < 8; ++j) { - float* block_ptr = dst_ptr + (mmm * 8 + j) * dst_stride; - accum_data_v[j] = _mm512_min_ps(accum_data_v[j], clamp_max_v); - accum_data_v[j] = _mm512_max_ps(accum_data_v[j], clamp_min_v); - _mm512_storeu_ps(block_ptr, accum_data_v[j]); - } - } else { - for (int j = 0; j < residual_cols; ++j) { - float* block_ptr = dst_ptr + (mmm * 8 + j) * dst_stride; - accum_data_v[j] = _mm512_min_ps(accum_data_v[j], clamp_max_v); - accum_data_v[j] = _mm512_max_ps(accum_data_v[j], clamp_min_v); - _mm512_storeu_ps(block_ptr, accum_data_v[j]); - } - } - } else { - for (int j = 0; j < residual_cols; ++j) { - float* block_ptr = dst_ptr + (mmm * 8 + j) * dst_stride; - accum_data_v[j] = _mm512_min_ps(accum_data_v[j], clamp_max_v); - accum_data_v[j] = _mm512_max_ps(accum_data_v[j], clamp_min_v); - _mm512_mask_storeu_ps(block_ptr, row_mask, accum_data_v[j]); - } - } - } // Inner half-block loop. - } // End row-block loop. - } // Residual cols. -} - -void KernelFloatAvx512SingleCol(const KernelParamsFloat<16, 16>& params) { - profiler::ScopeLabel label("Kernel kAvx512 float GEMV"); - - RUY_DCHECK_EQ(params.dst_cols, 1); - RUY_DCHECK_EQ(params.last_col, 0); - RUY_DCHECK_EQ(params.start_col, 0); - - // As parameters are defined, we need to scale by sizeof(float). - const std::int64_t lhs_stride = params.lhs_stride >> 2; - - int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 1 : 0; - const int end_row = std::min(params.dst_rows, params.last_row + 16); - - float* adj_dst_col_ptr = params.dst_base_ptr - params.start_row; - const float* adj_lhs_col_ptr = - params.lhs_base_ptr - params.start_row * lhs_stride; - const float* bias_col_ptr = params.bias; - - const __m512 clamp_max_v = _mm512_set1_ps(params.clamp_max); - const __m512 clamp_min_v = _mm512_set1_ps(params.clamp_min); - - __m512 accum_data_v; - - const float* rhs_col_ptr = params.rhs_base_ptr; - float* dst_col_ptr = adj_dst_col_ptr; - - int row = params.start_row; - for (; row <= end_row - 16; row += 16) { - const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; - float* dst_ptr = dst_col_ptr + row; - const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment; - - // Initialize with bias. - accum_data_v = _mm512_loadu_ps(bias_ptr); - - const float* lhs_ptr = lhs_col_ptr; - const float* rhs_ptr = rhs_col_ptr; - for (int d = 0; d < params.depth; ++d) { - const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr); - const float rhs_data = *rhs_ptr; - - const __m512 dup_rhs_element_j = _mm512_set1_ps(rhs_data); - accum_data_v = _mm512_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v); - lhs_ptr += 16; - rhs_ptr += 16; - } - - accum_data_v = _mm512_min_ps(accum_data_v, clamp_max_v); - accum_data_v = _mm512_max_ps(accum_data_v, clamp_min_v); - _mm512_storeu_ps(dst_ptr, accum_data_v); - } // End row-block loop. - - if (row < end_row) { - const int residual_rows = end_row - row; - RUY_CHECK_GE(residual_rows, 1); - RUY_CHECK_LT(residual_rows, 16); - - const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; - float* dst_ptr = dst_col_ptr + row; - const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment; - - // Initialize with bias. - const __mmask16 row_mask = - (static_cast(1) << residual_rows) - 1; - accum_data_v = _mm512_maskz_loadu_ps(row_mask, bias_ptr); - - const float* lhs_ptr = lhs_col_ptr; - const float* rhs_ptr = rhs_col_ptr; - for (int d = 0; d < params.depth; ++d) { - const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr); - const float rhs_data = *rhs_ptr; - - const __m512 dup_rhs_element_j = _mm512_set1_ps(rhs_data); - accum_data_v = _mm512_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v); - lhs_ptr += 16; - rhs_ptr += 16; - } - - accum_data_v = _mm512_min_ps(accum_data_v, clamp_max_v); - accum_data_v = _mm512_max_ps(accum_data_v, clamp_min_v); - _mm512_mask_storeu_ps(dst_ptr, row_mask, accum_data_v); - } // End handling of residual rows. -} - -#endif // RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_ASM) - -} // namespace ruy diff --git a/kernel_avxvnni.cc b/kernel_avxvnni.cc deleted file mode 100644 index b7b8c9e..0000000 --- a/kernel_avxvnni.cc +++ /dev/null @@ -1,435 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include - -#include "check_macros.h" -#include "kernel.h" -#include "opt_set.h" -#include "platform.h" -#include "profiler/instrumentation.h" - -#if RUY_PLATFORM(AVX_VNNI) && RUY_OPT_ENABLED(RUY_OPT_ASM) -#include // IWYU pragma: keep -#endif - -namespace ruy { - -#if !(RUY_PLATFORM(AVX_VNNI) && RUY_OPT_ENABLED(RUY_OPT_ASM)) - -void Kernel8bitAvxVnni(const KernelParams8bit<16, 16>& params) { - // CPU-ID-based checks should disable the path that would reach this point. - RUY_DCHECK(false); -} - -void KernelFloatAvxVnni(const KernelParamsFloat<16, 16>& params) { - // CPU-ID-based checks should disable the path that would reach this point. - RUY_DCHECK(false); -} - -#else // RUY_PLATFORM(AVX_VNNI) && RUY_OPT_ENABLED(RUY_OPT_ASM) - -static constexpr int kAvxFloatBlockSize = 16; -static constexpr int kAvx8bitBlockSize = 16; -static constexpr int kAvx8bitInnerSize = 4; - -// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. -// Optimization is not finished. In particular the dimensions of the kernel -// blocks can be changed as desired. -// -// When removing this comment, update profiling label below. -void Kernel8bitAvxVnni(const KernelParams8bit<16, 16>& params) { - profiler::ScopeLabel label("Kernel kAvxVnni 8-bit (UNFINISHED)"); - - std::int32_t accum_data[kAvx8bitBlockSize][kAvx8bitBlockSize]; - - int bias_ptr_block_increment = - params.flags & RUY_ASM_FLAG_HAS_BIAS ? kAvx8bitBlockSize : 0; - - const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; - void* dst_col_ptr = params.dst_base_ptr; - const std::int32_t* bias_col_ptr = params.bias; - if (params.flags & RUY_ASM_FLAG_HAS_BIAS) { - bias_col_ptr += params.start_row; - } - - for (int col = params.start_col; col <= params.last_col; - col += kAvx8bitBlockSize) { - const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; - void* dst_ptr = dst_col_ptr; - const std::int32_t* bias_ptr = bias_col_ptr; - - for (int row = params.start_row; row <= params.last_row; - row += kAvx8bitBlockSize) { - const int residual_rows = - std::min(params.dst_rows - row, kAvx8bitBlockSize); - const int residual_cols = - std::min(params.dst_cols - col, kAvx8bitBlockSize); - - // Initialize with bias. - std::int32_t initial_accum_data[kAvx8bitBlockSize]; - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - initial_accum_data[i] = 0; - } - for (int i = 0; i < residual_rows; ++i) { - initial_accum_data[i] = bias_ptr[i]; - } - - for (int j = 0; j < kAvx8bitBlockSize; ++j) { - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - accum_data[j][i] = initial_accum_data[i]; - } - } - bias_ptr += bias_ptr_block_increment; - - std::int8_t lhs_data[kAvx8bitBlockSize][kAvx8bitInnerSize]; - std::int8_t rhs_data[kAvx8bitBlockSize][kAvx8bitInnerSize]; - const std::int8_t* lhs_ptr = lhs_col_ptr; - const std::int8_t* rhs_ptr = rhs_col_ptr; - for (int d = 0; d < params.depth; d += kAvx8bitInnerSize) { - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - for (int x = 0; x < kAvx8bitInnerSize; ++x) { - lhs_data[i][x] = lhs_ptr[i * kAvx8bitInnerSize + x]; - rhs_data[i][x] = rhs_ptr[i * kAvx8bitInnerSize + x]; - } - } - - for (int j = 0; j < kAvx8bitBlockSize; ++j) { - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - for (int x = 0; x < kAvx8bitInnerSize; ++x) { - accum_data[j][i] += lhs_data[i][x] * rhs_data[j][x]; - } - } - } - - lhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize; - rhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize; - } - - if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && params.rhs_zero_point) { - for (int j = 0; j < kAvx8bitBlockSize; ++j) { - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - accum_data[j][i] -= - params.rhs_zero_point * params.lhs_sums[row + i]; - } - } - } - if ((params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && params.lhs_zero_point) { - for (int j = 0; j < kAvx8bitBlockSize; ++j) { - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - accum_data[j][i] -= - params.lhs_zero_point * params.rhs_sums[col + j]; - } - } - } - if (params.lhs_zero_point && params.rhs_zero_point) { - for (int j = 0; j < kAvx8bitBlockSize; ++j) { - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - accum_data[j][i] += params.prod_zp_depth; - } - } - } - - if (params.dst_type_id != DstTypeId::kValue) { - std::int32_t m_vector[kAvx8bitBlockSize]; - std::int32_t e_vector[kAvx8bitBlockSize]; - // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT. - if (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) { - int i = 0; - for (; i < residual_rows; ++i) { - m_vector[i] = params.multiplier_fixedpoint[row + i]; - e_vector[i] = params.multiplier_exponent[row + i]; - } - for (; i < kAvx8bitBlockSize; ++i) { - m_vector[i] = m_vector[0]; - e_vector[i] = e_vector[0]; - } - } else { - // These arrays have size LhsCols, and are pre-filled. - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - m_vector[i] = params.multiplier_fixedpoint[i]; - e_vector[i] = params.multiplier_exponent[i]; - } - } - - for (int j = 0; j < kAvx8bitBlockSize; ++j) { - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - accum_data[j][i] = MultiplyByQuantizedMultiplier( - accum_data[j][i], m_vector[i], e_vector[i]); - } - } - - if (params.dst_zero_point) { - for (int j = 0; j < kAvx8bitBlockSize; ++j) { - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - accum_data[j][i] += params.dst_zero_point; - } - } - } - - for (int j = 0; j < kAvx8bitBlockSize; ++j) { - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - accum_data[j][i] = - std::min(accum_data[j][i], params.clamp_max); - accum_data[j][i] = - std::max(accum_data[j][i], params.clamp_min); - } - } - } - - const bool store_full_block = (residual_rows == kAvx8bitBlockSize) && - (residual_cols == kAvx8bitBlockSize); - - if (params.dst_type_id == DstTypeId::kValue) { - std::int8_t* tmp_ptr = - store_full_block - ? static_cast(dst_ptr) - : const_cast( - reinterpret_cast(params.dst_tmp_buf)); - const int block_col_offset = - store_full_block ? params.dst_stride / sizeof(std::int8_t) - : kAvx8bitBlockSize; - for (int j = 0; j < kAvx8bitBlockSize; ++j) { - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - tmp_ptr[i] = accum_data[j][i]; - } - tmp_ptr += block_col_offset; - } - - if (!store_full_block) { - const std::int8_t* block_ptr = - reinterpret_cast(params.dst_tmp_buf); - for (int j = 0; j < residual_cols; ++j) { - for (int i = 0; i < residual_rows; ++i) { - static_cast( - dst_ptr)[j * params.dst_stride / sizeof(std::int8_t) + i] = - block_ptr[i]; - } - block_ptr += kAvx8bitBlockSize; - } - } - dst_ptr = static_cast(static_cast(dst_ptr) + - kAvx8bitBlockSize); - } else if (params.dst_type_id == DstTypeId::kValue) { - std::uint8_t* tmp_ptr = store_full_block - ? static_cast(dst_ptr) - : const_cast( - reinterpret_cast( - params.dst_tmp_buf)); - const int block_col_offset = - store_full_block ? params.dst_stride : kAvx8bitBlockSize; - for (int j = 0; j < kAvx8bitBlockSize; ++j) { - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - tmp_ptr[i] = accum_data[j][i]; - } - tmp_ptr += block_col_offset; - } - - if (!store_full_block) { - const std::uint8_t* block_ptr = - reinterpret_cast(params.dst_tmp_buf); - for (int j = 0; j < residual_cols; ++j) { - for (int i = 0; i < residual_rows; ++i) { - static_cast( - dst_ptr)[j * params.dst_stride / sizeof(std::uint8_t) + i] = - block_ptr[i]; - } - block_ptr += kAvx8bitBlockSize; - } - } - dst_ptr = static_cast(static_cast(dst_ptr) + - kAvx8bitBlockSize); - } else if (params.dst_type_id == DstTypeId::kValue) { - if (store_full_block) { - std::int16_t* tmp_ptr = static_cast(dst_ptr); - const int block_col_offset = params.dst_stride / sizeof(std::int16_t); - for (int j = 0; j < kAvx8bitBlockSize; ++j) { - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - tmp_ptr[i] = accum_data[j][i]; - } - tmp_ptr += block_col_offset; - } - } else { - std::int16_t* tmp_ptr = const_cast( - reinterpret_cast(params.dst_tmp_buf)); - const int block_col_offset = kAvx8bitBlockSize; - for (int j = 0; j < kAvx8bitBlockSize; ++j) { - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - tmp_ptr[i] = accum_data[j][i]; - } - tmp_ptr += block_col_offset; - } - const std::int16_t* block_ptr = - reinterpret_cast(params.dst_tmp_buf); - std::int16_t* dst_block_ptr = static_cast(dst_ptr); - for (int j = 0; j < residual_cols; ++j) { - for (int i = 0; i < residual_rows; ++i) { - dst_block_ptr[i] = block_ptr[i]; - } - dst_block_ptr += params.dst_stride / sizeof(std::int16_t); - block_ptr += kAvx8bitBlockSize; - } - } - dst_ptr = static_cast(static_cast(dst_ptr) + - kAvx8bitBlockSize); - } else if (params.dst_type_id == DstTypeId::kValue) { - if (store_full_block) { - std::int32_t* tmp_ptr = static_cast(dst_ptr); - const int block_col_offset = params.dst_stride / sizeof(std::int32_t); - for (int j = 0; j < kAvx8bitBlockSize; ++j) { - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - tmp_ptr[i] = accum_data[j][i]; - } - tmp_ptr += block_col_offset; - } - } else { - std::int32_t* dst_block_ptr = static_cast(dst_ptr); - for (int j = 0; j < residual_cols; ++j) { - for (int i = 0; i < residual_rows; ++i) { - dst_block_ptr[i] = accum_data[j][i]; - } - dst_block_ptr += params.dst_stride / sizeof(std::int32_t); - } - } - dst_ptr = static_cast(static_cast(dst_ptr) + - kAvx8bitBlockSize); - } else { - RUY_DCHECK(false); - } - - lhs_col_ptr += kAvx8bitBlockSize * params.lhs_stride; - } // End row-block loop. - - dst_col_ptr = static_cast(static_cast(dst_col_ptr) + - kAvx8bitBlockSize * params.dst_stride); - rhs_col_ptr += kAvx8bitBlockSize * params.rhs_stride; - } // End col-block loop. -} // NOLINT(readability/fn_size) - -// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. -// Optimization is not finished. In particular the dimensions of the kernel -// blocks can be changed as desired. -// -// When removing this comment, update profiling label below. -void KernelFloatAvxVnni(const KernelParamsFloat<16, 16>& params) { - profiler::ScopeLabel label("Kernel kAvxVnni float (UNFINISHED)"); - - float lhs_data[kAvxFloatBlockSize]; - float rhs_data[kAvxFloatBlockSize]; - float accum_data[kAvxFloatBlockSize][kAvxFloatBlockSize]; - int bias_ptr_block_increment = - params.flags & RUY_ASM_FLAG_HAS_BIAS ? kAvxFloatBlockSize : 0; - - const float* rhs_col_ptr = params.rhs_base_ptr; - float* dst_col_ptr = params.dst_base_ptr; - const float* bias_col_ptr = params.bias; - if (params.flags & RUY_ASM_FLAG_HAS_BIAS) { - bias_col_ptr += params.start_row; - } - - for (int col = params.start_col; col <= params.last_col; - col += kAvxFloatBlockSize) { - const float* lhs_col_ptr = params.lhs_base_ptr; - float* dst_ptr = dst_col_ptr; - const float* bias_ptr = bias_col_ptr; - - for (int row = params.start_row; row <= params.last_row; - row += kAvxFloatBlockSize) { - const int residual_rows = - std::min(params.dst_rows - row, kAvxFloatBlockSize); - const int residual_cols = - std::min(params.dst_cols - col, kAvxFloatBlockSize); - - // Initialize with bias. - float initial_accum_data[kAvxFloatBlockSize]; - for (int i = 0; i < kAvxFloatBlockSize; ++i) { - initial_accum_data[i] = 0.0f; - } - for (int i = 0; i < residual_rows; ++i) { - initial_accum_data[i] = bias_ptr[i]; - } - for (int j = 0; j < kAvxFloatBlockSize; ++j) { - for (int i = 0; i < kAvxFloatBlockSize; ++i) { - accum_data[j][i] = initial_accum_data[i]; - } - } - bias_ptr += bias_ptr_block_increment; - - const float* lhs_ptr = lhs_col_ptr; - const float* rhs_ptr = rhs_col_ptr; - for (int d = 0; d < params.depth; ++d) { - for (int i = 0; i < kAvxFloatBlockSize; ++i) { - lhs_data[i] = lhs_ptr[i]; - rhs_data[i] = rhs_ptr[i]; - } - - for (int j = 0; j < kAvxFloatBlockSize; ++j) { - for (int i = 0; i < kAvxFloatBlockSize; ++i) { - accum_data[j][i] += lhs_data[i] * rhs_data[j]; - } - } - - lhs_ptr += kAvxFloatBlockSize; - rhs_ptr += kAvxFloatBlockSize; - } - - for (int j = 0; j < kAvxFloatBlockSize; ++j) { - for (int i = 0; i < kAvxFloatBlockSize; ++i) { - accum_data[j][i] = - std::min(accum_data[j][i], params.clamp_max); - accum_data[j][i] = - std::max(accum_data[j][i], params.clamp_min); - } - } - - const bool store_full_block = (residual_rows == kAvxFloatBlockSize) && - (residual_cols == kAvxFloatBlockSize); - - { - float* block_ptr = - store_full_block ? dst_ptr : const_cast(params.dst_tmp_buf); - const int block_col_offset = store_full_block - ? params.dst_stride / sizeof(float) - : kAvxFloatBlockSize; - for (int j = 0; j < kAvxFloatBlockSize; ++j) { - for (int i = 0; i < kAvxFloatBlockSize; ++i) { - block_ptr[i] = accum_data[j][i]; - } - block_ptr += block_col_offset; - } - } - if (!store_full_block) { - const float* block_ptr = params.dst_tmp_buf; - for (int j = 0; j < residual_cols; ++j) { - for (int i = 0; i < residual_rows; ++i) { - dst_ptr[j * params.dst_stride / sizeof(float) + i] = block_ptr[i]; - } - block_ptr += kAvxFloatBlockSize; - } - } - - lhs_col_ptr += kAvxFloatBlockSize * params.lhs_stride / sizeof(float); - dst_ptr += kAvxFloatBlockSize; - } // End row-block loop. - - dst_col_ptr += kAvxFloatBlockSize * params.dst_stride / sizeof(float); - rhs_col_ptr += kAvxFloatBlockSize * params.rhs_stride / sizeof(float); - } // End col-block loop. -} - -#endif // RUY_PLATFORM(AVX_VNNI) && RUY_OPT_ENABLED(RUY_OPT_ASM) - -} // namespace ruy diff --git a/kernel_common.h b/kernel_common.h deleted file mode 100644 index f20bd16..0000000 --- a/kernel_common.h +++ /dev/null @@ -1,481 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_KERNEL_COMMON_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_KERNEL_COMMON_H_ - -#include -#include -#include - -#include "check_macros.h" -#include "common.h" -#include "internal_matrix.h" -#include "matrix.h" -#include "opt_set.h" -#include "path.h" -#include "platform.h" -#include "profiler/instrumentation.h" -#include "side_pair.h" -#include "size_util.h" -#include "spec.h" -#include "tune.h" - -namespace ruy { - -template -struct Kernel {}; - -template -void RunKernelTyped(Tuning tuning, const PackedMatrix& lhs, - const PackedMatrix& rhs, const Spec& spec, - int start_row, int start_col, int end_row, int end_col, - Matrix* dst) { - using Kernel = Kernel; - Kernel kernel(tuning); -#if !defined(NDEBUG) || !RUY_OPT_ENABLED(RUY_OPT_FAT_KERNEL) - using LhsLayout = typename Kernel::LhsLayout; - using RhsLayout = typename Kernel::RhsLayout; -#endif - // end_row and end_col may be larger than dst dimensions. - // that is because kernels write directly to the destination matrix, whose - // dimensions may not be a multiple of the kernel dimensions, and we try to - // keep this annoyance localized as an implementation detail in kernels, - // by allowing to pass rounded-up values down as far as possible. - // These assertions encode the contract. - RUY_DCHECK_LE(0, start_row); - RUY_DCHECK_LE(start_row, end_row); - RUY_DCHECK_LT(end_row, dst->layout.rows + LhsLayout::kCols); - RUY_DCHECK_EQ((end_row - start_row) % LhsLayout::kCols, 0); - RUY_DCHECK_LE(0, start_col); - RUY_DCHECK_LE(start_col, end_col); - RUY_DCHECK_LT(end_col, dst->layout.cols + RhsLayout::kCols); - RUY_DCHECK_EQ((end_col - start_col) % RhsLayout::kCols, 0); -#if RUY_OPT_ENABLED(RUY_OPT_FAT_KERNEL) - kernel.Run(lhs, rhs, spec, start_row, start_col, end_row, end_col, dst); -#else - for (int col = start_col; col < end_col; col += RhsLayout::kCols) { - int block_end_col = std::min(col + RhsLayout::kCols, end_col); - for (int row = start_row; row < end_row; row += LhsLayout::kCols) { - int block_end_row = std::min(row + LhsLayout::kCols, end_row); - kernel.Run(lhs, rhs, spec, row, col, block_end_row, block_end_col, dst); - } - } -#endif -} - -// Main entry point for kernels. -template -void RunKernel(Tuning tuning, const SidePair& src, void* spec, - const SidePair& start, const SidePair& end, - DMatrix* dst) { - Matrix mdst = ToMatrix(*dst); - RunKernelTyped( - tuning, ToPackedMatrix(src[Side::kLhs]), - ToPackedMatrix(src[Side::kRhs]), - *static_cast(spec), start[Side::kLhs], start[Side::kRhs], - end[Side::kLhs], end[Side::kRhs], &mdst); -} - -// Copied from gemmlowp/fixedpoint. -inline std::int32_t SaturatingRoundingDoublingHighMul(std::int32_t a, - std::int32_t b) { - bool overflow = a == b && a == std::numeric_limits::min(); - std::int64_t a_64(a); - std::int64_t b_64(b); - std::int64_t ab_64 = a_64 * b_64; - std::int32_t nudge = ab_64 >= 0 ? (1 << 30) : (1 - (1 << 30)); - std::int32_t ab_x2_high32 = - static_cast((ab_64 + nudge) / (1ll << 31)); - return overflow ? std::numeric_limits::max() : ab_x2_high32; -} - -inline std::int32_t RoundingDivideByPOT(std::int32_t numerator, int exponent) { - std::int32_t sign = numerator >= 0 ? 1 : -1; - std::int32_t abs_numerator = std::abs(numerator); - std::int32_t mask = (1LL << exponent) - 1; - std::int32_t remainder = abs_numerator & mask; - std::int32_t threshold = mask >> 1; - std::int32_t abs_result = - (abs_numerator >> exponent) + (remainder > threshold ? 1 : 0); - return sign * abs_result; -} - -// Copied from TF Lite code. -inline std::int32_t MultiplyByQuantizedMultiplier( - std::int32_t x, std::int32_t quantized_multiplier, int shift) { - int left_shift = shift > 0 ? shift : 0; - int right_shift = shift > 0 ? 0 : -shift; - return RoundingDivideByPOT(SaturatingRoundingDoublingHighMul( - x * (1 << left_shift), quantized_multiplier), - right_shift); -} - -// Helper to apply a fixed-point multiplier. Only 'applicable' if AccumScalar -// is int32 (i.e. in all cases except floating-point) and if the destination is -// not int32 (i.e. unless the user wants to get raw accumulators). -template ::value && - !std::is_same::value> -struct ApplyMultiplierImpl {}; - -// Specialization in non-applicable case: do nothing, just check that values -// are default. -template -struct ApplyMultiplierImpl { - using AccumScalar = typename Spec::AccumScalar; - using DstScalar = typename Spec::DstScalar; - static void Run(const Spec& spec, int row, AccumScalar* accum) { - RUY_DCHECK_EQ(spec.multiplier_fixedpoint, 0); - RUY_DCHECK_EQ(spec.multiplier_exponent, 0); - } -}; - -template -struct ApplyMultiplierImpl { - using AccumScalar = typename Spec::AccumScalar; - using DstScalar = typename Spec::DstScalar; - static void Run(const Spec& spec, int row, AccumScalar* accum) { - AccumScalar m = spec.multiplier_fixedpoint_perchannel - ? spec.multiplier_fixedpoint_perchannel[row] - : spec.multiplier_fixedpoint; - int e = spec.multiplier_exponent_perchannel - ? spec.multiplier_exponent_perchannel[row] - : spec.multiplier_exponent; - *accum = MultiplyByQuantizedMultiplier(*accum, m, e); - } -}; - -template -void ApplyMultiplier(const Spec& spec, int row, - typename Spec::AccumScalar* accum) { - ApplyMultiplierImpl::Run(spec, row, accum); -} - -template -struct Kernel { - using AccumScalar = typename Spec::AccumScalar; - using LhsLayout = typename Spec::StandardCppKernelLhsLayout; - using RhsLayout = typename Spec::StandardCppKernelRhsLayout; - explicit Kernel(Tuning) {} - void Run(const PackedMatrix& lhs, - const PackedMatrix& rhs, const Spec& spec, int start_row, - int start_col, int end_row, int end_col, - Matrix* dst) const { - // See the comment in RunKernelTyped. end_row may be larger than - // dst->layout.rows. It's the responsibility of the kernel to avoid - // overrunning dst boundaries, which we do here by computing - // clamped_end_row. - int clamped_end_row = std::min(end_row, dst->layout.rows); - int clamped_end_col = std::min(end_col, dst->layout.cols); - RUY_DCHECK_LE(0, start_row); - RUY_DCHECK_LE(start_row, clamped_end_row); - RUY_DCHECK_LE(clamped_end_row, dst->layout.rows); - RUY_DCHECK_LE(clamped_end_row, end_row); - RUY_DCHECK_LE(end_row - clamped_end_row, LhsLayout::kCols); - RUY_DCHECK_LE(0, start_col); - RUY_DCHECK_LE(start_col, clamped_end_col); - RUY_DCHECK_LE(clamped_end_col, dst->layout.cols); - RUY_DCHECK_LE(clamped_end_col, end_col); - RUY_DCHECK_LE(end_col - clamped_end_col, RhsLayout::kCols); - profiler::ScopeLabel label("Kernel (Standard Cpp)"); - const int depth = lhs.layout.rows; - for (int i = start_row; i < clamped_end_row; i++) { - for (int j = start_col; j < clamped_end_col; j++) { - using AccumScalar = typename Spec::AccumScalar; - AccumScalar accum = 0; - for (int k = 0; k < depth; k++) { - AccumScalar lhs_val = Element(lhs, k, i); - AccumScalar rhs_val = Element(rhs, k, j); - accum += lhs_val * rhs_val; - } - if (spec.bias) { - accum += spec.bias[i]; - } - if (lhs.zero_point) { - accum -= lhs.zero_point * rhs.sums[j]; - } - if (rhs.zero_point) { - accum -= rhs.zero_point * lhs.sums[i]; - } - if (lhs.zero_point && rhs.zero_point) { - accum += lhs.zero_point * rhs.zero_point * depth; - } - ApplyMultiplier(spec, i, &accum); - accum += dst->zero_point; - accum = std::min(accum, spec.clamp_max); - accum = std::max(accum, spec.clamp_min); - *ElementPtr(dst, i, j) = static_cast(accum); - } - } - } -}; - -#define RUY_INHERIT_KERNEL(PARENT, CHILD) \ - template \ - struct Kernel \ - : Kernel { \ - explicit Kernel(Tuning tuning) \ - : Kernel(tuning) {} \ - }; - -#if RUY_PLATFORM(NEON) -RUY_INHERIT_KERNEL(Path::kStandardCpp, Path::kNeon) -RUY_INHERIT_KERNEL(Path::kNeon, Path::kNeonDotprod) -#elif RUY_PLATFORM(X86) -RUY_INHERIT_KERNEL(Path::kStandardCpp, Path::kSse42) -RUY_INHERIT_KERNEL(Path::kSse42, Path::kAvx2) -RUY_INHERIT_KERNEL(Path::kAvx2, Path::kAvx512) -RUY_INHERIT_KERNEL(Path::kAvx512, Path::kAvxVnni) -#endif - -// KernelParams are shared across 32-bit and 64-bit NEON code, and x86 code. -// -// In other cases, we still define (empty) versions, so that dummy kernels -// can use the classes in function signatures. -#if ((RUY_PLATFORM(NEON_64) || RUY_PLATFORM(NEON_32)) && \ - RUY_OPT_ENABLED(RUY_OPT_ASM)) || \ - RUY_PLATFORM(X86) - -#define RUY_ASM_FLAG_HAS_BIAS 0x1 -#define RUY_ASM_FLAG_HAS_LHS_SUMS 0x2 -#define RUY_ASM_FLAG_HAS_RHS_SUMS 0x4 -#define RUY_ASM_FLAG_HAS_PERCHANNEL 0x8 -#define RUY_ASM_FLAG_NEEDS_LEFT_SHIFT 0x10 - -#define RUY_ASM_TYPE_ID_UINT8 1 -#define RUY_ASM_TYPE_ID_INT8 2 -#define RUY_ASM_TYPE_ID_INT16 3 -#define RUY_ASM_TYPE_ID_INT32 4 - -template -struct DstTypeId {}; - -template <> -struct DstTypeId { - static constexpr int kValue = RUY_ASM_TYPE_ID_UINT8; -}; - -template <> -struct DstTypeId { - static constexpr int kValue = RUY_ASM_TYPE_ID_INT8; -}; - -template <> -struct DstTypeId { - static constexpr int kValue = RUY_ASM_TYPE_ID_INT16; -}; - -template <> -struct DstTypeId { - static constexpr int kValue = RUY_ASM_TYPE_ID_INT32; -}; - -template -struct KernelParams8bit { - static constexpr int kMaxDstTypeSize = 4; - - const std::int32_t* bias; - const std::int32_t* lhs_sums; - const std::int32_t* rhs_sums; - const std::int8_t* lhs_base_ptr; - const std::int32_t* multiplier_fixedpoint; - const std::int32_t* multiplier_exponent; - const std::int8_t* rhs_base_ptr; - void* dst_base_ptr; - std::int32_t lhs_zero_point; - std::int32_t rhs_zero_point; - std::int32_t dst_zero_point; - std::int32_t prod_zp_depth; - std::int32_t start_row; - std::int32_t start_col; - std::int32_t last_row; - std::int32_t last_col; - std::int32_t dst_rows; - std::int32_t dst_cols; - std::int32_t lhs_stride; - std::int32_t rhs_stride; - std::int32_t dst_stride; - std::int32_t depth; - std::int32_t clamp_min; - std::int32_t clamp_max; - std::uint8_t flags; - std::uint8_t dst_type_id; - const std::int32_t zero_data[LhsCols] = {0}; - std::uint8_t dst_tmp_buf[LhsCols * RhsCols * kMaxDstTypeSize]; - std::int32_t multiplier_fixedpoint_buf[LhsCols]; - std::int32_t multiplier_exponent_buf[LhsCols]; -}; - -template -void MakeKernelParams8bit(const PackedMatrix& lhs, - const PackedMatrix& rhs, - const BasicSpec& spec, - int start_row, int start_col, int end_row, - int end_col, Matrix* dst, - KernelParams8bit* params) { - using Params = KernelParams8bit; - - static_assert(sizeof(DstScalar) <= Params::kMaxDstTypeSize, ""); - - const int depth = lhs.layout.rows; - RUY_DCHECK_EQ(start_row % LhsCols, 0); - RUY_DCHECK_EQ(start_col % RhsCols, 0); - RUY_DCHECK_EQ(end_row % LhsCols, 0); - RUY_DCHECK_EQ(end_col % RhsCols, 0); - - params->lhs_base_ptr = lhs.data + start_row * lhs.layout.stride; - params->rhs_base_ptr = rhs.data + start_col * rhs.layout.stride; - params->flags = 0; - params->bias = params->zero_data; - if (spec.bias) { - params->bias = spec.bias; - params->flags |= RUY_ASM_FLAG_HAS_BIAS; - } - if (lhs.sums) { - params->lhs_sums = lhs.sums; - params->flags |= RUY_ASM_FLAG_HAS_LHS_SUMS; - } - if (rhs.sums) { - params->rhs_sums = rhs.sums; - params->flags |= RUY_ASM_FLAG_HAS_RHS_SUMS; - } - params->start_row = start_row; - params->start_col = start_col; - params->last_row = end_row - LhsCols; - params->last_col = end_col - RhsCols; - params->lhs_stride = lhs.layout.stride; - params->rhs_stride = rhs.layout.stride; - params->dst_stride = sizeof(DstScalar) * dst->layout.stride; - params->lhs_zero_point = lhs.zero_point; - params->rhs_zero_point = rhs.zero_point; - params->dst_zero_point = dst->zero_point; - params->depth = depth; - params->prod_zp_depth = lhs.zero_point * rhs.zero_point * depth; - if (spec.multiplier_fixedpoint_perchannel) { - params->flags |= RUY_ASM_FLAG_NEEDS_LEFT_SHIFT; - params->flags |= RUY_ASM_FLAG_HAS_PERCHANNEL; - params->multiplier_fixedpoint = spec.multiplier_fixedpoint_perchannel; - params->multiplier_exponent = spec.multiplier_exponent_perchannel; - } else { - if (spec.multiplier_exponent > 0) { - params->flags |= RUY_ASM_FLAG_NEEDS_LEFT_SHIFT; - } - params->multiplier_fixedpoint = params->multiplier_fixedpoint_buf; - params->multiplier_exponent = params->multiplier_exponent_buf; - for (int i = 0; i < LhsCols; i++) { - params->multiplier_fixedpoint_buf[i] = spec.multiplier_fixedpoint; - params->multiplier_exponent_buf[i] = spec.multiplier_exponent; - } - } - params->clamp_min = spec.clamp_min; - params->clamp_max = spec.clamp_max; - params->dst_rows = dst->layout.rows; - params->dst_cols = dst->layout.cols; - - RUY_DCHECK_LT(params->last_row, params->dst_rows); - RUY_DCHECK_LT(params->last_col, params->dst_cols); - - params->dst_type_id = DstTypeId::kValue; - params->dst_base_ptr = - dst->data.get() + start_col * dst->layout.stride + start_row; -} - -template -struct KernelParamsFloat { - const float* lhs_base_ptr; - const float* rhs_base_ptr; - float* dst_base_ptr; - const float* bias; - std::int32_t start_row; - std::int32_t start_col; - std::int32_t last_row; - std::int32_t last_col; - std::int32_t dst_rows; - std::int32_t dst_cols; - std::int32_t lhs_stride; - std::int32_t rhs_stride; - std::int32_t dst_stride; - std::int32_t depth; - float clamp_min; - float clamp_max; - std::uint8_t flags; - const float zero_data[LhsCols] = {0}; - float dst_tmp_buf[LhsCols * RhsCols]; -}; - -template -inline void MakeKernelParamsFloat(const PackedMatrix& lhs, - const PackedMatrix& rhs, - const BasicSpec& spec, - int start_row, int start_col, int end_row, - int end_col, Matrix* dst, - KernelParamsFloat* params) { - const int depth = lhs.layout.rows; - RUY_DCHECK_EQ(start_row % LhsCols, 0); - RUY_DCHECK_EQ(start_col % RhsCols, 0); - RUY_DCHECK_EQ(end_row % LhsCols, 0); - RUY_DCHECK_EQ(end_col % RhsCols, 0); - - params->lhs_base_ptr = lhs.data + start_row * lhs.layout.stride; - params->rhs_base_ptr = rhs.data + start_col * rhs.layout.stride; - params->dst_base_ptr = - dst->data.get() + start_col * dst->layout.stride + start_row; - - std::uint8_t flags = 0; - params->bias = params->zero_data; - if (spec.bias) { - params->bias = spec.bias; - flags |= RUY_ASM_FLAG_HAS_BIAS; - } - params->flags = flags; - params->start_row = start_row; - params->start_col = start_col; - params->last_row = end_row - LhsCols; - params->last_col = end_col - RhsCols; - params->lhs_stride = sizeof(float) * lhs.layout.stride; - params->rhs_stride = sizeof(float) * rhs.layout.stride; - params->dst_stride = sizeof(float) * dst->layout.stride; - params->depth = depth; - params->clamp_min = spec.clamp_min; - params->clamp_max = spec.clamp_max; - params->dst_rows = dst->layout.rows; - params->dst_cols = dst->layout.cols; - - RUY_DCHECK_LT(params->last_row, params->dst_rows); - RUY_DCHECK_LT(params->last_col, params->dst_cols); -} - -#else // ((RUY_PLATFORM(NEON_64) || RUY_PLATFORM(NEON_32)) && - // RUY_OPT_ENABLED(RUY_OPT_ASM)) || RUY_PLATFORM(X86) - -template -struct KernelParams8bit {}; - -template -struct KernelParamsFloat {}; - -#endif // ((RUY_PLATFORM(NEON_64) || RUY_PLATFORM(NEON_32)) && - // RUY_OPT_ENABLED(RUY_OPT_ASM)) || RUY_PLATFORM(X86) - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_KERNEL_COMMON_H_ diff --git a/kernel_sse42.cc b/kernel_sse42.cc deleted file mode 100644 index 37196a6..0000000 --- a/kernel_sse42.cc +++ /dev/null @@ -1,428 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include - -#include "check_macros.h" -#include "kernel.h" -#include "opt_set.h" -#include "platform.h" -#include "profiler/instrumentation.h" - -#if RUY_PLATFORM(SSE42) && RUY_OPT_ENABLED(RUY_OPT_ASM) -#include // IWYU pragma: keep -#endif - -namespace ruy { - -#if !(RUY_PLATFORM(SSE42) && RUY_OPT_ENABLED(RUY_OPT_ASM)) - -void Kernel8bitSse42(const KernelParams8bit<8, 8>& params) { - // CPU-ID-based checks should disable the path that would reach this point. - RUY_DCHECK(false); -} - -void KernelFloatSse42(const KernelParamsFloat<8, 8>& params) { - // CPU-ID-based checks should disable the path that would reach this point. - RUY_DCHECK(false); -} - -#else // RUY_PLATFORM(SSE42) && RUY_OPT_ENABLED(RUY_OPT_ASM) - -static constexpr int kAvxFloatBlockSize = 8; -static constexpr int kAvx8bitBlockSize = 8; -static constexpr int kAvx8bitInnerSize = 4; - -// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. -// Optimization is not finished. In particular the dimensions of the kernel -// blocks can be changed as desired. -// -// When removing this comment, update profiling label below. -void Kernel8bitSse42(const KernelParams8bit<8, 8>& params) { - profiler::ScopeLabel label("Kernel kSse42 8-bit (UNFINISHED)"); - std::int32_t accum_data[kAvx8bitBlockSize][kAvx8bitBlockSize]; - - int bias_ptr_block_increment = - params.flags & RUY_ASM_FLAG_HAS_BIAS ? kAvx8bitBlockSize : 0; - - const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; - void* dst_col_ptr = params.dst_base_ptr; - const std::int32_t* bias_col_ptr = params.bias; - if (params.flags & RUY_ASM_FLAG_HAS_BIAS) { - bias_col_ptr += params.start_row; - } - - for (int col = params.start_col; col <= params.last_col; - col += kAvx8bitBlockSize) { - const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; - void* dst_ptr = dst_col_ptr; - const std::int32_t* bias_ptr = bias_col_ptr; - - for (int row = params.start_row; row <= params.last_row; - row += kAvx8bitBlockSize) { - const int residual_rows = - std::min(params.dst_rows - row, kAvx8bitBlockSize); - const int residual_cols = - std::min(params.dst_cols - col, kAvx8bitBlockSize); - - // Initialize with bias. - std::int32_t initial_accum_data[kAvx8bitBlockSize]; - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - initial_accum_data[i] = 0; - } - for (int i = 0; i < residual_rows; ++i) { - initial_accum_data[i] = bias_ptr[i]; - } - for (int j = 0; j < kAvx8bitBlockSize; ++j) { - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - accum_data[j][i] = initial_accum_data[i]; - } - } - bias_ptr += bias_ptr_block_increment; - - std::int8_t lhs_data[kAvx8bitBlockSize][kAvx8bitInnerSize]; - std::int8_t rhs_data[kAvx8bitBlockSize][kAvx8bitInnerSize]; - const std::int8_t* lhs_ptr = lhs_col_ptr; - const std::int8_t* rhs_ptr = rhs_col_ptr; - for (int d = 0; d < params.depth; d += kAvx8bitInnerSize) { - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - for (int x = 0; x < kAvx8bitInnerSize; ++x) { - lhs_data[i][x] = lhs_ptr[i * kAvx8bitInnerSize + x]; - rhs_data[i][x] = rhs_ptr[i * kAvx8bitInnerSize + x]; - } - } - for (int j = 0; j < kAvx8bitBlockSize; ++j) { - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - for (int x = 0; x < kAvx8bitInnerSize; ++x) { - accum_data[j][i] += lhs_data[i][x] * rhs_data[j][x]; - } - } - } - lhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize; - rhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize; - } - - if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && params.rhs_zero_point) { - for (int j = 0; j < kAvx8bitBlockSize; ++j) { - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - accum_data[j][i] -= - params.rhs_zero_point * params.lhs_sums[row + i]; - } - } - } - if ((params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && params.lhs_zero_point) { - for (int j = 0; j < kAvx8bitBlockSize; ++j) { - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - accum_data[j][i] -= - params.lhs_zero_point * params.rhs_sums[col + j]; - } - } - } - if (params.lhs_zero_point && params.rhs_zero_point) { - for (int j = 0; j < kAvx8bitBlockSize; ++j) { - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - accum_data[j][i] += params.prod_zp_depth; - } - } - } - - if (params.dst_type_id != DstTypeId::kValue) { - std::int32_t m_vector[kAvx8bitBlockSize]; - std::int32_t e_vector[kAvx8bitBlockSize]; - // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT. - if (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) { - int i = 0; - for (; i < residual_rows; ++i) { - m_vector[i] = params.multiplier_fixedpoint[row + i]; - e_vector[i] = params.multiplier_exponent[row + i]; - } - for (; i < kAvx8bitBlockSize; ++i) { - m_vector[i] = m_vector[0]; - e_vector[i] = e_vector[0]; - } - } else { - // These arrays have size LhsCols, and are pre-filled. - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - m_vector[i] = params.multiplier_fixedpoint[i]; - e_vector[i] = params.multiplier_exponent[i]; - } - } - for (int j = 0; j < kAvx8bitBlockSize; ++j) { - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - accum_data[j][i] = MultiplyByQuantizedMultiplier( - accum_data[j][i], m_vector[i], e_vector[i]); - } - } - - if (params.dst_zero_point) { - for (int j = 0; j < kAvx8bitBlockSize; ++j) { - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - accum_data[j][i] += params.dst_zero_point; - } - } - } - - for (int j = 0; j < kAvx8bitBlockSize; ++j) { - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - accum_data[j][i] = - std::min(accum_data[j][i], params.clamp_max); - accum_data[j][i] = - std::max(accum_data[j][i], params.clamp_min); - } - } - } - - const bool store_full_block = (residual_rows == kAvx8bitBlockSize) && - (residual_cols == kAvx8bitBlockSize); - - if (params.dst_type_id == DstTypeId::kValue) { - std::int8_t* tmp_ptr = - store_full_block - ? static_cast(dst_ptr) - : const_cast( - reinterpret_cast(params.dst_tmp_buf)); - const int block_col_offset = - store_full_block ? params.dst_stride / sizeof(std::int8_t) - : kAvx8bitBlockSize; - for (int j = 0; j < kAvx8bitBlockSize; ++j) { - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - tmp_ptr[i] = accum_data[j][i]; - } - tmp_ptr += block_col_offset; - } - - if (!store_full_block) { - const std::int8_t* block_ptr = - reinterpret_cast(params.dst_tmp_buf); - for (int j = 0; j < residual_cols; ++j) { - for (int i = 0; i < residual_rows; ++i) { - static_cast( - dst_ptr)[j * params.dst_stride / sizeof(std::int8_t) + i] = - block_ptr[i]; - } - block_ptr += kAvx8bitBlockSize; - } - } - dst_ptr = static_cast(static_cast(dst_ptr) + - kAvx8bitBlockSize); - } else if (params.dst_type_id == DstTypeId::kValue) { - std::uint8_t* tmp_ptr = store_full_block - ? static_cast(dst_ptr) - : const_cast( - reinterpret_cast( - params.dst_tmp_buf)); - const int block_col_offset = - store_full_block ? params.dst_stride : kAvx8bitBlockSize; - for (int j = 0; j < kAvx8bitBlockSize; ++j) { - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - tmp_ptr[i] = accum_data[j][i]; - } - tmp_ptr += block_col_offset; - } - - if (!store_full_block) { - const std::uint8_t* block_ptr = - reinterpret_cast(params.dst_tmp_buf); - for (int j = 0; j < residual_cols; ++j) { - for (int i = 0; i < residual_rows; ++i) { - static_cast( - dst_ptr)[j * params.dst_stride / sizeof(std::uint8_t) + i] = - block_ptr[i]; - } - block_ptr += kAvx8bitBlockSize; - } - } - dst_ptr = static_cast(static_cast(dst_ptr) + - kAvx8bitBlockSize); - } else if (params.dst_type_id == DstTypeId::kValue) { - if (store_full_block) { - std::int16_t* tmp_ptr = static_cast(dst_ptr); - const int block_col_offset = params.dst_stride / sizeof(std::int16_t); - for (int j = 0; j < kAvx8bitBlockSize; ++j) { - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - tmp_ptr[i] = accum_data[j][i]; - } - tmp_ptr += block_col_offset; - } - } else { - std::int16_t* tmp_ptr = const_cast( - reinterpret_cast(params.dst_tmp_buf)); - const int block_col_offset = kAvx8bitBlockSize; - for (int j = 0; j < kAvx8bitBlockSize; ++j) { - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - tmp_ptr[i] = accum_data[j][i]; - } - tmp_ptr += block_col_offset; - } - const std::int16_t* block_ptr = - reinterpret_cast(params.dst_tmp_buf); - std::int16_t* dst_block_ptr = static_cast(dst_ptr); - for (int j = 0; j < residual_cols; ++j) { - for (int i = 0; i < residual_rows; ++i) { - dst_block_ptr[i] = block_ptr[i]; - } - dst_block_ptr += params.dst_stride / sizeof(std::int16_t); - block_ptr += kAvx8bitBlockSize; - } - } - dst_ptr = static_cast(static_cast(dst_ptr) + - kAvx8bitBlockSize); - } else if (params.dst_type_id == DstTypeId::kValue) { - if (store_full_block) { - std::int32_t* tmp_ptr = static_cast(dst_ptr); - const int block_col_offset = params.dst_stride / sizeof(std::int32_t); - for (int j = 0; j < kAvx8bitBlockSize; ++j) { - for (int i = 0; i < kAvx8bitBlockSize; ++i) { - tmp_ptr[i] = accum_data[j][i]; - } - tmp_ptr += block_col_offset; - } - } else { - std::int32_t* dst_block_ptr = static_cast(dst_ptr); - for (int j = 0; j < residual_cols; ++j) { - for (int i = 0; i < residual_rows; ++i) { - dst_block_ptr[i] = accum_data[j][i]; - } - dst_block_ptr += params.dst_stride / sizeof(std::int32_t); - } - } - dst_ptr = static_cast(static_cast(dst_ptr) + - kAvx8bitBlockSize); - } else { - RUY_DCHECK(false); - } - - lhs_col_ptr += kAvx8bitBlockSize * params.lhs_stride; - } // End row-block loop. - - dst_col_ptr = static_cast(static_cast(dst_col_ptr) + - kAvx8bitBlockSize * params.dst_stride); - rhs_col_ptr += kAvx8bitBlockSize * params.rhs_stride; - } // End col-block loop. -} // NOLINT(readability/fn_size) - -// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. -// Optimization is not finished. In particular the dimensions of the kernel -// blocks can be changed as desired. -// -// When removing this comment, update profiling label below. -void KernelFloatSse42(const KernelParamsFloat<8, 8>& params) { - profiler::ScopeLabel label("Kernel kSse42 float (UNFINISHED)"); - - float lhs_data[kAvxFloatBlockSize]; - float rhs_data[kAvxFloatBlockSize]; - float accum_data[kAvxFloatBlockSize][kAvxFloatBlockSize]; - int bias_ptr_block_increment = - params.flags & RUY_ASM_FLAG_HAS_BIAS ? kAvxFloatBlockSize : 0; - - const float* rhs_col_ptr = params.rhs_base_ptr; - float* dst_col_ptr = params.dst_base_ptr; - const float* bias_col_ptr = params.bias; - if (params.flags & RUY_ASM_FLAG_HAS_BIAS) { - bias_col_ptr += params.start_row; - } - - for (int col = params.start_col; col <= params.last_col; - col += kAvxFloatBlockSize) { - const float* lhs_col_ptr = params.lhs_base_ptr; - float* dst_ptr = dst_col_ptr; - const float* bias_ptr = bias_col_ptr; - - for (int row = params.start_row; row <= params.last_row; - row += kAvxFloatBlockSize) { - const int residual_rows = - std::min(params.dst_rows - row, kAvxFloatBlockSize); - const int residual_cols = - std::min(params.dst_cols - col, kAvxFloatBlockSize); - - // Initialize with bias. - float initial_accum_data[kAvxFloatBlockSize]; - for (int i = 0; i < kAvxFloatBlockSize; ++i) { - initial_accum_data[i] = 0.0f; - } - for (int i = 0; i < residual_rows; ++i) { - initial_accum_data[i] = bias_ptr[i]; - } - for (int j = 0; j < kAvxFloatBlockSize; ++j) { - for (int i = 0; i < kAvxFloatBlockSize; ++i) { - accum_data[j][i] = initial_accum_data[i]; - } - } - bias_ptr += bias_ptr_block_increment; - - const float* lhs_ptr = lhs_col_ptr; - const float* rhs_ptr = rhs_col_ptr; - for (int d = 0; d < params.depth; ++d) { - for (int i = 0; i < kAvxFloatBlockSize; ++i) { - lhs_data[i] = lhs_ptr[i]; - rhs_data[i] = rhs_ptr[i]; - } - for (int j = 0; j < kAvxFloatBlockSize; ++j) { - for (int i = 0; i < kAvxFloatBlockSize; ++i) { - accum_data[j][i] += lhs_data[i] * rhs_data[j]; - } - } - lhs_ptr += kAvxFloatBlockSize; - rhs_ptr += kAvxFloatBlockSize; - } - - for (int j = 0; j < kAvxFloatBlockSize; ++j) { - for (int i = 0; i < kAvxFloatBlockSize; ++i) { - accum_data[j][i] = - std::min(accum_data[j][i], params.clamp_max); - accum_data[j][i] = - std::max(accum_data[j][i], params.clamp_min); - } - } - - const bool store_full_block = (residual_rows == kAvxFloatBlockSize) && - (residual_cols == kAvxFloatBlockSize); - - { - float* block_ptr = - store_full_block ? dst_ptr : const_cast(params.dst_tmp_buf); - const int block_col_offset = store_full_block - ? params.dst_stride / sizeof(float) - : kAvxFloatBlockSize; - for (int j = 0; j < kAvxFloatBlockSize; ++j) { - for (int i = 0; i < kAvxFloatBlockSize; ++i) { - block_ptr[i] = accum_data[j][i]; - } - block_ptr += block_col_offset; - } - } - if (!store_full_block) { - const float* block_ptr = params.dst_tmp_buf; - for (int j = 0; j < residual_cols; ++j) { - for (int i = 0; i < residual_rows; ++i) { - dst_ptr[j * params.dst_stride / sizeof(float) + i] = block_ptr[i]; - } - block_ptr += kAvxFloatBlockSize; - } - } - - lhs_col_ptr += kAvxFloatBlockSize * params.lhs_stride / sizeof(float); - dst_ptr += kAvxFloatBlockSize; - } // End row-block loop. - - dst_col_ptr += kAvxFloatBlockSize * params.dst_stride / sizeof(float); - rhs_col_ptr += kAvxFloatBlockSize * params.rhs_stride / sizeof(float); - } // End col-block loop. -} - -#endif // RUY_PLATFORM(SSE42) && RUY_OPT_ENABLED(RUY_OPT_ASM) - -} // namespace ruy diff --git a/kernel_x86.h b/kernel_x86.h deleted file mode 100644 index d6ce72a..0000000 --- a/kernel_x86.h +++ /dev/null @@ -1,222 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_KERNEL_X86_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_KERNEL_X86_H_ - -#include - -#include "common.h" -#include "internal_matrix.h" -#include "kernel_common.h" -#include "matrix.h" -#include "opt_set.h" -#include "path.h" -#include "platform.h" -#include "spec.h" -#include "tune.h" - -namespace ruy { - -#if RUY_PLATFORM(X86) -// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. -// Optimization is not finished. In particular the dimensions of the kernel -// blocks can be changed as desired. -// -void Kernel8bitSse42(const KernelParams8bit<8, 8>& params); - -template -struct Kernel> { - Tuning tuning = Tuning::kAuto; - using LhsLayout = FixedKernelLayout; - using RhsLayout = FixedKernelLayout; - explicit Kernel(Tuning tuning_) : tuning(tuning_) {} - void Run(const PackedMatrix& lhs, - const PackedMatrix& rhs, - const BasicSpec& spec, int start_row, - int start_col, int end_row, int end_col, - Matrix* dst) const { - KernelParams8bit params; - MakeKernelParams8bit(lhs, rhs, spec, start_row, start_col, end_row, end_col, - dst, ¶ms); - Kernel8bitSse42(params); - } -}; - -void KernelFloatSse42(const KernelParamsFloat<8, 8>& params); - -template <> -struct Kernel> { - Tuning tuning = Tuning::kAuto; - using LhsLayout = FixedKernelLayout; - using RhsLayout = FixedKernelLayout; - explicit Kernel(Tuning tuning_) : tuning(tuning_) {} - void Run(const PackedMatrix& lhs, const PackedMatrix& rhs, - const BasicSpec& spec, int start_row, int start_col, - int end_row, int end_col, Matrix* dst) const { - KernelParamsFloat params; - MakeKernelParamsFloat(lhs, rhs, spec, start_row, start_col, end_row, - end_col, dst, ¶ms); - KernelFloatSse42(params); - } -}; - -void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params); -void Kernel8bitAvx512SingleCol(const KernelParams8bit<16, 16>& params); - -template -struct Kernel> { - Tuning tuning = Tuning::kAuto; - using LhsLayout = FixedKernelLayout; - using RhsLayout = FixedKernelLayout; - explicit Kernel(Tuning tuning_) : tuning(tuning_) {} - void Run(const PackedMatrix& lhs, - const PackedMatrix& rhs, - const BasicSpec& spec, int start_row, - int start_col, int end_row, int end_col, - Matrix* dst) const { - KernelParams8bit params; - MakeKernelParams8bit(lhs, rhs, spec, start_row, start_col, end_row, end_col, - dst, ¶ms); - if (dst->layout.cols == 1) { - Kernel8bitAvx512SingleCol(params); - } else { - Kernel8bitAvx512(params); - } - } -}; - -void KernelFloatAvx512(const KernelParamsFloat<16, 16>& params); -void KernelFloatAvx512SingleCol(const KernelParamsFloat<16, 16>& param); - -template <> -struct Kernel> { - Tuning tuning = Tuning::kAuto; - using LhsLayout = FixedKernelLayout; - using RhsLayout = FixedKernelLayout; - explicit Kernel(Tuning tuning_) : tuning(tuning_) {} - void Run(const PackedMatrix& lhs, const PackedMatrix& rhs, - const BasicSpec& spec, int start_row, int start_col, - int end_row, int end_col, Matrix* dst) const { - KernelParamsFloat params; - MakeKernelParamsFloat(lhs, rhs, spec, start_row, start_col, end_row, - end_col, dst, ¶ms); - if (dst->layout.cols == 1) { - KernelFloatAvx512SingleCol(params); - } else { - KernelFloatAvx512(params); - } - } -}; - -void Kernel8bitAvx2(const KernelParams8bit<8, 8>& params); -void Kernel8bitAvx2SingleCol(const KernelParams8bit<8, 8>& params); - -template -struct Kernel> { - Tuning tuning = Tuning::kAuto; - using LhsLayout = FixedKernelLayout; - using RhsLayout = FixedKernelLayout; - explicit Kernel(Tuning tuning_) : tuning(tuning_) {} - void Run(const PackedMatrix& lhs, - const PackedMatrix& rhs, - const BasicSpec& spec, int start_row, - int start_col, int end_row, int end_col, - Matrix* dst) const { - KernelParams8bit params; - MakeKernelParams8bit(lhs, rhs, spec, start_row, start_col, end_row, end_col, - dst, ¶ms); - if (dst->layout.cols == 1) { - Kernel8bitAvx2SingleCol(params); - } else { - Kernel8bitAvx2(params); - } - } -}; - -void KernelFloatAvx2(const KernelParamsFloat<8, 8>& params); -void KernelFloatAvx2SingleCol(const KernelParamsFloat<8, 8>& params); - -template <> -struct Kernel> { - Tuning tuning = Tuning::kAuto; - using LhsLayout = FixedKernelLayout; - using RhsLayout = FixedKernelLayout; - explicit Kernel(Tuning tuning_) : tuning(tuning_) {} - void Run(const PackedMatrix& lhs, const PackedMatrix& rhs, - const BasicSpec& spec, int start_row, int start_col, - int end_row, int end_col, Matrix* dst) const { - KernelParamsFloat params; - MakeKernelParamsFloat(lhs, rhs, spec, start_row, start_col, end_row, - end_col, dst, ¶ms); - if (dst->layout.cols == 1) { - KernelFloatAvx2SingleCol(params); - } else { - KernelFloatAvx2(params); - } - } -}; - -// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. -// Optimization is not finished. In particular the dimensions of the kernel -// blocks can be changed as desired. -// -void Kernel8bitAvxVnni(const KernelParams8bit<16, 16>& params); - -template -struct Kernel> { - Tuning tuning = Tuning::kAuto; - using LhsLayout = FixedKernelLayout; - using RhsLayout = FixedKernelLayout; - explicit Kernel(Tuning tuning_) : tuning(tuning_) {} - void Run(const PackedMatrix& lhs, - const PackedMatrix& rhs, - const BasicSpec& spec, int start_row, - int start_col, int end_row, int end_col, - Matrix* dst) const { - KernelParams8bit params; - MakeKernelParams8bit(lhs, rhs, spec, start_row, start_col, end_row, end_col, - dst, ¶ms); - Kernel8bitAvxVnni(params); - } -}; - -void KernelFloatAvxVnni(const KernelParamsFloat<16, 16>& params); - -template <> -struct Kernel> { - Tuning tuning = Tuning::kAuto; - using LhsLayout = FixedKernelLayout; - using RhsLayout = FixedKernelLayout; - explicit Kernel(Tuning tuning_) : tuning(tuning_) {} - void Run(const PackedMatrix& lhs, const PackedMatrix& rhs, - const BasicSpec& spec, int start_row, int start_col, - int end_row, int end_col, Matrix* dst) const { - KernelParamsFloat params; - MakeKernelParamsFloat(lhs, rhs, spec, start_row, start_col, end_row, - end_col, dst, ¶ms); - KernelFloatAvxVnni(params); - } -}; - -#endif // RUY_PLATFORM(X86) - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_KERNEL_X86_H_ diff --git a/matrix.h b/matrix.h deleted file mode 100644 index 8c5fbf9..0000000 --- a/matrix.h +++ /dev/null @@ -1,182 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_MATRIX_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_MATRIX_H_ - -#include -#include // IWYU pragma: keep -#include - -#include "check_macros.h" - -namespace ruy { - -// Layout storage order. Here and elsewhere, 'col' is short for 'column'. -// 'column-major' means that each column is contiguous in memory. -enum class Order : std::uint8_t { kColMajor, kRowMajor }; - -// Describes the shape and storage layout of a matrix. -struct Layout final { - std::int32_t rows = 0; - std::int32_t cols = 0; - // Stride is the offset between two adjacent matrix elements - // in the non-contiguous direction. - std::int32_t stride = 0; - Order order = Order::kColMajor; -}; - -namespace detail { - -// Thin wrapper around a pointer that tracks its constness dynamically. -// -// This is our take on the C++ problem of enforcing constness of data -// wrapped in a containers class: it's not worth the hassle of trying to -// make it fully work at compile-time. -// Instead, we only enforce constness at runtime, and to make it -// zero-overhead, we only enforce it in debug builds. -template -class ConstCheckingPtr final { - public: - using element_type = T; - - // Convenience methods. Most `set` calls go through these. - ConstCheckingPtr& operator=(T* ptr) { - set(ptr); - return *this; - } - ConstCheckingPtr& operator=(const T* ptr) { - set(ptr); - return *this; - } - ConstCheckingPtr& operator=(std::nullptr_t) { - set(static_cast(nullptr)); - return *this; - } - - // Core accessors. These encapsulate the main logic: - // - for `set`, the constness of the argument determines whether internal - // pointer should be tracked as const/mutable. - // - for `get`, the constness of `this` determines whether the call - // counts as a const or mutable use of the internal pointer. - void set(T* ptr) { - ptr_ = ptr; - set_mutable(true); - } - void set(const T* ptr) { - ptr_ = ptr; - set_mutable(false); - } - T* get() /* NOT const */ { - assert_mutable(); - return const_cast(ptr_); - } - const T* get() const { return ptr_; } - - private: - static_assert(!std::is_const::value, ""); - const T* ptr_ = nullptr; -#ifndef NDEBUG - bool is_mutable_ = true; - void set_mutable(bool val) { is_mutable_ = val; } - void assert_mutable() { RUY_DCHECK(is_mutable_); } -#else - void set_mutable(bool) {} - void assert_mutable() {} -#endif -}; - -} // namespace detail - -// A Matrix is really what Eigen and gemmlowp would have called a 'matrix map': -// it merely wraps existing data as a matrix. It doesn't own any buffer. -// Scalar may be any floating-point or integral type. When integral, it may be -// signed or unsigned. -template -struct Matrix final { - Matrix& operator=(const Matrix& other) { - data = other.data; - cacheable = other.cacheable; - layout = other.layout; - zero_point = other.zero_point; - return *this; - } - - // The underlying buffer wrapped by this matrix. - detail::ConstCheckingPtr data; - // The shape and data layout of this matrix. - Layout layout; - // The zero_point, i.e. which Scalar value is to be interpreted as zero. - // When Scalar is floating-point, this must be 0. - Scalar zero_point = 0; - // Clients of Ruy must set this flag to enable any caching behavior. Doesn't - // impact numerical results, but caching can impact observable metrics like - // latency, memory usage, power, etc. - bool cacheable = false; -}; - -inline void MakeSimpleLayout(int rows, int cols, Order order, Layout* layout) { - layout->rows = rows; - layout->cols = cols; - layout->order = order; - layout->stride = order == Order::kColMajor ? rows : cols; -} - -// Opaque data structure representing a pre-packed matrix, as obtained from -// Ruy's advanced API. -struct PrepackedMatrix { - void* data = nullptr; - std::size_t data_size = 0; - void* sums = nullptr; - std::size_t sums_size = 0; -}; - -template -StreamType& operator<<(StreamType& stream, const Matrix& mat) { - for (int row = 0; row < mat.layout.rows; row++) { - for (int col = 0; col < mat.layout.cols; col++) { - stream << static_cast(Element(mat, row, col)) << " "; - } - stream << "\n"; - } - return stream; -} - -// Compile-time version of KernelLayout, used to declare kernel layouts in a -// way that can be consumed by compile-time logic. -// See how partial specializations of Kernel use it to declare their layouts. -// The only reason why this is currently part of the public API is to -// allow testing various layouts for the Path::kStandardCpp kernel, as a -// testing-only feature. See Spec::StandardCppKernelLhsLayout. -template -struct FixedKernelLayout { - static constexpr Order kOrder = tOrder; - static constexpr int kRows = tRows; - static constexpr int kCols = tCols; -}; - -#if (__cplusplus < 201703L) -// A static constexpr data member is automatically inline and should not require -// redeclaration without an initializer. This is actually deprecated from C++17 -// onwards. Clang with -O0 without this can fail to link. -template -constexpr int FixedKernelLayout::kCols; -template -constexpr int FixedKernelLayout::kRows; -#endif - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_MATRIX_H_ diff --git a/opt_set.h b/opt_set.h deleted file mode 100644 index d082ade..0000000 --- a/opt_set.h +++ /dev/null @@ -1,51 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_OPT_SET_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_OPT_SET_H_ - -// RUY_OPT_SET is a compile-time API that Ruy provides for enabling/disabling -// certain optimizations. It should be used by defining that macro on the -// compiler command line. -// -// Each bit in RUY_OPT_SET controls a particular optimization done in Ruy. -#define RUY_OPT_INTRINSICS 0x1 -#define RUY_OPT_ASM 0x2 -#define RUY_OPT_TUNING 0x4 -#define RUY_OPT_FAT_KERNEL 0x8 -#define RUY_OPT_NATIVE_ROUNDING 0x10 -#define RUY_OPT_AVOID_ALIASING 0x20 -#define RUY_OPT_MAX_STREAMING 0x40 -#define RUY_OPT_PACK_AHEAD 0x80 -#define RUY_OPT_PREFETCH_LOAD 0x100 -#define RUY_OPT_PREFETCH_STORE 0x200 -#define RUY_OPT_FRACTAL_Z 0x400 -#define RUY_OPT_FRACTAL_U 0x800 -#define RUY_OPT_FRACTAL_HILBERT 0x1000 - -#if !defined(RUY_OPT_SET) -#ifdef RUY_OPTIMIZE_FOR_MATMUL_BENCHMARK -// Load prefetching is detrimental in matrix multiplication benchmarks. -// Store prefetching is not. -#define RUY_OPT_SET (~RUY_OPT_PREFETCH_LOAD) -#else -// Default to all optimizations. -#define RUY_OPT_SET (~0) -#endif -#endif - -#define RUY_OPT_ENABLED(ruy_opt) ((RUY_OPT_SET & ruy_opt) != 0) - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_OPT_SET_H_ diff --git a/pack.h b/pack.h deleted file mode 100644 index 4aaec2e..0000000 --- a/pack.h +++ /dev/null @@ -1,98 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// # What is "packing"? -// -// Before feeding data to the gemm kernels (the parts of Ruy that do lots -// of multiply-add operations), Ruy first performs a data transformation (which -// we call "packing") on the input matrices. This transformation has two main -// goals: -// - rearrange data into blocks that are a convenient size/layout for the gemm -// kernels to consume. This helps make the memory access pattern of the gemm -// kernel simpler and more contiguous, and puts the data in a layout most -// convenient for specific arithmetic instructions in the gemm kernel. -// - compute row/column sums needed for handling quantization with non-symmetric -// zero points. -// -// # Simplified algorithmic analysis of packing -// -// Packing is a relatively simple transformation which does a small constant -// amount of work on each element of an input matrix, and hence for an NxM -// matrix performs O(N*M) work. If N and M are of the same order, then this is -// O(N^2) work. -// -// A NxKxM matrix multiplication requires N*K*M multiply-accumulate operations. -// Note that if N, K, and M are all the same order, then the number of -// multiply-accumulate operations is O(N^3). -// -// Thus, the O(N^2) cost of packing is small compared to the O(N^3) work, in the -// case of all dimensions being roughly the same order. -// -// # Packing cost can be significant -// -// When matrix * matrix multiplications begin to look more like matrix * vector -// multiplications, packing cost can become significant. We sometimes call these -// cases "gemv-like". -// -// Continuing the algorithmic analysis above, if we consider a case where an -// NxKxM matrix multiplication has either N = O(1) or M = O(1), then the -// situation is different. In this case, the multiply-accumulate work is only -// quadratic, so the quadratic cost of packing can be come significant. -// -// Another way to say this is that the cost of packing an input matrix (either -// the LHS or RHS) is amortized across the non-depth dimension of the opposite -// input matrix. Thus, when the LHS has very few rows or the RHS has very few -// columns, the cost of packing the opposite input matrix can become -// significant. -// -// As a rough rule of thumb, the cost of packing starts to become significant -// when either N or M is below 32 (and other dimensions are hundreds), with very -// significant packing costs at 8 or below. This varies by data type, Path, and -// tuning, so these numbers are only rough guides. -// -// One practical use case that is affected by this is inference of -// fully connected neural network layers with a low batch size. The weight -// matrix (which is a constant for inference) is the one affected by significant -// packing cost. -// -// Ruy provides an API in ruy_advanced.h for advanced users to pre-pack -// input matrices that are affected by significant packing costs. -// -// # Implementation notes -// -// Ruy's packing routines always operate on a range of columns and can be -// applied to either the LHS or RHS. This is possible because Ruy internally -// implements a TrMul, so the accumulation along depth is done along columns of -// both the LHS and RHS (whereas for a normal Mul the accumulation along depth -// for the LHS is along rows). As another example, we are always computing -// column sums for quantization (and never row sums, since the LHS is -// transposed). - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_PACK_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_PACK_H_ - -#include "platform.h" - -// IWYU pragma: begin_exports -#if RUY_PLATFORM(NEON) -#include "pack_arm.h" -#elif RUY_PLATFORM(X86) -#include "pack_x86.h" -#else -#include "pack_common.h" -#endif -// IWYU pragma: end_exports - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_PACK_H_ diff --git a/pack_arm.cc b/pack_arm.cc deleted file mode 100644 index 549e615..0000000 --- a/pack_arm.cc +++ /dev/null @@ -1,1936 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include - -#include "common.h" -#include "opt_set.h" -#include "pack.h" -#include "platform.h" -#include "profiler/instrumentation.h" - -namespace ruy { - -#if RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM) - -void Pack8bitNeonOutOfOrder(const void* src_ptr0, const void* src_ptr1, - const void* src_ptr2, const void* src_ptr3, - int src_inc0, int src_inc1, int src_inc2, - int src_inc3, int src_rows, int src_zero_point, - std::int8_t* packed_ptr, int start_col, int end_col, - std::int32_t* sums_ptr, int input_xor) { - profiler::ScopeLabel label("Pack (kNeon, optimized for out-of-order cores)"); - asm volatile( - // clang-format off - "dup v26.16b, %w[input_xor]\n" - "mov w1, #0\n" - "dup v28.4s, wzr\n" - "dup v29.4s, wzr\n" - "dup v30.4s, wzr\n" - "dup v31.4s, wzr\n" - - "and w2, %w[rows], #-16\n" - "cmp w1, w2\n" - "beq 3f\n" - - "add w1, w1, #16\n" - "ld1 {v0.16b}, [%[src_ptr0]], %[src_inc0]\n" - "ld1 {v1.16b}, [%[src_ptr1]], %[src_inc1]\n" - "cmp w1, w2\n" - "ld1 {v2.16b}, [%[src_ptr2]], %[src_inc2]\n" - "ld1 {v3.16b}, [%[src_ptr3]], %[src_inc3]\n" - "beq 2f\n" - - "1:\n" - - "add w1, w1, #16\n" - "eor v4.16b, v0.16b, v26.16b\n" - "ld1 {v0.16b}, [%[src_ptr0]], %[src_inc0]\n" - "eor v5.16b, v1.16b, v26.16b\n" - "ld1 {v1.16b}, [%[src_ptr1]], %[src_inc1]\n" - "eor v6.16b, v2.16b, v26.16b\n" - "ld1 {v2.16b}, [%[src_ptr2]], %[src_inc2]\n" - "eor v7.16b, v3.16b, v26.16b\n" - "ld1 {v3.16b}, [%[src_ptr3]], %[src_inc3]\n" - - "saddlp v16.8h, v4.16b\n" - "str q4, [%[packed_ptr], #0]\n" - "saddlp v17.8h, v5.16b\n" - "str q5, [%[packed_ptr], #16]\n" - "saddlp v18.8h, v6.16b\n" - "str q6, [%[packed_ptr], #32]\n" - "saddlp v19.8h, v7.16b\n" - "str q7, [%[packed_ptr], #48]\n" - "sadalp v28.4s, v16.8h\n" - "cmp w1, w2\n" - "sadalp v29.4s, v17.8h\n" - "add %[packed_ptr], %[packed_ptr], #64\n" - "sadalp v30.4s, v18.8h\n" - "sadalp v31.4s, v19.8h\n" - - "bne 1b\n" - - "2:\n" - - "eor v4.16b, v0.16b, v26.16b\n" - "eor v5.16b, v1.16b, v26.16b\n" - "eor v6.16b, v2.16b, v26.16b\n" - "eor v7.16b, v3.16b, v26.16b\n" - - "saddlp v16.8h, v4.16b\n" - "str q4, [%[packed_ptr], #0]\n" - "saddlp v17.8h, v5.16b\n" - "str q5, [%[packed_ptr], #16]\n" - "saddlp v18.8h, v6.16b\n" - "str q6, [%[packed_ptr], #32]\n" - "saddlp v19.8h, v7.16b\n" - "str q7, [%[packed_ptr], #48]\n" - "sadalp v28.4s, v16.8h\n" - "sadalp v29.4s, v17.8h\n" - "sadalp v30.4s, v18.8h\n" - "sadalp v31.4s, v19.8h\n" - - "add %[packed_ptr], %[packed_ptr], #64\n" - - "3:\n" - - "ands w2, %w[rows], #15\n" - "beq 4f\n" - "dup v0.16b, %w[src_zero_point]\n" - "dup v1.16b, %w[src_zero_point]\n" - "dup v2.16b, %w[src_zero_point]\n" - "dup v3.16b, %w[src_zero_point]\n" -#define RUY_LOAD_ONE_ROW(R) \ - "cmp w2, #" #R "\n" \ - "beq 5f\n" \ - "ld1 { v0.b }[" #R "], [%[src_ptr0]], #1\n" \ - "ld1 { v1.b }[" #R "], [%[src_ptr1]], #1\n" \ - "ld1 { v2.b }[" #R "], [%[src_ptr2]], #1\n" \ - "ld1 { v3.b }[" #R "], [%[src_ptr3]], #1\n" - - RUY_LOAD_ONE_ROW(0) - RUY_LOAD_ONE_ROW(1) - RUY_LOAD_ONE_ROW(2) - RUY_LOAD_ONE_ROW(3) - RUY_LOAD_ONE_ROW(4) - RUY_LOAD_ONE_ROW(5) - RUY_LOAD_ONE_ROW(6) - RUY_LOAD_ONE_ROW(7) - RUY_LOAD_ONE_ROW(8) - RUY_LOAD_ONE_ROW(9) - RUY_LOAD_ONE_ROW(10) - RUY_LOAD_ONE_ROW(11) - RUY_LOAD_ONE_ROW(12) - RUY_LOAD_ONE_ROW(13) - RUY_LOAD_ONE_ROW(14) - RUY_LOAD_ONE_ROW(15) -#undef RUY_LOAD_ONE_ROW - "5:\n" - - "eor v4.16b, v0.16b, v26.16b\n" - "eor v5.16b, v1.16b, v26.16b\n" - "eor v6.16b, v2.16b, v26.16b\n" - "eor v7.16b, v3.16b, v26.16b\n" - - "saddlp v16.8h, v4.16b\n" - "saddlp v17.8h, v5.16b\n" - "saddlp v18.8h, v6.16b\n" - "saddlp v19.8h, v7.16b\n" - "sadalp v28.4s, v16.8h\n" - "sadalp v29.4s, v17.8h\n" - "sadalp v30.4s, v18.8h\n" - "sadalp v31.4s, v19.8h\n" - - "str q4, [%[packed_ptr], #0]\n" - "str q5, [%[packed_ptr], #16]\n" - "str q6, [%[packed_ptr], #32]\n" - "str q7, [%[packed_ptr], #48]\n" - "add %[packed_ptr], %[packed_ptr], #64\n" - - "4:\n" - - "addp v28.4s, v28.4s, v29.4s\n" - "addp v30.4s, v30.4s, v31.4s\n" - "addp v28.4s, v28.4s, v30.4s\n" - - "cmp %[sums_ptr], #0\n" - "beq 6f\n" - "st1 {v28.4s}, [%[sums_ptr]], #16\n" - "6:\n" - // clang-format on - - : [ src_ptr0 ] "+r"(src_ptr0), [ src_ptr1 ] "+r"(src_ptr1), - [ src_ptr2 ] "+r"(src_ptr2), [ src_ptr3 ] "+r"(src_ptr3), - [ packed_ptr ] "+r"(packed_ptr), [ sums_ptr ] "+r"(sums_ptr) - : [ src_inc0 ] "r"(static_cast(src_inc0)), - [ src_inc1 ] "r"(static_cast(src_inc1)), - [ src_inc2 ] "r"(static_cast(src_inc2)), - [ src_inc3 ] "r"(static_cast(src_inc3)), - [ rows ] "r"(src_rows), [ src_zero_point ] "r"(src_zero_point), - [ input_xor ] "r"(input_xor) - : "cc", "memory", "x1", "x2", "v0", "v1", "v2", "v3", "v4", "v5", "v6", - "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", - "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", - "v27", "v28", "v29", "v30", "v31"); -} -#endif - -#if RUY_PLATFORM(NEON_32) && RUY_OPT_ENABLED(RUY_OPT_ASM) - -#define RUY_OFFSET_SRC_PTR0 0 -#define RUY_OFFSET_SRC_PTR1 4 -#define RUY_OFFSET_SRC_PTR2 8 -#define RUY_OFFSET_SRC_PTR3 12 -#define RUY_OFFSET_SUMS_PTR 16 -#define RUY_OFFSET_PACKED_PTR 20 -#define RUY_OFFSET_SRC_INC0 24 -#define RUY_OFFSET_SRC_INC1 28 -#define RUY_OFFSET_SRC_INC2 32 -#define RUY_OFFSET_SRC_INC3 36 -#define RUY_OFFSET_SRC_ROWS 40 -#define RUY_OFFSET_SRC_ZERO_POINT 44 -#define RUY_OFFSET_INPUT_XOR 48 - -template -void CheckOffsetsInPackParams8bit(const Params&) { - static_assert(offsetof(Params, src_ptr0) == RUY_OFFSET_SRC_PTR0, ""); - static_assert(offsetof(Params, src_ptr1) == RUY_OFFSET_SRC_PTR1, ""); - static_assert(offsetof(Params, src_ptr2) == RUY_OFFSET_SRC_PTR2, ""); - static_assert(offsetof(Params, src_ptr3) == RUY_OFFSET_SRC_PTR3, ""); - static_assert(offsetof(Params, sums_ptr) == RUY_OFFSET_SUMS_PTR, ""); - static_assert(offsetof(Params, packed_ptr) == RUY_OFFSET_PACKED_PTR, ""); - static_assert(offsetof(Params, src_inc0) == RUY_OFFSET_SRC_INC0, ""); - static_assert(offsetof(Params, src_inc1) == RUY_OFFSET_SRC_INC1, ""); - static_assert(offsetof(Params, src_inc2) == RUY_OFFSET_SRC_INC2, ""); - static_assert(offsetof(Params, src_inc3) == RUY_OFFSET_SRC_INC3, ""); - static_assert(offsetof(Params, src_rows) == RUY_OFFSET_SRC_ROWS, ""); - static_assert(offsetof(Params, src_zero_point) == RUY_OFFSET_SRC_ZERO_POINT, - ""); - static_assert(offsetof(Params, input_xor) == RUY_OFFSET_INPUT_XOR, ""); -} - -// Packing code for out-of-order ARMv7 CPUs like the Krait 400 or A9. -// No attempt made at making this code efficient on in-order cores yet. -void Pack8bitNeonOutOfOrder4Cols(const PackParams8bit& params) { - CheckOffsetsInPackParams8bit(params); - profiler::ScopeLabel label("Pack (kNeon, optimized for out-of-order cores)"); - const void* src_ptr0 = params.src_ptr0; - const void* src_ptr1 = params.src_ptr1; - const void* src_ptr2 = params.src_ptr2; - const void* src_ptr3 = params.src_ptr3; - const int src_inc0 = params.src_inc0; - const int src_inc1 = params.src_inc1; - const int src_inc2 = params.src_inc2; - const int src_inc3 = params.src_inc3; - const std::int8_t* packed_ptr = params.packed_ptr; - - asm volatile( - // clang-format off - - "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_INPUT_XOR) "]\n" - "vdup.8 q11, r2\n" - "mov r1, #0\n" - // Zero-out the accumulators - "vmov.i32 q12, #0\n" - "vmov.i32 q13, #0\n" - "vmov.i32 q14, #0\n" - "vmov.i32 q15, #0\n" - - // Round down src_rows to nearest multiple of 16. - "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SRC_ROWS) "]\n" - "and r2, r3, #-16\n" - "cmp r1, r2\n" - "beq 3f\n" - - "1:\n" - "add r1, r1, #16\n" - /* Load q0 */ - "vld1.8 {d0, d1}, [%[src_ptr0]]\n" - "add %[src_ptr0], %[src_ptr0], %[src_inc0]\n" - RUY_PREFETCH_LOAD("pld [%[src_ptr0]]\n") - - /* Load q1 */ - "vld1.8 {d2, d3}, [%[src_ptr1]]\n" - "add %[src_ptr1], %[src_ptr1], %[src_inc1]\n" - RUY_PREFETCH_LOAD("pld [%[src_ptr1]]\n") - - "veor.8 q4, q0, q11\n" - "veor.8 q5, q1, q11\n" - - // Pairwise add in to 16b accumulators. - "vpaddl.s8 q8, q4\n" - "vpaddl.s8 q9, q5\n" - - "vst1.32 {q4}, [%[packed_ptr]]!\n" - "vst1.32 {q5}, [%[packed_ptr]]!\n" - - // Pairwise add accumulate into 32b accumulators. - // q12 and q13 contain 4x32b accumulators - "vpadal.s16 q12, q8\n" - "vpadal.s16 q13, q9\n" - - // Now do the same for src_ptr2 and src_ptr3. - "vld1.8 {d0, d1}, [%[src_ptr2]]\n" - "add %[src_ptr2], %[src_ptr2], %[src_inc2]\n" - RUY_PREFETCH_LOAD("pld [%[src_ptr2]]\n") - - "vld1.8 {d2, d3}, [%[src_ptr3]]\n" - "add %[src_ptr3], %[src_ptr3], %[src_inc3]\n" - RUY_PREFETCH_LOAD("pld [%[src_ptr3]]\n") - - "veor.8 q4, q0, q11\n" - "veor.8 q5, q1, q11\n" - - "vpaddl.s8 q8, q4\n" - "vpaddl.s8 q9, q5\n" - - "vst1.32 {q4}, [%[packed_ptr]]!\n" - "vst1.32 {q5}, [%[packed_ptr]]!\n" - - // Pairwise add accumulate into 32b accumulators. - // q14 and q15 contain 4x32b accumulators - "vpadal.s16 q14, q8\n" - "vpadal.s16 q15, q9\n" - - "cmp r1, r2\n" - "bne 1b\n" - - "3:\n" - - // Now pack the last (num_rows % 16) rows. - "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SRC_ROWS) "]\n" - "ands r2, r3, #15\n" - "beq 4f\n" - "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SRC_ZERO_POINT) "]\n" - "vdup.8 q0, r3\n" - "vdup.8 q1, r3\n" - -// First, read/accumulate/write for src_ptr0 and src_ptr1. -#define RUY_LOAD_ONE_ROW1(I, R) \ - "cmp r2, #" #I "\n" \ - "beq 5f\n" \ - "vld1.8 { d0[" #R "]}, [%[src_ptr0]]!\n" \ - "vld1.8 { d2[" #R "]}, [%[src_ptr1]]!\n" \ - - RUY_LOAD_ONE_ROW1(0, 0) - RUY_LOAD_ONE_ROW1(1, 1) - RUY_LOAD_ONE_ROW1(2, 2) - RUY_LOAD_ONE_ROW1(3, 3) - RUY_LOAD_ONE_ROW1(4, 4) - RUY_LOAD_ONE_ROW1(5, 5) - RUY_LOAD_ONE_ROW1(6, 6) - RUY_LOAD_ONE_ROW1(7, 7) -#undef RUY_LOAD_ONE_ROW1 - -#define RUY_LOAD_ONE_ROW2(I, R) \ - "cmp r2, #" #I "\n" \ - "beq 5f\n" \ - "vld1.8 { d1[" #R "]}, [%[src_ptr0]]!\n" \ - "vld1.8 { d3[" #R "]}, [%[src_ptr1]]!\n" \ - - RUY_LOAD_ONE_ROW2(8, 0) - RUY_LOAD_ONE_ROW2(9, 1) - RUY_LOAD_ONE_ROW2(10, 2) - RUY_LOAD_ONE_ROW2(11, 3) - RUY_LOAD_ONE_ROW2(12, 4) - RUY_LOAD_ONE_ROW2(13, 5) - RUY_LOAD_ONE_ROW2(14, 6) - RUY_LOAD_ONE_ROW2(15, 7) -#undef RUY_LOAD_ONE_ROW2 - - "5:\n" - - "veor.16 q4, q0, q11\n" - "veor.16 q5, q1, q11\n" - - "vpaddl.s8 q8, q4\n" - "vpaddl.s8 q9, q5\n" - - // Pairwise add accumulate to 4x32b accumulators. - "vpadal.s16 q12, q8\n" - "vpadal.s16 q13, q9\n" - - "vst1.32 {q4}, [%[packed_ptr]]!\n" - "vst1.32 {q5}, [%[packed_ptr]]!\n" - - // Reset to src_zero for src_ptr2 and src_ptr3. - "vdup.8 q0, r3\n" - "vdup.8 q1, r3\n" - -// Next, read/accumulate/write for src_ptr2 and src_ptr3. -#define RUY_LOAD_ONE_ROW1(I, R) \ - "cmp r2, #" #I "\n" \ - "beq 5f\n" \ - "vld1.8 { d0[" #R "]}, [%[src_ptr2]]!\n" \ - "vld1.8 { d2[" #R "]}, [%[src_ptr3]]!\n" \ - - RUY_LOAD_ONE_ROW1(0, 0) - RUY_LOAD_ONE_ROW1(1, 1) - RUY_LOAD_ONE_ROW1(2, 2) - RUY_LOAD_ONE_ROW1(3, 3) - RUY_LOAD_ONE_ROW1(4, 4) - RUY_LOAD_ONE_ROW1(5, 5) - RUY_LOAD_ONE_ROW1(6, 6) - RUY_LOAD_ONE_ROW1(7, 7) -#undef RUY_LOAD_ONE_ROW1 - -#define RUY_LOAD_ONE_ROW2(I, R) \ - "cmp r2, #" #I "\n" \ - "beq 5f\n" \ - "vld1.8 { d1[" #R "]}, [%[src_ptr2]]!\n" \ - "vld1.8 { d3[" #R "]}, [%[src_ptr3]]!\n" \ - - RUY_LOAD_ONE_ROW2(8, 0) - RUY_LOAD_ONE_ROW2(9, 1) - RUY_LOAD_ONE_ROW2(10, 2) - RUY_LOAD_ONE_ROW2(11, 3) - RUY_LOAD_ONE_ROW2(12, 4) - RUY_LOAD_ONE_ROW2(13, 5) - RUY_LOAD_ONE_ROW2(14, 6) - RUY_LOAD_ONE_ROW2(15, 7) -#undef RUY_LOAD_ONE_ROW2 - - "5:\n" - - "veor.16 q4, q0, q11\n" - "veor.16 q5, q1, q11\n" - - "vpaddl.s8 q8, q4\n" - "vpaddl.s8 q9, q5\n" - - // Pairwise add accumulate to 4x32b accumulators. - "vpadal.s16 q14, q8\n" - "vpadal.s16 q15, q9\n" - - "vst1.32 {q4}, [%[packed_ptr]]!\n" - "vst1.32 {q5}, [%[packed_ptr]]!\n" - - "4:\n" - // Pairwise add 32-bit accumulators - "vpadd.i32 d24, d24, d25\n" - "vpadd.i32 d26, d26, d27\n" - "vpadd.i32 d28, d28, d29\n" - "vpadd.i32 d30, d30, d31\n" - // Final 32-bit values per row - "vpadd.i32 d25, d24, d26\n" - "vpadd.i32 d27, d28, d30\n" - - "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SUMS_PTR) "]\n" - "cmp r3, #0\n" - "beq 6f\n" - "vst1.32 {d25}, [r3]!\n" - "vst1.32 {d27}, [r3]!\n" - "6:\n" - // clang-format on - - : [ src_ptr0 ] "+r"(src_ptr0), [ src_ptr1 ] "+r"(src_ptr1), - [ src_ptr2 ] "+r"(src_ptr2), [ src_ptr3 ] "+r"(src_ptr3) - : [ src_inc0 ] "r"(src_inc0), [ src_inc1 ] "r"(src_inc1), - [ src_inc2 ] "r"(src_inc2), [ src_inc3 ] "r"(src_inc3), - [ packed_ptr ] "r"(packed_ptr), [ params ] "r"(¶ms) - : "cc", "memory", "r1", "r2", "r3", "q0", "q1", "q2", "q3", - "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13"); -} - -// Packing code for out-of-order ARMv7 CPUs like the Krait 400 or A9. -// No attempt made at making this code efficient on in-order cores yet. -// This version differs from the above in that we only handle two columns -// at a time. -void Pack8bitNeonOutOfOrder2Cols(const PackParams8bit& params) { - CheckOffsetsInPackParams8bit(params); - profiler::ScopeLabel label("Pack (kNeon, optimized for out-of-order cores)"); - const void* src_ptr0 = params.src_ptr0; - const void* src_ptr1 = params.src_ptr1; - const int src_inc0 = params.src_inc0; - const int src_inc1 = params.src_inc1; - const std::int8_t* packed_ptr = params.packed_ptr; - - asm volatile( - // clang-format off - - "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_INPUT_XOR) "]\n" - "vdup.8 q11, r2\n" - "mov r1, #0\n" - // Zero-out the accumulators - "vmov.i32 q12, #0\n" - "vmov.i32 q13, #0\n" - - // Round down src_rows to nearest multiple of 16. - "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SRC_ROWS) "]\n" - "and r2, r3, #-16\n" - "cmp r1, r2\n" - "beq 3f\n" - - "1:\n" - "add r1, r1, #16\n" - /* Load q0 */ - "vld1.8 {d0, d1}, [%[src_ptr0]]\n" - "add %[src_ptr0], %[src_ptr0], %[src_inc0]\n" - - /* Load q1 */ - "vld1.8 {d2, d3}, [%[src_ptr1]]\n" - "add %[src_ptr1], %[src_ptr1], %[src_inc1]\n" - - "veor.8 q4, q0, q11\n" - "veor.8 q5, q1, q11\n" - - // Pairwise add in to 16b accumulators. - "vpaddl.s8 q8, q4\n" - "vpaddl.s8 q9, q5\n" - - "vst1.32 {q4}, [%[packed_ptr]]!\n" - "vst1.32 {q5}, [%[packed_ptr]]!\n" - - // Pairwise add accumulate into 32b accumulators. - // q12 and q13 contain 4x32b accumulators - "vpadal.s16 q12, q8\n" - "vpadal.s16 q13, q9\n" - - "cmp r1, r2\n" - - "bne 1b\n" - - "3:\n" - - // Now pack the last (num_rows % 16) rows. - "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SRC_ROWS) "]\n" - "ands r2, r3, #15\n" - "beq 4f\n" - "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SRC_ZERO_POINT) "]\n" - "vdup.8 q0, r3\n" - "vdup.8 q1, r3\n" - -// Read/accumulate/write for src_ptr0 and src_ptr1. -#define RUY_LOAD_ONE_ROW1(I, R) \ - "cmp r2, #" #I "\n" \ - "beq 5f\n" \ - "vld1.8 { d0[" #R "]}, [%[src_ptr0]]!\n" \ - "vld1.8 { d2[" #R "]}, [%[src_ptr1]]!\n" \ - - RUY_LOAD_ONE_ROW1(0, 0) - RUY_LOAD_ONE_ROW1(1, 1) - RUY_LOAD_ONE_ROW1(2, 2) - RUY_LOAD_ONE_ROW1(3, 3) - RUY_LOAD_ONE_ROW1(4, 4) - RUY_LOAD_ONE_ROW1(5, 5) - RUY_LOAD_ONE_ROW1(6, 6) - RUY_LOAD_ONE_ROW1(7, 7) -#undef RUY_LOAD_ONE_ROW1 - -#define RUY_LOAD_ONE_ROW2(I, R) \ - "cmp r2, #" #I "\n" \ - "beq 5f\n" \ - "vld1.8 { d1[" #R "]}, [%[src_ptr0]]!\n" \ - "vld1.8 { d3[" #R "]}, [%[src_ptr1]]!\n" \ - - RUY_LOAD_ONE_ROW2(8, 0) - RUY_LOAD_ONE_ROW2(9, 1) - RUY_LOAD_ONE_ROW2(10, 2) - RUY_LOAD_ONE_ROW2(11, 3) - RUY_LOAD_ONE_ROW2(12, 4) - RUY_LOAD_ONE_ROW2(13, 5) - RUY_LOAD_ONE_ROW2(14, 6) - RUY_LOAD_ONE_ROW2(15, 7) -#undef RUY_LOAD_ONE_ROW2 - - "5:\n" - - "veor.16 q4, q0, q11\n" - "veor.16 q5, q1, q11\n" - - "vpaddl.s8 q8, q4\n" - "vpaddl.s8 q9, q5\n" - - - // Pairwise add accumulate to 4x32b accumulators. - "vpadal.s16 q12, q8\n" - "vpadal.s16 q13, q9\n" - - "vst1.32 {q4}, [%[packed_ptr]]!\n" - "vst1.32 {q5}, [%[packed_ptr]]!\n" - - "4:\n" - - // Pairwise add 32-bit accumulators - "vpadd.i32 d24, d24, d25\n" - "vpadd.i32 d26, d26, d27\n" - // Final 32-bit values per row - "vpadd.i32 d25, d24, d26\n" - - "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SUMS_PTR) "]\n" - "cmp r3, #0\n" - "beq 6f\n" - "vst1.32 {d25}, [r3]!\n" - "6:\n" - // clang-format on - - : [ src_ptr0 ] "+r"(src_ptr0), [ src_ptr1 ] "+r"(src_ptr1) - : [ src_inc0 ] "r"(src_inc0), [ src_inc1 ] "r"(src_inc1), - [ packed_ptr ] "r"(packed_ptr), [ params ] "r"(¶ms) - : "cc", "memory", "r1", "r2", "r3", "q0", "q1", "q2", "q3", - "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13"); -} - -#undef RUY_OFFSET_SRC_PTR0 -#undef RUY_OFFSET_SRC_PTR1 -#undef RUY_OFFSET_SRC_PTR2 -#undef RUY_OFFSET_SRC_PTR32 -#undef RUY_OFFSET_SUMS_PTR -#undef RUY_OFFSET_PACKED_PTR0 -#undef RUY_OFFSET_SRC_INC0 -#undef RUY_OFFSET_SRC_INC1 -#undef RUY_OFFSET_SRC_INC2 -#undef RUY_OFFSET_SRC_INC3 -#undef RUY_OFFSET_SRC_ROWS -#undef RUY_OFFSET_SRC_ZERO_POINT -#undef RUY_OFFSET_INPUT_XOR - -#endif // RUY_PLATFORM(NEON_32) && RUY_OPT_ENABLED(RUY_OPT_ASM) - -#if RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM) - -void Pack8bitNeonInOrder(const void* src_ptr0, const void* src_ptr1, - const void* src_ptr2, const void* src_ptr3, - int src_inc0, int src_inc1, int src_inc2, int src_inc3, - int src_rows, int src_zero_point, - std::int8_t* packed_ptr, int start_col, int end_col, - std::int32_t* sums_ptr, int input_xor) { - profiler::ScopeLabel label("Pack (kNeon, optimized for in-order cores)"); - asm volatile( - // clang-format off - "dup v26.16b, %w[input_xor]\n" - "mov w1, #0\n" - "dup v28.4s, wzr\n" - "dup v29.4s, wzr\n" - "dup v30.4s, wzr\n" - "dup v31.4s, wzr\n" - - "and w2, %w[rows], #-16\n" - "cmp w1, w2\n" - "beq 3f\n" - "ldr x10, [%[src_ptr0], #8]\n" - "ld1 {v0.8b}, [%[src_ptr0]], %[src_inc0]\n" - "ldr x11, [%[src_ptr1], #8]\n" - "ld1 {v1.8b}, [%[src_ptr1]], %[src_inc1]\n" - "ldr x12, [%[src_ptr2], #8]\n" - "ld1 {v2.8b}, [%[src_ptr2]], %[src_inc2]\n" - "ldr x13, [%[src_ptr3], #8]\n" - "ld1 {v3.8b}, [%[src_ptr3]], %[src_inc3]\n" - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #64]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #64]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #64]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #64]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #128]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #128]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #128]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #128]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #192]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #192]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #192]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #192]\n") - "add w1, w1, #16\n" - "cmp w1, w2\n" - - "beq 2f\n" - - "1:\n" - "add w1, w1, #16\n" - "ins v0.d[1], x10\n" - "ldr x10, [%[src_ptr0], #8]\n" - "ins v1.d[1], x11\n" - "ldr x11, [%[src_ptr1], #8]\n" - "ins v2.d[1], x12\n" - "ldr x12, [%[src_ptr2], #8]\n" - "ins v3.d[1], x13\n" - "ldr x13, [%[src_ptr3], #8]\n" - "eor v4.16b, v0.16b, v26.16b\n" - "ld1 {v0.8b}, [%[src_ptr0]], %[src_inc0]\n" - "eor v5.16b, v1.16b, v26.16b\n" - "ld1 {v1.8b}, [%[src_ptr1]], %[src_inc1]\n" - "eor v6.16b, v2.16b, v26.16b\n" - "ld1 {v2.8b}, [%[src_ptr2]], %[src_inc2]\n" - "eor v7.16b, v3.16b, v26.16b\n" - "ld1 {v3.8b}, [%[src_ptr3]], %[src_inc3]\n" - "saddlp v16.8h, v4.16b\n" - "str q4, [%[packed_ptr], #0]\n" - "saddlp v17.8h, v5.16b\n" - "str q5, [%[packed_ptr], #16]\n" - "saddlp v18.8h, v6.16b\n" - "str q6, [%[packed_ptr], #32]\n" - "saddlp v19.8h, v7.16b\n" - "str q7, [%[packed_ptr], #48]\n" - "sadalp v28.4s, v16.8h\n" - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #240]\n") - "cmp w1, w2\n" - "sadalp v29.4s, v17.8h\n" - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #240]\n") - "add %[packed_ptr], %[packed_ptr], #64\n" - "sadalp v30.4s, v18.8h\n" - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #240]\n") - "sadalp v31.4s, v19.8h\n" - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #240]\n") - - "bne 1b\n" - - "2:\n" - "ins v0.d[1], x10\n" - "ins v1.d[1], x11\n" - "ins v2.d[1], x12\n" - "ins v3.d[1], x13\n" - "eor v4.16b, v0.16b, v26.16b\n" - "eor v5.16b, v1.16b, v26.16b\n" - "eor v6.16b, v2.16b, v26.16b\n" - "eor v7.16b, v3.16b, v26.16b\n" - - "saddlp v16.8h, v4.16b\n" - "str q4, [%[packed_ptr], #0]\n" - "saddlp v17.8h, v5.16b\n" - "str q5, [%[packed_ptr], #16]\n" - "saddlp v18.8h, v6.16b\n" - "str q6, [%[packed_ptr], #32]\n" - "saddlp v19.8h, v7.16b\n" - "str q7, [%[packed_ptr], #48]\n" - "sadalp v28.4s, v16.8h\n" - "sadalp v29.4s, v17.8h\n" - "sadalp v30.4s, v18.8h\n" - "sadalp v31.4s, v19.8h\n" - - "add %[packed_ptr], %[packed_ptr], #64\n" - - "3:\n" - - "ands w2, %w[rows], #15\n" - "beq 4f\n" - "dup v0.16b, %w[src_zero_point]\n" - "dup v1.16b, %w[src_zero_point]\n" - "dup v2.16b, %w[src_zero_point]\n" - "dup v3.16b, %w[src_zero_point]\n" -#define RUY_LOAD_ONE_ROW(R) \ - "cmp w2, #" #R "\n" \ - "beq 5f\n" \ - "ld1 { v0.b }[" #R "], [%[src_ptr0]], #1\n" \ - "ld1 { v1.b }[" #R "], [%[src_ptr1]], #1\n" \ - "ld1 { v2.b }[" #R "], [%[src_ptr2]], #1\n" \ - "ld1 { v3.b }[" #R "], [%[src_ptr3]], #1\n" - - RUY_LOAD_ONE_ROW(0) - RUY_LOAD_ONE_ROW(1) - RUY_LOAD_ONE_ROW(2) - RUY_LOAD_ONE_ROW(3) - RUY_LOAD_ONE_ROW(4) - RUY_LOAD_ONE_ROW(5) - RUY_LOAD_ONE_ROW(6) - RUY_LOAD_ONE_ROW(7) - RUY_LOAD_ONE_ROW(8) - RUY_LOAD_ONE_ROW(9) - RUY_LOAD_ONE_ROW(10) - RUY_LOAD_ONE_ROW(11) - RUY_LOAD_ONE_ROW(12) - RUY_LOAD_ONE_ROW(13) - RUY_LOAD_ONE_ROW(14) - RUY_LOAD_ONE_ROW(15) -#undef RUY_LOAD_ONE_ROW - "5:\n" - - "eor v4.16b, v0.16b, v26.16b\n" - "eor v5.16b, v1.16b, v26.16b\n" - "eor v6.16b, v2.16b, v26.16b\n" - "eor v7.16b, v3.16b, v26.16b\n" - - "saddlp v16.8h, v4.16b\n" - "saddlp v17.8h, v5.16b\n" - "saddlp v18.8h, v6.16b\n" - "saddlp v19.8h, v7.16b\n" - "sadalp v28.4s, v16.8h\n" - "sadalp v29.4s, v17.8h\n" - "sadalp v30.4s, v18.8h\n" - "sadalp v31.4s, v19.8h\n" - - "str q4, [%[packed_ptr], #0]\n" - "str q5, [%[packed_ptr], #16]\n" - "str q6, [%[packed_ptr], #32]\n" - "str q7, [%[packed_ptr], #48]\n" - "add %[packed_ptr], %[packed_ptr], #64\n" - - "4:\n" - - "addp v28.4s, v28.4s, v29.4s\n" - "addp v30.4s, v30.4s, v31.4s\n" - "addp v28.4s, v28.4s, v30.4s\n" - - "cmp %[sums_ptr], #0\n" - "beq 6f\n" - "st1 {v28.4s}, [%[sums_ptr]], #16\n" - "6:\n" - // clang-format on - - : [ src_ptr0 ] "+r"(src_ptr0), [ src_ptr1 ] "+r"(src_ptr1), - [ src_ptr2 ] "+r"(src_ptr2), [ src_ptr3 ] "+r"(src_ptr3), - [ packed_ptr ] "+r"(packed_ptr), [ sums_ptr ] "+r"(sums_ptr) - : [ src_inc0 ] "r"(static_cast(src_inc0)), [ src_inc1 ] "r"(static_cast(src_inc1)), - [ src_inc2 ] "r"(static_cast(src_inc2)), [ src_inc3 ] "r"(static_cast(src_inc3)), - [ rows ] "r"(src_rows), - [ src_zero_point ] "r"(src_zero_point), - [input_xor] "r"(input_xor) - : "cc", "memory", "x1", "x2", "x10", "x11", "x12", "x13", "v0", "v1", "v2", "v3", "v4", "v5", - "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", - "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", - "v25", "v26", "v27", "v28", "v29", "v30", "v31"); -} - -void Pack8bitNeonDotprodInOrder(const void* src_ptr0, const void* src_ptr1, - const void* src_ptr2, const void* src_ptr3, - int src_inc0, int src_inc1, int src_inc2, - int src_inc3, int src_rows, int src_zero_point, - std::int8_t* packed_ptr, int start_col, - int end_col, std::int32_t* sums_ptr, - int input_xor) { - profiler::ScopeLabel label( - "Pack (kNeonDotprod, optimized for in-order cores)"); - asm volatile( - // clang-format off - "dup v26.16b, %w[input_xor]\n" - "mov w1, #1\n" - "dup v27.16b, w1\n" - "mov w1, #0\n" - "dup v28.4s, wzr\n" - "dup v29.4s, wzr\n" - "dup v30.4s, wzr\n" - "dup v31.4s, wzr\n" - - "and w2, %w[rows], #-16\n" - "cmp w1, w2\n" - "beq 3f\n" - "ldr x10, [%[src_ptr0], #8]\n" - "ld1 {v0.8b}, [%[src_ptr0]], %[src_inc0]\n" - "ldr x11, [%[src_ptr1], #8]\n" - "ld1 {v1.8b}, [%[src_ptr1]], %[src_inc1]\n" - "ldr x12, [%[src_ptr2], #8]\n" - "ld1 {v2.8b}, [%[src_ptr2]], %[src_inc2]\n" - "ldr x13, [%[src_ptr3], #8]\n" - "ld1 {v3.8b}, [%[src_ptr3]], %[src_inc3]\n" - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #64]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #64]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #64]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #64]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #128]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #128]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #128]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #128]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #192]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #192]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #192]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #192]\n") - "add w1, w1, #16\n" - "cmp w1, w2\n" - - "beq 2f\n" - - "1:\n" - "add w1, w1, #16\n" - "ins v0.d[1], x10\n" - "ldr x10, [%[src_ptr0], #8]\n" - "ins v1.d[1], x11\n" - "ldr x11, [%[src_ptr1], #8]\n" - "ins v2.d[1], x12\n" - "ldr x12, [%[src_ptr2], #8]\n" - "ins v3.d[1], x13\n" - "ldr x13, [%[src_ptr3], #8]\n" - - "eor v4.16b, v0.16b, v26.16b\n" - "ld1 {v0.8b}, [%[src_ptr0]], %[src_inc0]\n" - "eor v5.16b, v1.16b, v26.16b\n" - "ld1 {v1.8b}, [%[src_ptr1]], %[src_inc1]\n" - "eor v6.16b, v2.16b, v26.16b\n" - "ld1 {v2.8b}, [%[src_ptr2]], %[src_inc2]\n" - "eor v7.16b, v3.16b, v26.16b\n" - "ld1 {v3.8b}, [%[src_ptr3]], %[src_inc3]\n" - - "trn1 v16.4s, v4.4s, v5.4s\n" - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #240]\n") - "trn2 v17.4s, v4.4s, v5.4s\n" - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #240]\n") - "trn1 v18.4s, v6.4s, v7.4s\n" - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #240]\n") - "trn2 v19.4s, v6.4s, v7.4s\n" - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #240]\n") - - "trn1 v20.2d, v16.2d, v18.2d\n" - "trn2 v22.2d, v16.2d, v18.2d\n" - "trn1 v21.2d, v17.2d, v19.2d\n" - "trn2 v23.2d, v17.2d, v19.2d\n" - "cmp w1, w2\n" - - ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" - "str q20, [%[packed_ptr], #0]\n" - ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" - "str q21, [%[packed_ptr], #32]\n" - ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" - "str q22, [%[packed_ptr], #64]\n" - ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" - "str q23, [%[packed_ptr], #96]\n" - - "add %[packed_ptr], %[packed_ptr], #128\n" - - "bne 1b\n" - - "2:\n" - "ins v0.d[1], x10\n" - "ins v1.d[1], x11\n" - "ins v2.d[1], x12\n" - "ins v3.d[1], x13\n" - "eor v0.16b, v0.16b, v26.16b\n" - "eor v1.16b, v1.16b, v26.16b\n" - "eor v2.16b, v2.16b, v26.16b\n" - "eor v3.16b, v3.16b, v26.16b\n" - - "trn1 v16.4s, v0.4s, v1.4s\n" - "trn2 v17.4s, v0.4s, v1.4s\n" - "trn1 v18.4s, v2.4s, v3.4s\n" - "trn2 v19.4s, v2.4s, v3.4s\n" - - "trn1 v20.2d, v16.2d, v18.2d\n" - "trn2 v22.2d, v16.2d, v18.2d\n" - "trn1 v21.2d, v17.2d, v19.2d\n" - "trn2 v23.2d, v17.2d, v19.2d\n" - - ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" - "str q20, [%[packed_ptr], #0]\n" - ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" - "str q21, [%[packed_ptr], #32]\n" - ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" - "str q22, [%[packed_ptr], #64]\n" - ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" - "str q23, [%[packed_ptr], #96]\n" - "add %[packed_ptr], %[packed_ptr], #128\n" - - "3:\n" - - "ands w2, %w[rows], #15\n" - "beq 4f\n" - "dup v0.16b, %w[src_zero_point]\n" - "dup v1.16b, %w[src_zero_point]\n" - "dup v2.16b, %w[src_zero_point]\n" - "dup v3.16b, %w[src_zero_point]\n" -#define RUY_LOAD_ONE_ROW(R) \ - "cmp w2, #" #R "\n" \ - "beq 5f\n" \ - "ld1 { v0.b }[" #R "], [%[src_ptr0]], #1\n" \ - "ld1 { v1.b }[" #R "], [%[src_ptr1]], #1\n" \ - "ld1 { v2.b }[" #R "], [%[src_ptr2]], #1\n" \ - "ld1 { v3.b }[" #R "], [%[src_ptr3]], #1\n" - - RUY_LOAD_ONE_ROW(0) - RUY_LOAD_ONE_ROW(1) - RUY_LOAD_ONE_ROW(2) - RUY_LOAD_ONE_ROW(3) - RUY_LOAD_ONE_ROW(4) - RUY_LOAD_ONE_ROW(5) - RUY_LOAD_ONE_ROW(6) - RUY_LOAD_ONE_ROW(7) - RUY_LOAD_ONE_ROW(8) - RUY_LOAD_ONE_ROW(9) - RUY_LOAD_ONE_ROW(10) - RUY_LOAD_ONE_ROW(11) - RUY_LOAD_ONE_ROW(12) - RUY_LOAD_ONE_ROW(13) - RUY_LOAD_ONE_ROW(14) - RUY_LOAD_ONE_ROW(15) -#undef RUY_LOAD_ONE_ROW - "5:\n" - - "eor v0.16b, v0.16b, v26.16b\n" - "eor v1.16b, v1.16b, v26.16b\n" - "eor v2.16b, v2.16b, v26.16b\n" - "eor v3.16b, v3.16b, v26.16b\n" - - "trn1 v16.4s, v0.4s, v1.4s\n" - "trn2 v17.4s, v0.4s, v1.4s\n" - "trn1 v18.4s, v2.4s, v3.4s\n" - "trn2 v19.4s, v2.4s, v3.4s\n" - - "trn1 v20.2d, v16.2d, v18.2d\n" - "trn2 v22.2d, v16.2d, v18.2d\n" - "trn1 v21.2d, v17.2d, v19.2d\n" - "trn2 v23.2d, v17.2d, v19.2d\n" - - ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" - "str q20, [%[packed_ptr], #0]\n" - "cmp w2, #4\n" - "ble 4f\n" - ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" - "str q21, [%[packed_ptr], #32]\n" - "cmp w2, #8\n" - "ble 4f\n" - ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" - "str q22, [%[packed_ptr], #64]\n" - "cmp w2, #12\n" - "ble 4f\n" - ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" - "str q23, [%[packed_ptr], #96]\n" - "add %[packed_ptr], %[packed_ptr], #128\n" - - "4:\n" - - "add v28.4s, v28.4s, v29.4s\n" - "add v30.4s, v30.4s, v31.4s\n" - "add v28.4s, v28.4s, v30.4s\n" - - "cmp %[sums_ptr], #0\n" - "beq 6f\n" - "st1 {v28.4s}, [%[sums_ptr]], #16\n" - "6:\n" - // clang-format on - - : [ src_ptr0 ] "+r"(src_ptr0), [src_ptr1] "+r"(src_ptr1), [src_ptr2] "+r"(src_ptr2), - [src_ptr3] "+r"(src_ptr3), [packed_ptr] "+r"(packed_ptr), [sums_ptr] "+r"(sums_ptr) - : [ src_inc0 ] "r"(static_cast(src_inc0)), [ src_inc1 ] "r"(static_cast(src_inc1)), - [ src_inc2 ] "r"(static_cast(src_inc2)), [ src_inc3 ] "r"(static_cast(src_inc3)), - [rows] "r"(src_rows), - [src_zero_point] "r"(static_cast(src_zero_point)), - [input_xor] "r"(input_xor) - : "cc", "memory", "x1", "x2", "x10", "x11", "x12", "x13", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", - "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", - "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); -} - -void Pack8bitNeonDotprodOutOfOrder(const void* src_ptr0, const void* src_ptr1, - const void* src_ptr2, const void* src_ptr3, - int src_inc0, int src_inc1, int src_inc2, - int src_inc3, int src_rows, - int src_zero_point, std::int8_t* packed_ptr, - int start_col, int end_col, - std::int32_t* sums_ptr, int input_xor) { - profiler::ScopeLabel label( - "Pack (kNeonDotprod, optimized for out-of-order cores)"); - asm volatile( - // clang-format off - "dup v26.16b, %w[input_xor]\n" - "mov w1, #1\n" - "dup v27.16b, w1\n" - "mov w1, #0\n" - "dup v28.4s, wzr\n" - "dup v29.4s, wzr\n" - "dup v30.4s, wzr\n" - "dup v31.4s, wzr\n" - -#if RUY_OPT_ENABLED(RUY_OPT_MAX_STREAMING) - "and w2, %w[rows], #-64\n" - "cmp w1, w2\n" - "beq 9f\n" - - "ld1 {v0.16b}, [%[src_ptr0]], %[src_inc0]\n" - "ld1 {v1.16b}, [%[src_ptr1]], %[src_inc1]\n" - "ld1 {v2.16b}, [%[src_ptr2]], %[src_inc2]\n" - "ld1 {v3.16b}, [%[src_ptr3]], %[src_inc3]\n" - "ld1 {v4.16b}, [%[src_ptr0]], %[src_inc0]\n" - "ld1 {v5.16b}, [%[src_ptr1]], %[src_inc1]\n" - "ld1 {v6.16b}, [%[src_ptr2]], %[src_inc2]\n" - "ld1 {v7.16b}, [%[src_ptr3]], %[src_inc3]\n" - "ld1 {v8.16b}, [%[src_ptr0]], %[src_inc0]\n" - "ld1 {v9.16b}, [%[src_ptr1]], %[src_inc1]\n" - "ld1 {v10.16b}, [%[src_ptr2]], %[src_inc2]\n" - "ld1 {v11.16b}, [%[src_ptr3]], %[src_inc3]\n" - "ld1 {v12.16b}, [%[src_ptr0]], %[src_inc0]\n" - "ld1 {v13.16b}, [%[src_ptr1]], %[src_inc1]\n" - "ld1 {v14.16b}, [%[src_ptr2]], %[src_inc2]\n" - "ld1 {v15.16b}, [%[src_ptr3]], %[src_inc3]\n" - "add w1, w1, #64\n" - "cmp w1, w2\n" - "beq 8f\n" - - "7:\n" - "eor v0.16b, v0.16b, v26.16b\n" - "eor v1.16b, v1.16b, v26.16b\n" - "eor v2.16b, v2.16b, v26.16b\n" - "eor v3.16b, v3.16b, v26.16b\n" - - "trn1 v16.4s, v0.4s, v1.4s\n" - "trn2 v17.4s, v0.4s, v1.4s\n" - "trn1 v18.4s, v2.4s, v3.4s\n" - "trn2 v19.4s, v2.4s, v3.4s\n" - - "ld1 {v0.16b}, [%[src_ptr0]], %[src_inc0]\n" - "ld1 {v1.16b}, [%[src_ptr1]], %[src_inc1]\n" - "ld1 {v2.16b}, [%[src_ptr2]], %[src_inc2]\n" - "ld1 {v3.16b}, [%[src_ptr3]], %[src_inc3]\n" - "add w1, w1, #16\n" - - "trn1 v20.2d, v16.2d, v18.2d\n" - "trn2 v22.2d, v16.2d, v18.2d\n" - "trn1 v21.2d, v17.2d, v19.2d\n" - "trn2 v23.2d, v17.2d, v19.2d\n" - - ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" - ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" - ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" - ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" - - "str q20, [%[packed_ptr], #0]\n" - "str q21, [%[packed_ptr], #32]\n" - "str q22, [%[packed_ptr], #64]\n" - "str q23, [%[packed_ptr], #96]\n" - "add %[packed_ptr], %[packed_ptr], #128\n" - - "eor v4.16b, v4.16b, v26.16b\n" - "eor v5.16b, v5.16b, v26.16b\n" - "eor v6.16b, v6.16b, v26.16b\n" - "eor v7.16b, v7.16b, v26.16b\n" - - "trn1 v16.4s, v4.4s, v5.4s\n" - "trn2 v17.4s, v4.4s, v5.4s\n" - "trn1 v18.4s, v6.4s, v7.4s\n" - "trn2 v19.4s, v6.4s, v7.4s\n" - - "ld1 {v4.16b}, [%[src_ptr0]], %[src_inc0]\n" - "ld1 {v5.16b}, [%[src_ptr1]], %[src_inc1]\n" - "ld1 {v6.16b}, [%[src_ptr2]], %[src_inc2]\n" - "ld1 {v7.16b}, [%[src_ptr3]], %[src_inc3]\n" - "add w1, w1, #16\n" - - "trn1 v20.2d, v16.2d, v18.2d\n" - "trn2 v22.2d, v16.2d, v18.2d\n" - "trn1 v21.2d, v17.2d, v19.2d\n" - "trn2 v23.2d, v17.2d, v19.2d\n" - - ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" - ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" - ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" - ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" - - "str q20, [%[packed_ptr], #0]\n" - "str q21, [%[packed_ptr], #32]\n" - "str q22, [%[packed_ptr], #64]\n" - "str q23, [%[packed_ptr], #96]\n" - "add %[packed_ptr], %[packed_ptr], #128\n" - - "eor v8.16b, v8.16b, v26.16b\n" - "eor v9.16b, v9.16b, v26.16b\n" - "eor v10.16b, v10.16b, v26.16b\n" - "eor v11.16b, v11.16b, v26.16b\n" - - "trn1 v16.4s, v8.4s, v9.4s\n" - "trn2 v17.4s, v8.4s, v9.4s\n" - "trn1 v18.4s, v10.4s, v11.4s\n" - "trn2 v19.4s, v10.4s, v11.4s\n" - - "ld1 {v8.16b}, [%[src_ptr0]], %[src_inc0]\n" - "ld1 {v9.16b}, [%[src_ptr1]], %[src_inc1]\n" - "ld1 {v10.16b}, [%[src_ptr2]], %[src_inc2]\n" - "ld1 {v11.16b}, [%[src_ptr3]], %[src_inc3]\n" - "add w1, w1, #16\n" - - "trn1 v20.2d, v16.2d, v18.2d\n" - "trn2 v22.2d, v16.2d, v18.2d\n" - "trn1 v21.2d, v17.2d, v19.2d\n" - "trn2 v23.2d, v17.2d, v19.2d\n" - - ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" - ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" - ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" - ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" - - "str q20, [%[packed_ptr], #0]\n" - "str q21, [%[packed_ptr], #32]\n" - "str q22, [%[packed_ptr], #64]\n" - "str q23, [%[packed_ptr], #96]\n" - "add %[packed_ptr], %[packed_ptr], #128\n" - - "eor v12.16b, v12.16b, v26.16b\n" - "eor v13.16b, v13.16b, v26.16b\n" - "eor v14.16b, v14.16b, v26.16b\n" - "eor v15.16b, v15.16b, v26.16b\n" - - "trn1 v16.4s, v12.4s, v13.4s\n" - "trn2 v17.4s, v12.4s, v13.4s\n" - "trn1 v18.4s, v14.4s, v15.4s\n" - "trn2 v19.4s, v14.4s, v15.4s\n" - - "ld1 {v12.16b}, [%[src_ptr0]], %[src_inc0]\n" - "ld1 {v13.16b}, [%[src_ptr1]], %[src_inc1]\n" - "ld1 {v14.16b}, [%[src_ptr2]], %[src_inc2]\n" - "ld1 {v15.16b}, [%[src_ptr3]], %[src_inc3]\n" - "add w1, w1, #16\n" - - "trn1 v20.2d, v16.2d, v18.2d\n" - "trn2 v22.2d, v16.2d, v18.2d\n" - "trn1 v21.2d, v17.2d, v19.2d\n" - "trn2 v23.2d, v17.2d, v19.2d\n" - - ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" - ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" - ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" - ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" - - "str q20, [%[packed_ptr], #0]\n" - "str q21, [%[packed_ptr], #32]\n" - "str q22, [%[packed_ptr], #64]\n" - "str q23, [%[packed_ptr], #96]\n" - "add %[packed_ptr], %[packed_ptr], #128\n" - - "cmp w1, w2\n" - "bne 7b\n" - - "8:\n" - - "eor v0.16b, v0.16b, v26.16b\n" - "eor v1.16b, v1.16b, v26.16b\n" - "eor v2.16b, v2.16b, v26.16b\n" - "eor v3.16b, v3.16b, v26.16b\n" - - "trn1 v16.4s, v0.4s, v1.4s\n" - "trn2 v17.4s, v0.4s, v1.4s\n" - "trn1 v18.4s, v2.4s, v3.4s\n" - "trn2 v19.4s, v2.4s, v3.4s\n" - - "trn1 v20.2d, v16.2d, v18.2d\n" - "trn2 v22.2d, v16.2d, v18.2d\n" - "trn1 v21.2d, v17.2d, v19.2d\n" - "trn2 v23.2d, v17.2d, v19.2d\n" - - ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" - ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" - ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" - ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" - - "str q20, [%[packed_ptr], #0]\n" - "str q21, [%[packed_ptr], #32]\n" - "str q22, [%[packed_ptr], #64]\n" - "str q23, [%[packed_ptr], #96]\n" - "add %[packed_ptr], %[packed_ptr], #128\n" - - "eor v4.16b, v4.16b, v26.16b\n" - "eor v5.16b, v5.16b, v26.16b\n" - "eor v6.16b, v6.16b, v26.16b\n" - "eor v7.16b, v7.16b, v26.16b\n" - - "trn1 v16.4s, v4.4s, v5.4s\n" - "trn2 v17.4s, v4.4s, v5.4s\n" - "trn1 v18.4s, v6.4s, v7.4s\n" - "trn2 v19.4s, v6.4s, v7.4s\n" - - "trn1 v20.2d, v16.2d, v18.2d\n" - "trn2 v22.2d, v16.2d, v18.2d\n" - "trn1 v21.2d, v17.2d, v19.2d\n" - "trn2 v23.2d, v17.2d, v19.2d\n" - - ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" - ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" - ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" - ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" - - "str q20, [%[packed_ptr], #0]\n" - "str q21, [%[packed_ptr], #32]\n" - "str q22, [%[packed_ptr], #64]\n" - "str q23, [%[packed_ptr], #96]\n" - "add %[packed_ptr], %[packed_ptr], #128\n" - - "eor v8.16b, v8.16b, v26.16b\n" - "eor v9.16b, v9.16b, v26.16b\n" - "eor v10.16b, v10.16b, v26.16b\n" - "eor v11.16b, v11.16b, v26.16b\n" - - "trn1 v16.4s, v8.4s, v9.4s\n" - "trn2 v17.4s, v8.4s, v9.4s\n" - "trn1 v18.4s, v10.4s, v11.4s\n" - "trn2 v19.4s, v10.4s, v11.4s\n" - - "trn1 v20.2d, v16.2d, v18.2d\n" - "trn2 v22.2d, v16.2d, v18.2d\n" - "trn1 v21.2d, v17.2d, v19.2d\n" - "trn2 v23.2d, v17.2d, v19.2d\n" - - ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" - ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" - ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" - ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" - - "str q20, [%[packed_ptr], #0]\n" - "str q21, [%[packed_ptr], #32]\n" - "str q22, [%[packed_ptr], #64]\n" - "str q23, [%[packed_ptr], #96]\n" - "add %[packed_ptr], %[packed_ptr], #128\n" - - "eor v12.16b, v12.16b, v26.16b\n" - "eor v13.16b, v13.16b, v26.16b\n" - "eor v14.16b, v14.16b, v26.16b\n" - "eor v15.16b, v15.16b, v26.16b\n" - - "trn1 v16.4s, v12.4s, v13.4s\n" - "trn2 v17.4s, v12.4s, v13.4s\n" - "trn1 v18.4s, v14.4s, v15.4s\n" - "trn2 v19.4s, v14.4s, v15.4s\n" - - "trn1 v20.2d, v16.2d, v18.2d\n" - "trn2 v22.2d, v16.2d, v18.2d\n" - "trn1 v21.2d, v17.2d, v19.2d\n" - "trn2 v23.2d, v17.2d, v19.2d\n" - - ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" - ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" - ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" - ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" - - "str q20, [%[packed_ptr], #0]\n" - "str q21, [%[packed_ptr], #32]\n" - "str q22, [%[packed_ptr], #64]\n" - "str q23, [%[packed_ptr], #96]\n" - "add %[packed_ptr], %[packed_ptr], #128\n" - - "9:\n" -#endif // #if RUY_OPT_ENABLED(RUY_OPT_MAX_STREAMING) - "and w2, %w[rows], #-16\n" - "cmp w1, w2\n" - "beq 3f\n" - - "ld1 {v0.16b}, [%[src_ptr0]], %[src_inc0]\n" - "ld1 {v1.16b}, [%[src_ptr1]], %[src_inc1]\n" - "ld1 {v2.16b}, [%[src_ptr2]], %[src_inc2]\n" - "ld1 {v3.16b}, [%[src_ptr3]], %[src_inc3]\n" - "add w1, w1, #16\n" - "cmp w1, w2\n" - "beq 2f\n" - - "1:\n" - - "eor v0.16b, v0.16b, v26.16b\n" - "eor v1.16b, v1.16b, v26.16b\n" - "eor v2.16b, v2.16b, v26.16b\n" - "eor v3.16b, v3.16b, v26.16b\n" - - "trn1 v16.4s, v0.4s, v1.4s\n" - "trn2 v17.4s, v0.4s, v1.4s\n" - "trn1 v18.4s, v2.4s, v3.4s\n" - "trn2 v19.4s, v2.4s, v3.4s\n" - - "ld1 {v0.16b}, [%[src_ptr0]], %[src_inc0]\n" - "ld1 {v1.16b}, [%[src_ptr1]], %[src_inc1]\n" - "ld1 {v2.16b}, [%[src_ptr2]], %[src_inc2]\n" - "ld1 {v3.16b}, [%[src_ptr3]], %[src_inc3]\n" - "add w1, w1, #16\n" - - "trn1 v20.2d, v16.2d, v18.2d\n" - "trn2 v22.2d, v16.2d, v18.2d\n" - "trn1 v21.2d, v17.2d, v19.2d\n" - "trn2 v23.2d, v17.2d, v19.2d\n" - - ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" - ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" - ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" - ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" - - "str q20, [%[packed_ptr], #0]\n" - "str q21, [%[packed_ptr], #32]\n" - "str q22, [%[packed_ptr], #64]\n" - "str q23, [%[packed_ptr], #96]\n" - "add %[packed_ptr], %[packed_ptr], #128\n" - - "cmp w1, w2\n" - "bne 1b\n" - - "2:\n" - - "eor v0.16b, v0.16b, v26.16b\n" - "eor v1.16b, v1.16b, v26.16b\n" - "eor v2.16b, v2.16b, v26.16b\n" - "eor v3.16b, v3.16b, v26.16b\n" - - "trn1 v16.4s, v0.4s, v1.4s\n" - "trn2 v17.4s, v0.4s, v1.4s\n" - "trn1 v18.4s, v2.4s, v3.4s\n" - "trn2 v19.4s, v2.4s, v3.4s\n" - - "trn1 v20.2d, v16.2d, v18.2d\n" - "trn2 v22.2d, v16.2d, v18.2d\n" - "trn1 v21.2d, v17.2d, v19.2d\n" - "trn2 v23.2d, v17.2d, v19.2d\n" - - ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" - ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" - ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" - ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" - - "str q20, [%[packed_ptr], #0]\n" - "str q21, [%[packed_ptr], #32]\n" - "str q22, [%[packed_ptr], #64]\n" - "str q23, [%[packed_ptr], #96]\n" - "add %[packed_ptr], %[packed_ptr], #128\n" - - "3:\n" - - "ands w2, %w[rows], #15\n" - "beq 4f\n" - "dup v0.16b, %w[src_zero_point]\n" - "dup v1.16b, %w[src_zero_point]\n" - "dup v2.16b, %w[src_zero_point]\n" - "dup v3.16b, %w[src_zero_point]\n" -#define RUY_LOAD_ONE_ROW(R) \ - "cmp w2, #" #R "\n" \ - "beq 5f\n" \ - "ld1 { v0.b }[" #R "], [%[src_ptr0]], #1\n" \ - "ld1 { v1.b }[" #R "], [%[src_ptr1]], #1\n" \ - "ld1 { v2.b }[" #R "], [%[src_ptr2]], #1\n" \ - "ld1 { v3.b }[" #R "], [%[src_ptr3]], #1\n" - - RUY_LOAD_ONE_ROW(0) - RUY_LOAD_ONE_ROW(1) - RUY_LOAD_ONE_ROW(2) - RUY_LOAD_ONE_ROW(3) - RUY_LOAD_ONE_ROW(4) - RUY_LOAD_ONE_ROW(5) - RUY_LOAD_ONE_ROW(6) - RUY_LOAD_ONE_ROW(7) - RUY_LOAD_ONE_ROW(8) - RUY_LOAD_ONE_ROW(9) - RUY_LOAD_ONE_ROW(10) - RUY_LOAD_ONE_ROW(11) - RUY_LOAD_ONE_ROW(12) - RUY_LOAD_ONE_ROW(13) - RUY_LOAD_ONE_ROW(14) - RUY_LOAD_ONE_ROW(15) -#undef RUY_LOAD_ONE_ROW - "5:\n" - - "eor v0.16b, v0.16b, v26.16b\n" - "eor v1.16b, v1.16b, v26.16b\n" - "eor v2.16b, v2.16b, v26.16b\n" - "eor v3.16b, v3.16b, v26.16b\n" - - "trn1 v16.4s, v0.4s, v1.4s\n" - "trn2 v17.4s, v0.4s, v1.4s\n" - "trn1 v18.4s, v2.4s, v3.4s\n" - "trn2 v19.4s, v2.4s, v3.4s\n" - - "trn1 v20.2d, v16.2d, v18.2d\n" - "trn2 v22.2d, v16.2d, v18.2d\n" - "trn1 v21.2d, v17.2d, v19.2d\n" - "trn2 v23.2d, v17.2d, v19.2d\n" - - ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" - "str q20, [%[packed_ptr], #0]\n" - "cmp w2, #4\n" - "ble 4f\n" - ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" - "str q21, [%[packed_ptr], #32]\n" - "cmp w2, #8\n" - "ble 4f\n" - ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" - "str q22, [%[packed_ptr], #64]\n" - "cmp w2, #12\n" - "ble 4f\n" - ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" - "str q23, [%[packed_ptr], #96]\n" - "add %[packed_ptr], %[packed_ptr], #128\n" - - "4:\n" - - "add v28.4s, v28.4s, v29.4s\n" - "add v30.4s, v30.4s, v31.4s\n" - "add v28.4s, v28.4s, v30.4s\n" - - "cmp %[sums_ptr], #0\n" - "beq 6f\n" - "st1 {v28.4s}, [%[sums_ptr]], #16\n" - "6:\n" - // clang-format on - - : [ src_ptr0 ] "+r"(src_ptr0), [ src_ptr1 ] "+r"(src_ptr1), - [ src_ptr2 ] "+r"(src_ptr2), [ src_ptr3 ] "+r"(src_ptr3), - [ packed_ptr ] "+r"(packed_ptr), [ sums_ptr ] "+r"(sums_ptr) - : [ src_inc0 ] "r"(static_cast(src_inc0)), - [ src_inc1 ] "r"(static_cast(src_inc1)), - [ src_inc2 ] "r"(static_cast(src_inc2)), - [ src_inc3 ] "r"(static_cast(src_inc3)), - [ rows ] "r"(src_rows), - [ src_zero_point ] "r"(static_cast(src_zero_point)), - [ input_xor ] "r"(input_xor) - : "cc", "memory", "x1", "x2", "v0", "v1", "v2", "v3", "v4", "v5", "v6", - "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", - "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", - "v27", "v28", "v29", "v30", "v31"); -} - -#endif // RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM) - -#if RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM) -void PackFloatNeonOutOfOrder(const float* src_ptr0, const float* src_ptr1, - const float* src_ptr2, const float* src_ptr3, - int src_inc0, int src_inc1, int src_inc2, - int src_inc3, int src_rows, int src_zero_point, - float* packed_ptr, int start_col, int end_col) { - profiler::ScopeLabel label("Pack (kNeon, optimized for out-of-order cores)"); - asm volatile( - // clang-format off - "mov w1, #0\n" - - "and w2, %w[rows], #-4\n" - "cmp w1, w2\n" - "beq 3f\n" - "ld1 {v0.4s}, [%[src_ptr0]], %[src_inc0]\n" - "ld1 {v1.4s}, [%[src_ptr1]], %[src_inc1]\n" - "ld1 {v2.4s}, [%[src_ptr2]], %[src_inc2]\n" - "ld1 {v3.4s}, [%[src_ptr3]], %[src_inc3]\n" - "add w1, w1, #4\n" - "cmp w1, w2\n" - - "beq 2f\n" - - "1:\n" - "add w1, w1, #4\n" - - "trn1 v16.4s, v0.4s, v1.4s\n" - "trn2 v17.4s, v0.4s, v1.4s\n" - "trn1 v18.4s, v2.4s, v3.4s\n" - "trn2 v19.4s, v2.4s, v3.4s\n" - - "ld1 {v0.4s}, [%[src_ptr0]], %[src_inc0]\n" - "ld1 {v1.4s}, [%[src_ptr1]], %[src_inc1]\n" - "ld1 {v2.4s}, [%[src_ptr2]], %[src_inc2]\n" - "ld1 {v3.4s}, [%[src_ptr3]], %[src_inc3]\n" - - "trn1 v20.2d, v16.2d, v18.2d\n" - "trn2 v22.2d, v16.2d, v18.2d\n" - "trn1 v21.2d, v17.2d, v19.2d\n" - "trn2 v23.2d, v17.2d, v19.2d\n" - "cmp w1, w2\n" - - "str q20, [%[packed_ptr], #0]\n" - "str q21, [%[packed_ptr], #32]\n" - "str q22, [%[packed_ptr], #64]\n" - "str q23, [%[packed_ptr], #96]\n" - - "add %[packed_ptr], %[packed_ptr], #128\n" - - "bne 1b\n" - - "2:\n" - - "trn1 v16.4s, v0.4s, v1.4s\n" - "trn2 v17.4s, v0.4s, v1.4s\n" - "trn1 v18.4s, v2.4s, v3.4s\n" - "trn2 v19.4s, v2.4s, v3.4s\n" - - "trn1 v20.2d, v16.2d, v18.2d\n" - "trn2 v22.2d, v16.2d, v18.2d\n" - "trn1 v21.2d, v17.2d, v19.2d\n" - "trn2 v23.2d, v17.2d, v19.2d\n" - - "str q20, [%[packed_ptr], #0]\n" - "str q21, [%[packed_ptr], #32]\n" - "str q22, [%[packed_ptr], #64]\n" - "str q23, [%[packed_ptr], #96]\n" - "add %[packed_ptr], %[packed_ptr], #128\n" - - "3:\n" - - "ands w2, %w[rows], #3\n" - "beq 4f\n" - "dup v0.16b, wzr\n" - "dup v1.16b, wzr\n" - "dup v2.16b, wzr\n" - "dup v3.16b, wzr\n" -#define RUY_LOAD_ONE_ROW(R) \ - "cmp w2, #" #R "\n" \ - "beq 5f\n" \ - "ld1 { v0.s }[" #R "], [%[src_ptr0]], #4\n" \ - "ld1 { v1.s }[" #R "], [%[src_ptr1]], #4\n" \ - "ld1 { v2.s }[" #R "], [%[src_ptr2]], #4\n" \ - "ld1 { v3.s }[" #R "], [%[src_ptr3]], #4\n" - - RUY_LOAD_ONE_ROW(0) - RUY_LOAD_ONE_ROW(1) - RUY_LOAD_ONE_ROW(2) - RUY_LOAD_ONE_ROW(3) -#undef RUY_LOAD_ONE_ROW - "5:\n" - - "trn1 v16.4s, v0.4s, v1.4s\n" - "trn2 v17.4s, v0.4s, v1.4s\n" - "trn1 v18.4s, v2.4s, v3.4s\n" - "trn2 v19.4s, v2.4s, v3.4s\n" - - "trn1 v20.2d, v16.2d, v18.2d\n" - "trn2 v22.2d, v16.2d, v18.2d\n" - "trn1 v21.2d, v17.2d, v19.2d\n" - "trn2 v23.2d, v17.2d, v19.2d\n" - - "mov x1, #32\n" - -#define RUY_STORE_ONE_ROW(ROW, REGISTER) \ - "cmp w2, #" #ROW "\n" \ - "beq 4f\n" \ - "st1 {" #REGISTER ".4s}, [%[packed_ptr]], x1\n" - - RUY_STORE_ONE_ROW(0, v20) - RUY_STORE_ONE_ROW(1, v21) - RUY_STORE_ONE_ROW(2, v22) - RUY_STORE_ONE_ROW(3, v23) - -#undef RUY_STORE_ONE_ROW - - "4:\n" - - // clang-format on - - : [ src_ptr0 ] "+r"(src_ptr0), [ src_ptr1 ] "+r"(src_ptr1), - [ src_ptr2 ] "+r"(src_ptr2), [ src_ptr3 ] "+r"(src_ptr3), - [ packed_ptr ] "+r"(packed_ptr) - : [ src_inc0 ] "r"(static_cast(src_inc0)), - [ src_inc1 ] "r"(static_cast(src_inc1)), - [ src_inc2 ] "r"(static_cast(src_inc2)), - [ src_inc3 ] "r"(static_cast(src_inc3)), - [ rows ] "r"(src_rows) - : "cc", "memory", "x1", "x2", "x10", "x11", "x12", "x13", "v0", "v1", - "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", - "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", - "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); -} -#endif - -#if RUY_PLATFORM(NEON_32) && RUY_OPT_ENABLED(RUY_OPT_ASM) -void PackFloatNeonOutOfOrder(const float* src_ptr0, const float* src_ptr1, - const float* src_ptr2, const float* src_ptr3, - int src_inc, int src_rows, int src_zero_point, - float* packed_ptr, int start_col, int end_col, - int output_stride) { - profiler::ScopeLabel label("Pack (kNeon, optimized for out-of-order cores)"); - asm volatile( - // clang-format off - "mov r1, #0\n" - "and r2, %[rows], #-4\n" - "cmp r1, r2\n" - "beq 3f\n" -#define RUY_LOAD_FOUR_BY_FOUR() \ - /* Load q0 */ \ - "vld1.32 {d0, d1}, [%[src_ptr0]]\n" \ - /* if src_inc0 != 0, add 16 to src_ptr0 */ \ - "and r3, %[src_inc], #1\n" \ - "add %[src_ptr0], %[src_ptr0], r3, lsl #4\n"\ - /* Load q1 */ \ - "vld1.32 {d2, d3}, [%[src_ptr1]]\n" \ - /* if src_inc1 != 0, add 16 to src_ptr0 */ \ - "and r3, %[src_inc], #2\n" \ - "add %[src_ptr1], %[src_ptr1], r3, lsl #3\n"\ - /* Load q2 */ \ - "vld1.32 {d4, d5}, [%[src_ptr2]]\n" \ - /* if src_inc2 != 0, add 16 to src_ptr0 */ \ - "and r3, %[src_inc], #4\n" \ - "add %[src_ptr2], %[src_ptr2], r3, lsl #2\n"\ - /* Load q3 */ \ - "vld1.32 {d6, d7}, [%[src_ptr3]]\n" \ - /* if src_inc3 != 0, add 16 to src_ptr0 */ \ - "and r3, %[src_inc], #8\n" \ - "add %[src_ptr3], %[src_ptr3], r3, lsl #1\n"\ - - RUY_LOAD_FOUR_BY_FOUR() - "add r1, r1, #4\n" - "cmp r1, r2\n" - - "beq 2f\n" - - "1:\n" - "add r1, r1, #4\n" - - // Transpose 4x4 matrix. - "vzip.32 q0, q1\n" - "vzip.32 q2, q3\n" - - "vtrn.32 q0, q2\n" - "vtrn.32 q1, q3\n" - - "vzip.32 q0, q2\n" - "vzip.32 q1, q3\n" - - "vmov q8, q0\n" - "vmov q9, q1\n" - "vmov q10, q2\n" - "vmov q11, q3\n" - - RUY_LOAD_FOUR_BY_FOUR() -#undef RUY_LOAD_FOUR_BY_FOUR - -#define RUY_STORE_FOUR_BY_FOUR() \ - /* Store q8, q10, q9, q11 */ \ - /* q8 = d16, d17 */ \ - "vst1.32 {d16, d17}, [%[packed_ptr]]\n" \ - /* q10 = d20, d21 */ \ - "add %[packed_ptr], %[packed_ptr], %[stride]\n" \ - "vst1.32 {d20, d21}, [%[packed_ptr]]\n" \ - /* q9 = d18, d19 */ \ - "add %[packed_ptr], %[packed_ptr], %[stride]\n" \ - "vst1.32 {d18, d19}, [%[packed_ptr]]\n" \ - /* q11 = d22, d23 */ \ - "add %[packed_ptr], %[packed_ptr], %[stride]\n" \ - "vst1.32 {d22, d23}, [%[packed_ptr]]\n" \ - "add %[packed_ptr], %[packed_ptr], %[stride]\n" \ - - RUY_STORE_FOUR_BY_FOUR() - "cmp r1, r2\n" - - "bne 1b\n" - - "2:\n" - - // Transpose 4x4 matrix. - "vzip.32 q0, q1\n" - "vzip.32 q2, q3\n" - - "vtrn.32 q0, q2\n" - "vtrn.32 q1, q3\n" - - "vzip.32 q0, q2\n" - "vzip.32 q1, q3\n" - - "vmov q8, q0\n" - "vmov q9, q1\n" - "vmov q10, q2\n" - "vmov q11, q3\n" - - RUY_STORE_FOUR_BY_FOUR() -#undef RUY_STORE_FOUR_BY_FOUR - "3:\n" - - "ands r2, %[rows], #3\n" - "beq 4f\n" - "mov r0, #0\n" - // Zero out q0 - q3 - "vdup.32 q0, r0\n" - "vdup.32 q1, r0\n" - "vdup.32 q2, r0\n" - "vdup.32 q3, r0\n" -#define RUY_LOAD_ONE_ROW_FIRST_HALF(R, I) \ - "cmp r2, #" #R "\n" \ - "beq 5f\n" \ - "vld1.32 { d0[" #I "] }, [%[src_ptr0]]!\n" \ - "vld1.32 { d2[" #I "] }, [%[src_ptr1]]!\n" \ - "vld1.32 { d4[" #I "] }, [%[src_ptr2]]!\n" \ - "vld1.32 { d6[" #I "] }, [%[src_ptr3]]!\n" - -#define RUY_LOAD_ONE_ROW_SECOND_HALF(R, I) \ - "vld1.32 { d1[" #I "] }, [%[src_ptr0]]!\n" \ - "vld1.32 { d3[" #I "] }, [%[src_ptr1]]!\n" \ - "vld1.32 { d5[" #I "] }, [%[src_ptr2]]!\n" \ - "vld1.32 { d7[" #I "] }, [%[src_ptr3]]!\n" - - RUY_LOAD_ONE_ROW_FIRST_HALF(0, 0) - RUY_LOAD_ONE_ROW_FIRST_HALF(1, 1) - RUY_LOAD_ONE_ROW_SECOND_HALF(2, 0) - RUY_LOAD_ONE_ROW_SECOND_HALF(3, 1) -#undef RUY_LOAD_ONE_ROW_SECOND_HALF -#undef RUY_LOAD_ONE_ROW_FIRST_HALF - "5:\n" - - // Transpose 4x4 matrix. - "vzip.32 q0, q1\n" - "vzip.32 q2, q3\n" - - "vtrn.32 q0, q2\n" - "vtrn.32 q1, q3\n" - - "vzip.32 q0, q2\n" - "vzip.32 q1, q3\n" - - "vmov q8, q0\n" - "vmov q9, q1\n" - "vmov q10, q2\n" - "vmov q11, q3\n" - - "mov r1, #32\n" - -#define RUY_STORE_ONE_ROW(ROW, REGISTER) \ - "cmp r2, #" #ROW "\n" \ - "beq 4f\n" \ - "vst1.32 {" #REGISTER "}, [%[packed_ptr]]\n" \ - "add %[packed_ptr], %[packed_ptr], %[stride]\n" - - // Store q8 - RUY_STORE_ONE_ROW(0, q8) - // Store q10 - RUY_STORE_ONE_ROW(1, q10) - // Store q9 - RUY_STORE_ONE_ROW(2, q9) - // Store q11 - RUY_STORE_ONE_ROW(3, q11) - -#undef RUY_STORE_ONE_ROW - - "4:\n" - - // clang-format on - : [ src_ptr0 ] "+r"(src_ptr0), [ src_ptr1 ] "+r"(src_ptr1), - [ src_ptr2 ] "+r"(src_ptr2), [ src_ptr3 ] "+r"(src_ptr3), - [ packed_ptr ] "+r"(packed_ptr) - : [ src_inc ] "r"(static_cast(src_inc)), - [ rows ] "r"(src_rows), [ stride ] "r"(output_stride) - : "cc", "memory", "r0", "r1", "r2", "r3", "q0", "q1", "q2", "q3", - "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11"); -} - -#endif // (RUY_PLATFORM(NEON_32) - -#if RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM) -void PackFloatNeonInOrder(const float* src_ptr0, const float* src_ptr1, - const float* src_ptr2, const float* src_ptr3, - int src_inc0, int src_inc1, int src_inc2, - int src_inc3, int src_rows, int src_zero_point, - float* packed_ptr, int start_col, int end_col) { - profiler::ScopeLabel label("Pack (kNeon, optimized for in-order cores)"); - - asm volatile( - // clang-format off - "mov w1, #0\n" - - "and w2, %w[rows], #-4\n" - "cmp w1, w2\n" - "beq 3f\n" - "ld1 {v0.4s}, [%[src_ptr0]], %[src_inc0]\n" - "ld1 {v1.4s}, [%[src_ptr1]], %[src_inc1]\n" - "ld1 {v2.4s}, [%[src_ptr2]], %[src_inc2]\n" - "ld1 {v3.4s}, [%[src_ptr3]], %[src_inc3]\n" - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #64]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #64]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #64]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #64]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #128]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #128]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #128]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #128]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #192]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #192]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #192]\n") - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #192]\n") - "add w1, w1, #4\n" - "cmp w1, w2\n" - - "beq 2f\n" - - "1:\n" - "add w1, w1, #4\n" - - "ldr x10, [%[src_ptr0], #8]\n" - "trn1 v16.4s, v0.4s, v1.4s\n" - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #240]\n") - "ldr x11, [%[src_ptr1], #8]\n" - "trn2 v17.4s, v0.4s, v1.4s\n" - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #240]\n") - "ldr x12, [%[src_ptr2], #8]\n" - "trn1 v18.4s, v2.4s, v3.4s\n" - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #240]\n") - "ldr x13, [%[src_ptr3], #8]\n" - "trn2 v19.4s, v2.4s, v3.4s\n" - RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #240]\n") - - "ld1 {v0.2s}, [%[src_ptr0]], %[src_inc0]\n" - "trn1 v20.2d, v16.2d, v18.2d\n" - "ld1 {v1.2s}, [%[src_ptr1]], %[src_inc1]\n" - "trn2 v22.2d, v16.2d, v18.2d\n" - "ld1 {v2.2s}, [%[src_ptr2]], %[src_inc2]\n" - "trn1 v21.2d, v17.2d, v19.2d\n" - "ld1 {v3.2s}, [%[src_ptr3]], %[src_inc3]\n" - "trn2 v23.2d, v17.2d, v19.2d\n" - "cmp w1, w2\n" - - "ins v0.d[1], x10\n" - "str q20, [%[packed_ptr], #0]\n" - "ins v1.d[1], x11\n" - "str q21, [%[packed_ptr], #32]\n" - "ins v2.d[1], x12\n" - "str q22, [%[packed_ptr], #64]\n" - "ins v3.d[1], x13\n" - "str q23, [%[packed_ptr], #96]\n" - - "add %[packed_ptr], %[packed_ptr], #128\n" - - "bne 1b\n" - - "2:\n" - - "trn1 v16.4s, v0.4s, v1.4s\n" - "trn2 v17.4s, v0.4s, v1.4s\n" - "trn1 v18.4s, v2.4s, v3.4s\n" - "trn2 v19.4s, v2.4s, v3.4s\n" - - "trn1 v20.2d, v16.2d, v18.2d\n" - "trn2 v22.2d, v16.2d, v18.2d\n" - "trn1 v21.2d, v17.2d, v19.2d\n" - "trn2 v23.2d, v17.2d, v19.2d\n" - - "str q20, [%[packed_ptr], #0]\n" - "str q21, [%[packed_ptr], #32]\n" - "str q22, [%[packed_ptr], #64]\n" - "str q23, [%[packed_ptr], #96]\n" - "add %[packed_ptr], %[packed_ptr], #128\n" - - "3:\n" - - "ands w2, %w[rows], #3\n" - "beq 4f\n" - "dup v0.16b, wzr\n" - "dup v1.16b, wzr\n" - "dup v2.16b, wzr\n" - "dup v3.16b, wzr\n" -#define RUY_LOAD_ONE_ROW(R) \ - "cmp w2, #" #R "\n" \ - "beq 5f\n" \ - "ld1 { v0.s }[" #R "], [%[src_ptr0]], #4\n" \ - "ld1 { v1.s }[" #R "], [%[src_ptr1]], #4\n" \ - "ld1 { v2.s }[" #R "], [%[src_ptr2]], #4\n" \ - "ld1 { v3.s }[" #R "], [%[src_ptr3]], #4\n" - - RUY_LOAD_ONE_ROW(0) - RUY_LOAD_ONE_ROW(1) - RUY_LOAD_ONE_ROW(2) - RUY_LOAD_ONE_ROW(3) -#undef RUY_LOAD_ONE_ROW - "5:\n" - - "trn1 v16.4s, v0.4s, v1.4s\n" - "trn2 v17.4s, v0.4s, v1.4s\n" - "trn1 v18.4s, v2.4s, v3.4s\n" - "trn2 v19.4s, v2.4s, v3.4s\n" - - "trn1 v20.2d, v16.2d, v18.2d\n" - "trn2 v22.2d, v16.2d, v18.2d\n" - "trn1 v21.2d, v17.2d, v19.2d\n" - "trn2 v23.2d, v17.2d, v19.2d\n" - - "mov x1, #32\n" - -#define RUY_STORE_ONE_ROW(ROW, REGISTER) \ - "cmp w2, #" #ROW "\n" \ - "beq 4f\n" \ - "st1 {" #REGISTER ".4s}, [%[packed_ptr]], x1\n" - - RUY_STORE_ONE_ROW(0, v20) - RUY_STORE_ONE_ROW(1, v21) - RUY_STORE_ONE_ROW(2, v22) - RUY_STORE_ONE_ROW(3, v23) - -#undef RUY_STORE_ONE_ROW - - "4:\n" - - // clang-format on - - : [ src_ptr0 ] "+r"(src_ptr0), [src_ptr1] "+r"(src_ptr1), [src_ptr2] "+r"(src_ptr2), - [src_ptr3] "+r"(src_ptr3), [packed_ptr] "+r"(packed_ptr) - : [ src_inc0 ] "r"(static_cast(src_inc0)), [src_inc1] "r"(static_cast(src_inc1)), [src_inc2] "r"(static_cast(src_inc2)), - [src_inc3] "r"(static_cast(src_inc3)), [rows] "r"(src_rows) - : "cc", "memory", "x1", "x2", "x10", "x11", "x12", "x13", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", - "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", - "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); -} -#endif // RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM) - -} // namespace ruy diff --git a/pack_arm.h b/pack_arm.h deleted file mode 100644 index d93475b..0000000 --- a/pack_arm.h +++ /dev/null @@ -1,497 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// # What is "packing"? -// -// Before feeding data to the gemm kernels (the parts of Ruy that do lots -// of multiply-add operations), Ruy first performs a data transformation (which -// we call "packing") on the input matrices. This transformation has two main -// goals: -// - rearrange data into blocks that are a convenient size/layout for the gemm -// kernels to consume. This helps make the memory access pattern of the gemm -// kernel simpler and more contiguous, and puts the data in a layout most -// convenient for specific arithmetic instructions in the gemm kernel. -// - compute row/column sums needed for handling quantization with non-symmetric -// zero points. -// -// # Simplified algorithmic analysis of packing -// -// Packing is a relatively simple transformation which does a small constant -// amount of work on each element of an input matrix, and hence for an NxM -// matrix performs O(N*M) work. If N and M are of the same order, then this is -// O(N^2) work. -// -// A NxKxM matrix multiplication requires N*K*M multiply-accumulate operations. -// Note that if N, K, and M are all the same order, then the number of -// multiply-accumulate operations is O(N^3). -// -// Thus, the O(N^2) cost of packing is small compared to the O(N^3) work, in the -// case of all dimensions being roughly the same order. -// -// # Packing cost can be significant -// -// When matrix * matrix multiplications begin to look more like matrix * vector -// multiplications, packing cost can become significant. We sometimes call these -// cases "gemv-like". -// -// Continuing the algorithmic analysis above, if we consider a case where an -// NxKxM matrix multiplication has either N = O(1) or M = O(1), then the -// situation is different. In this case, the multiply-accumulate work is only -// quadratic, so the quadratic cost of packing can be come significant. -// -// Another way to say this is that the cost of packing an input matrix (either -// the LHS or RHS) is amortized across the non-depth dimension of the opposite -// input matrix. Thus, when the LHS has very few rows or the RHS has very few -// columns, the cost of packing the opposite input matrix can become -// significant. -// -// As a rough rule of thumb, the cost of packing starts to become significant -// when either N or M is below 32 (and other dimensions are hundreds), with very -// significant packing costs at 8 or below. This varies by data type, Path, and -// tuning, so these numbers are only rough guides. -// -// One practical use case that is affected by this is inference of -// fully connected neural network layers with a low batch size. The weight -// matrix (which is a constant for inference) is the one affected by significant -// packing cost. -// -// Ruy provides an API in ruy_advanced.h for advanced users to pre-pack -// input matrices that are affected by significant packing costs. -// -// # Implementation notes -// -// Ruy's packing routines always operate on a range of columns and can be -// applied to either the LHS or RHS. This is possible because Ruy internally -// implements a TrMul, so the accumulation along depth is done along columns of -// both the LHS and RHS (whereas for a normal Mul the accumulation along depth -// for the LHS is along rows). As another example, we are always computing -// column sums for quantization (and never row sums, since the LHS is -// transposed). - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_PACK_ARM_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_PACK_ARM_H_ - -#include -#include - -#include "check_macros.h" -#include "common.h" -#include "internal_matrix.h" -#include "matrix.h" -#include "opt_set.h" -#include "pack_common.h" -#include "path.h" -#include "platform.h" -#include "profiler/instrumentation.h" -#include "tune.h" - -namespace ruy { - -#if RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM) -void Pack8bitNeonOutOfOrder(const void* src_ptr0, const void* src_ptr1, - const void* src_ptr2, const void* src_ptr3, - int src_inc0, int src_inc1, int src_inc2, - int src_inc3, int src_rows, int src_zero_point, - std::int8_t* packed_ptr, int start_col, int end_col, - std::int32_t* sums_ptr, int input_xor); -void Pack8bitNeonInOrder(const void* src_ptr0, const void* src_ptr1, - const void* src_ptr2, const void* src_ptr3, - int src_inc0, int src_inc1, int src_inc2, int src_inc3, - int src_rows, int src_zero_point, - std::int8_t* packed_ptr, int start_col, int end_col, - std::int32_t* sums_ptr, int input_xor); -void Pack8bitNeonDotprodOutOfOrder(const void* src_ptr0, const void* src_ptr1, - const void* src_ptr2, const void* src_ptr3, - int src_inc0, int src_inc1, int src_inc2, - int src_inc3, int src_rows, - int src_zero_point, std::int8_t* packed_ptr, - int start_col, int end_col, - std::int32_t* sums_ptr, int input_xor); -void Pack8bitNeonDotprodInOrder(const void* src_ptr0, const void* src_ptr1, - const void* src_ptr2, const void* src_ptr3, - int src_inc0, int src_inc1, int src_inc2, - int src_inc3, int src_rows, int src_zero_point, - std::int8_t* packed_ptr, int start_col, - int end_col, std::int32_t* sums_ptr, - int input_xor); - -#elif RUY_PLATFORM(NEON_32) && RUY_OPT_ENABLED(RUY_OPT_ASM) -void Pack8bitNeonOutOfOrder4Cols(const PackParams8bit& params); -void Pack8bitNeonOutOfOrder2Cols(const PackParams8bit& params); -#endif // (RUY_PLATFORM(NEON_64)&& RUY_OPT_ENABLED(RUY_OPT_ASM) - -#if (RUY_PLATFORM(NEON_32) || RUY_PLATFORM(NEON_64)) && \ - RUY_OPT_ENABLED(RUY_OPT_ASM) - -template -struct PackImpl, Scalar, - std::int8_t, std::int32_t> { - static_assert(std::is_same::value || - std::is_same::value, - ""); - static constexpr int kInputXor = - std::is_same::value ? 0 : 0x80; - - static void Run(Tuning tuning, const Matrix& src_matrix, - PackedMatrix* packed_matrix, int start_col, - int end_col) { - RUY_DCHECK(IsColMajor(src_matrix.layout)); - RUY_DCHECK(IsColMajor(packed_matrix->layout)); - RUY_DCHECK_EQ(start_col % 4, 0); - std::int32_t* sums = packed_matrix->sums; - Scalar zerobuf[16]; - memset(zerobuf, src_matrix.zero_point, sizeof(zerobuf)); - for (int block_col = start_col; block_col < end_col; block_col += 4) { - int src_stride = src_matrix.layout.stride; - const Scalar* src_ptr0 = src_matrix.data.get() + src_stride * block_col; - const Scalar* src_ptr1 = src_ptr0 + src_stride; - const Scalar* src_ptr2 = src_ptr1 + src_stride; - const Scalar* src_ptr3 = src_ptr2 + src_stride; - int src_inc0 = 16; - int src_inc1 = 16; - int src_inc2 = 16; - int src_inc3 = 16; - if (block_col >= src_matrix.layout.cols - 3) { - if (block_col >= src_matrix.layout.cols - 0) { - src_ptr0 = zerobuf; - src_inc0 = 0; - } - if (block_col >= src_matrix.layout.cols - 1) { - src_ptr1 = zerobuf; - src_inc1 = 0; - } - if (block_col >= src_matrix.layout.cols - 2) { - src_ptr2 = zerobuf; - src_inc2 = 0; - } - if (block_col >= src_matrix.layout.cols - 3) { - src_ptr3 = zerobuf; - src_inc3 = 0; - } - } - std::int8_t* packed_ptr = - packed_matrix->data + packed_matrix->layout.stride * block_col; - std::int32_t* sums_ptr = sums ? sums + block_col : nullptr; -#if RUY_PLATFORM(NEON_64) - if (__builtin_expect(tuning == Tuning::kInOrder, true)) { - Pack8bitNeonInOrder( - src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc0, src_inc1, - src_inc2, src_inc3, src_matrix.layout.rows, src_matrix.zero_point, - packed_ptr, start_col, end_col, sums_ptr, kInputXor); - } else { - Pack8bitNeonOutOfOrder( - src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc0, src_inc1, - src_inc2, src_inc3, src_matrix.layout.rows, src_matrix.zero_point, - packed_ptr, start_col, end_col, sums_ptr, kInputXor); - } -#else - // We have a more limited set of general purpose registers in ARMv7, so - // we use the "params" struct technique from the kernel code to save - // registers. - PackParams8bit params; - MakePackParams8bit(src_ptr0, src_ptr1, src_ptr2, src_ptr3, sums_ptr, - packed_ptr, src_inc0, src_inc1, src_inc2, src_inc3, - src_matrix.layout.rows, src_matrix.zero_point, - kInputXor, ¶ms); - Pack8bitNeonOutOfOrder4Cols(params); -#endif // RUY_PLATFORM(NEON_64) - } - } -}; - -#endif // (RUY_PLATFORM(NEON_32) || RUY_PLATFORM(NEON_64)) && - // RUY_OPT_ENABLED(RUY_OPT_ASM) - -#if RUY_PLATFORM(NEON_32) && RUY_OPT_ENABLED(RUY_OPT_ASM) -// The 32-bit float kernel is 4 rows X 2 columns, so we need an additional -// partial specialization for the RHS, which has a FixedKernelLayout with 2 -// columns. -template -struct PackImpl, Scalar, - std::int8_t, std::int32_t> { - static_assert(std::is_same::value || - std::is_same::value, - ""); - static constexpr int kInputXor = - std::is_same::value ? 0 : 0x80; - static void Run(Tuning tuning, const Matrix& src_matrix, - PackedMatrix* packed_matrix, int start_col, - int end_col) { - RUY_DCHECK(IsColMajor(src_matrix.layout)); - RUY_DCHECK(IsColMajor(packed_matrix->layout)); - RUY_DCHECK_EQ(start_col % 2, 0); - std::int32_t* sums = packed_matrix->sums; - Scalar zerobuf[16]; - memset(zerobuf, src_matrix.zero_point, sizeof(zerobuf)); - for (int block_col = start_col; block_col < end_col; block_col += 2) { - int src_stride = src_matrix.layout.stride; - const Scalar* src_ptr0 = src_matrix.data.get() + src_stride * block_col; - const Scalar* src_ptr1 = src_ptr0 + src_stride; - int src_inc0 = 16; - int src_inc1 = 16; - if (block_col >= src_matrix.layout.cols - 2) { - if (block_col >= src_matrix.layout.cols - 0) { - src_ptr0 = zerobuf; - src_inc0 = 0; - } - if (block_col >= src_matrix.layout.cols - 1) { - src_ptr1 = zerobuf; - src_inc1 = 0; - } - } - std::int8_t* packed_ptr = - packed_matrix->data + packed_matrix->layout.stride * block_col; - std::int32_t* sums_ptr = sums ? sums + block_col : nullptr; - PackParams8bit params; - MakePackParams8bit(src_ptr0, src_ptr1, nullptr, nullptr, sums_ptr, - packed_ptr, src_inc0, src_inc1, -1, -1, - src_matrix.layout.rows, src_matrix.zero_point, - kInputXor, ¶ms); - Pack8bitNeonOutOfOrder2Cols(params); - } - } -}; -#endif // (RUY_PLATFORM(NEON_32)) && RUY_OPT_ENABLED(RUY_OPT_ASM) - -#if RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM) -template -struct PackImpl, - Scalar, std::int8_t, std::int32_t> { - static_assert(std::is_same::value || - std::is_same::value, - ""); - static constexpr int kInputXor = - std::is_same::value ? 0 : 0x80; - - static void Run(Tuning tuning, const Matrix& src_matrix, - PackedMatrix* packed_matrix, int start_col, - int end_col) { - RUY_DCHECK(IsColMajor(src_matrix.layout)); - RUY_DCHECK(IsColMajor(packed_matrix->layout)); - RUY_DCHECK_EQ(start_col % 8, 0); - std::int32_t* sums = packed_matrix->sums; - Scalar zerobuf[16]; - memset(zerobuf, src_matrix.zero_point, sizeof(zerobuf)); - for (int block_col = start_col; block_col < end_col; block_col += 4) { - int src_stride = src_matrix.layout.stride; - const Scalar* src_ptr0 = src_matrix.data.get() + src_stride * block_col; - const Scalar* src_ptr1 = src_ptr0 + src_stride; - const Scalar* src_ptr2 = src_ptr1 + src_stride; - const Scalar* src_ptr3 = src_ptr2 + src_stride; - std::int64_t src_inc0 = 16; - std::int64_t src_inc1 = 16; - std::int64_t src_inc2 = 16; - std::int64_t src_inc3 = 16; - if (block_col >= src_matrix.layout.cols - 3) { - if (block_col >= src_matrix.layout.cols - 0) { - src_ptr0 = zerobuf; - src_inc0 = 0; - } - if (block_col >= src_matrix.layout.cols - 1) { - src_ptr1 = zerobuf; - src_inc1 = 0; - } - if (block_col >= src_matrix.layout.cols - 2) { - src_ptr2 = zerobuf; - src_inc2 = 0; - } - if (block_col >= src_matrix.layout.cols - 3) { - src_ptr3 = zerobuf; - src_inc3 = 0; - } - } - std::int8_t* packed_ptr = - packed_matrix->data + - packed_matrix->layout.stride * (block_col & ~7) + - ((block_col & 4) * 4); - std::int32_t* sums_ptr = sums ? sums + block_col : nullptr; - if (__builtin_expect(tuning == Tuning::kInOrder, true)) { - Pack8bitNeonDotprodInOrder( - src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc0, src_inc1, - src_inc2, src_inc3, src_matrix.layout.rows, src_matrix.zero_point, - packed_ptr, start_col, end_col, sums_ptr, kInputXor); - } else { - Pack8bitNeonDotprodOutOfOrder( - src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc0, src_inc1, - src_inc2, src_inc3, src_matrix.layout.rows, src_matrix.zero_point, - packed_ptr, start_col, end_col, sums_ptr, kInputXor); - } - } - } -}; -#endif // (RUY_PLATFORM(NEON_64)&& RUY_OPT_ENABLED(RUY_OPT_ASM) - -#if RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM) -void PackFloatNeonOutOfOrder(const float* src_ptr0, const float* src_ptr1, - const float* src_ptr2, const float* src_ptr3, - int src_inc0, int src_inc1, int src_inc2, - int src_inc3, int src_rows, int src_zero_point, - float* packed_ptr, int start_col, int end_col); -void PackFloatNeonInOrder(const float* src_ptr0, const float* src_ptr1, - const float* src_ptr2, const float* src_ptr3, - int src_inc0, int src_inc1, int src_inc2, - int src_inc3, int src_rows, int src_zero_point, - float* packed_ptr, int start_col, int end_col); - -#elif RUY_PLATFORM(NEON_32) && RUY_OPT_ENABLED(RUY_OPT_ASM) -void PackFloatNeonOutOfOrder(const float* src_ptr0, const float* src_ptr1, - const float* src_ptr2, const float* src_ptr3, - int src_inc, int src_rows, int src_zero_point, - float* packed_ptr, int start_col, int end_col, - int stride); -#endif // (RUY_PLATFORM(NEON_64)&& RUY_OPT_ENABLED(RUY_OPT_ASM) - -#if (RUY_PLATFORM(NEON_32) || RUY_PLATFORM(NEON_64)) && \ - RUY_OPT_ENABLED(RUY_OPT_ASM) - -template <> -struct PackImpl, float, - float, float> { - static void Run(Tuning tuning, const Matrix& src_matrix, - PackedMatrix* packed_matrix, int start_col, - int end_col) { - RUY_DCHECK(IsColMajor(src_matrix.layout)); - RUY_DCHECK(IsColMajor(packed_matrix->layout)); - RUY_DCHECK_EQ(start_col % 8, 0); - const float zerobuf[4] = {0}; - for (int block_col = start_col; block_col < end_col; block_col += 4) { - int src_stride = src_matrix.layout.stride; - const float* src_ptr0 = src_matrix.data.get() + src_stride * block_col; - const float* src_ptr1 = src_ptr0 + src_stride; - const float* src_ptr2 = src_ptr1 + src_stride; - const float* src_ptr3 = src_ptr2 + src_stride; - std::int64_t src_inc0 = 16; - std::int64_t src_inc1 = 16; - std::int64_t src_inc2 = 16; - std::int64_t src_inc3 = 16; - if (block_col >= src_matrix.layout.cols - 3) { - if (block_col >= src_matrix.layout.cols - 0) { - src_ptr0 = zerobuf; - src_inc0 = 0; - } - if (block_col >= src_matrix.layout.cols - 1) { - src_ptr1 = zerobuf; - src_inc1 = 0; - } - if (block_col >= src_matrix.layout.cols - 2) { - src_ptr2 = zerobuf; - src_inc2 = 0; - } - if (block_col >= src_matrix.layout.cols - 3) { - src_ptr3 = zerobuf; - src_inc3 = 0; - } - } - float* packed_ptr = packed_matrix->data + - packed_matrix->layout.stride * (block_col & ~7) + - ((block_col & 4)); -#if RUY_PLATFORM(NEON_64) - if (__builtin_expect(tuning == Tuning::kInOrder, true)) { - PackFloatNeonInOrder(src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc0, - src_inc1, src_inc2, src_inc3, - src_matrix.layout.rows, src_matrix.zero_point, - packed_ptr, start_col, end_col); - } else { - PackFloatNeonOutOfOrder(src_ptr0, src_ptr1, src_ptr2, src_ptr3, - src_inc0, src_inc1, src_inc2, src_inc3, - src_matrix.layout.rows, src_matrix.zero_point, - packed_ptr, start_col, end_col); - } -#else - // Encode each of src_inc0, ..., src_inc3 in lowest 4 bits of src_inc - // to save on registers (we have fewer general purpose registers in - // 32-bit ARM than in 64-bit ARM). For the 64-bit case, we pass four - // values that are each either 16 or 0 and use them directly. For the - // 32-bit case, bits 0, 1, 2, and 3 are used to determine if we should - // use the value 16 (bit is set) or 0 (bit is not set) for the - // respective increment value. - std::int64_t src_inc = 0; - src_inc += src_inc0 == 16 ? 1 : 0; - src_inc += src_inc1 == 16 ? 2 : 0; - src_inc += src_inc2 == 16 ? 4 : 0; - src_inc += src_inc3 == 16 ? 8 : 0; - const int kOutputStride = 32; - PackFloatNeonOutOfOrder(src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc, - src_matrix.layout.rows, src_matrix.zero_point, - packed_ptr, start_col, end_col, kOutputStride); -#endif // RUY_PLATFORM(NEON_64) - } - } -}; - -#if RUY_PLATFORM(NEON_32) -// The 32-bit float kernel is 8 rows X 4 columns, so we need an additional -// specialization for a FixedKernelLayout with 4 columns. -template <> -struct PackImpl, float, - float, float> { - static void Run(Tuning tuning, const Matrix& src_matrix, - PackedMatrix* packed_matrix, int start_col, - int end_col) { - RUY_DCHECK(IsColMajor(src_matrix.layout)); - RUY_DCHECK(IsColMajor(packed_matrix->layout)); - RUY_DCHECK_EQ(start_col % 4, 0); - const float zerobuf[4] = {0}; - for (int block_col = start_col; block_col < end_col; block_col += 4) { - int src_stride = src_matrix.layout.stride; - const float* src_ptr0 = src_matrix.data.get() + src_stride * block_col; - const float* src_ptr1 = src_ptr0 + src_stride; - const float* src_ptr2 = src_ptr1 + src_stride; - const float* src_ptr3 = src_ptr2 + src_stride; - std::int64_t src_inc0 = 16; - std::int64_t src_inc1 = 16; - std::int64_t src_inc2 = 16; - std::int64_t src_inc3 = 16; - if (block_col >= src_matrix.layout.cols - 3) { - if (block_col >= src_matrix.layout.cols - 0) { - src_ptr0 = zerobuf; - src_inc0 = 0; - } - if (block_col >= src_matrix.layout.cols - 1) { - src_ptr1 = zerobuf; - src_inc1 = 0; - } - if (block_col >= src_matrix.layout.cols - 2) { - src_ptr2 = zerobuf; - src_inc2 = 0; - } - if (block_col >= src_matrix.layout.cols - 3) { - src_ptr3 = zerobuf; - src_inc3 = 0; - } - } - float* packed_ptr = - packed_matrix->data + packed_matrix->layout.stride * (block_col); - // Encode each of src_inc0, ..., src_inc1 in lowest 4 bits of scrc_inc - // to save registers. - std::int64_t src_inc = 0; - src_inc += src_inc0 == 16 ? 1 : 0; - src_inc += src_inc1 == 16 ? 2 : 0; - src_inc += src_inc2 == 16 ? 4 : 0; - src_inc += src_inc3 == 16 ? 8 : 0; - const int kOutputStride = 16; - PackFloatNeonOutOfOrder(src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc, - src_matrix.layout.rows, src_matrix.zero_point, - packed_ptr, start_col, end_col, kOutputStride); - } - } -}; -#endif // (RUY_PLATFORM(NEON_32)) -#endif // (RUY_PLATFORM(NEON_64) || RUY_PLATFORM(NEON_32)) && \ - // RUY_OPT_ENABLED(RUY_OPT_ASM) - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_PACK_ARM_H_ diff --git a/pack_avx2.cc b/pack_avx2.cc deleted file mode 100644 index 65efaab..0000000 --- a/pack_avx2.cc +++ /dev/null @@ -1,816 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include - -#include "check_macros.h" -#include "matrix.h" -#include "opt_set.h" -#include "pack.h" -#include "path.h" -#include "platform.h" -#include "profiler/instrumentation.h" - -#if RUY_PLATFORM(AVX2) && RUY_OPT_ENABLED(RUY_OPT_INTRINSICS) -#include // IWYU pragma: keep -#endif - -namespace ruy { - -#if !(RUY_PLATFORM(AVX2) && RUY_OPT_ENABLED(RUY_OPT_ASM)) - -void Pack8bitAvx2(const std::int8_t* src_ptr, std::int8_t input_xor, - const std::int8_t* zerobuf, int src_stride, - int remaining_src_cols, int src_rows, std::int8_t* packed_ptr, - std::int32_t* sums_ptr) { - // CPU-ID-based checks should disable the path that would reach this point. - RUY_DCHECK(false); -} - -void PackFloatAvx2(const float* src_ptr, const float* zerobuf, int src_stride, - int remaining_src_cols, int src_rows, float* packed_ptr) { - // CPU-ID-based checks should disable the path that would reach this point. - RUY_DCHECK(false); -} - -#else // RUY_PLATFORM(AVX2) && RUY_OPT_ENABLED(RUY_OPT_ASM) - -// The first int8_t template parameter is arbitrary: this routine is common to -// all 8-bit source matrix types. -using PackImpl8bitAvx2 = - PackImpl, - std::int8_t, std::int8_t, std::int32_t>; - -using PackImplFloatAvx2 = - PackImpl, float, - float, float>; - -namespace { - -inline __m256i MaskLoadu(int available_src_rows, std::int8_t zero_point, - const std::int8_t* addr) { - RUY_DCHECK_LT(available_src_rows, 32); - __m256i padded_data; - - if (available_src_rows >= 16) { - __m128i load_hi = _mm_set1_epi8(zero_point); - __m128i load_lo = _mm_loadu_si128(reinterpret_cast(addr)); - memcpy(&load_hi, addr + 16, available_src_rows - 16); - padded_data = _mm256_set_m128i(load_hi, load_lo); - } else { - __m128i load_hi = _mm_set1_epi8(zero_point); - __m128i load_lo = load_hi; - memcpy(&load_lo, addr, available_src_rows); - padded_data = _mm256_set_m128i(load_hi, load_lo); - } - return padded_data; -} - -inline void Pack8bitAvx2Packer(const std::int8_t* src_ptr, - std::int8_t input_xor, - const std::int8_t* zerobuf, int src_stride, - int remaining_src_cols, int src_rows, - std::int8_t* packed_ptr, std::int32_t* sums_ptr, - std::int8_t* trailing_buf) { - using Layout = PackImpl8bitAvx2::Layout; - RUY_DCHECK_EQ(Layout::kCols, 8); - RUY_DCHECK_EQ(Layout::kRows, 4); - // Each Layout::Rows is 4 contiguous input, contiguous packed elements. - // We process 8 of these chunks at a time, padding short input chunks. - constexpr int kNumRowChunks = 8; - constexpr int kNumChunkedSrcRows = kNumRowChunks * Layout::kRows; - - const std::int8_t* src_ptr0 = src_ptr; - const std::int8_t* src_ptr1 = src_ptr0 + src_stride; - const std::int8_t* src_ptr2 = src_ptr1 + src_stride; - const std::int8_t* src_ptr3 = src_ptr2 + src_stride; - const std::int8_t* src_ptr4 = src_ptr3 + src_stride; - const std::int8_t* src_ptr5 = src_ptr4 + src_stride; - const std::int8_t* src_ptr6 = src_ptr5 + src_stride; - const std::int8_t* src_ptr7 = src_ptr6 + src_stride; - std::int64_t src_inc0 = kNumChunkedSrcRows; - std::int64_t src_inc1 = kNumChunkedSrcRows; - std::int64_t src_inc2 = kNumChunkedSrcRows; - std::int64_t src_inc3 = kNumChunkedSrcRows; - std::int64_t src_inc4 = kNumChunkedSrcRows; - std::int64_t src_inc5 = kNumChunkedSrcRows; - std::int64_t src_inc6 = kNumChunkedSrcRows; - std::int64_t src_inc7 = kNumChunkedSrcRows; - // Handle cases where source does not have Layout::kCols (8) columns. - if (remaining_src_cols < 8) { - if (remaining_src_cols <= 0) { - src_ptr0 = zerobuf; - src_inc0 = 0; - } - if (remaining_src_cols <= 1) { - src_ptr1 = zerobuf; - src_inc1 = 0; - } - if (remaining_src_cols <= 2) { - src_ptr2 = zerobuf; - src_inc2 = 0; - } - if (remaining_src_cols <= 3) { - src_ptr3 = zerobuf; - src_inc3 = 0; - } - if (remaining_src_cols <= 4) { - src_ptr4 = zerobuf; - src_inc4 = 0; - } - if (remaining_src_cols <= 5) { - src_ptr5 = zerobuf; - src_inc5 = 0; - } - if (remaining_src_cols <= 6) { - src_ptr6 = zerobuf; - src_inc6 = 0; - } - src_ptr7 = zerobuf; - src_inc7 = 0; - } - - const std::int8_t zero_point = zerobuf[0]; - - if (sums_ptr) { - // i: Layout::kCols. - for (int i = 0; i < 8; ++i) { - sums_ptr[i] = 0; - } - } - std::int32_t sums_adjustment = 0; - const __m256i ones_16bit = _mm256_set1_epi16(1); - __m256i sums_4x2_32bit_lo = _mm256_set1_epi32(0); - __m256i sums_4x2_32bit_hi = _mm256_set1_epi32(0); - - // The overall packing effectively pads the source rows to - // (src_rows + 63) & ~63. The iteration over k may skip when m=1, and then we - // only pack for (src_rows + 31) & ~31. When there is an incomplete - // destination block, this is stored into trailing_buf instead of packed_ptr. - for (int k = 0; k < src_rows; k += kNumChunkedSrcRows) { - // Available source rows. - // If this is less than 0 (for m=1), we skip, having filled trailing - // buffer for m=0. Also, if source rows is zero on m=1, then we filled - // exactly to the end of the column in the packed buffer. - const int available_src_rows = src_rows - k; - // Effectively, - // available rows = std::max(0, std::min(8, src_rows - k)); - // treat each case separately. - if (available_src_rows >= kNumChunkedSrcRows) { - if (sums_ptr) { - __m256i t0, t1, t2, t3, t4, t5, t6, t7; - __m256i r0, r1, r2, r3, r4, r5, r6, r7; - const __m256i input_xor_v = _mm256_set1_epi8(input_xor); - - t0 = _mm256_loadu_si256(reinterpret_cast(src_ptr0)); - t4 = _mm256_loadu_si256(reinterpret_cast(src_ptr4)); - t1 = _mm256_loadu_si256(reinterpret_cast(src_ptr1)); - t5 = _mm256_loadu_si256(reinterpret_cast(src_ptr5)); - t2 = _mm256_loadu_si256(reinterpret_cast(src_ptr2)); - t6 = _mm256_loadu_si256(reinterpret_cast(src_ptr6)); - t3 = _mm256_loadu_si256(reinterpret_cast(src_ptr3)); - t7 = _mm256_loadu_si256(reinterpret_cast(src_ptr7)); - - r0 = _mm256_unpacklo_epi32(t0, t1); - r4 = _mm256_unpacklo_epi32(t4, t5); - r2 = _mm256_unpackhi_epi32(t0, t1); - r6 = _mm256_unpackhi_epi32(t4, t5); - r1 = _mm256_unpacklo_epi32(t2, t3); - r5 = _mm256_unpacklo_epi32(t6, t7); - r3 = _mm256_unpackhi_epi32(t2, t3); - r7 = _mm256_unpackhi_epi32(t6, t7); - - t0 = _mm256_unpacklo_epi64(r0, r1); - t4 = _mm256_unpacklo_epi64(r4, r5); - t2 = _mm256_unpackhi_epi64(r0, r1); - t6 = _mm256_unpackhi_epi64(r4, r5); - t1 = _mm256_unpacklo_epi64(r2, r3); - t5 = _mm256_unpacklo_epi64(r6, r7); - t3 = _mm256_unpackhi_epi64(r2, r3); - t7 = _mm256_unpackhi_epi64(r6, r7); - - // The preceding sets of rearrangement operations interleaved by 4 bytes - // and then by 8 bytes *within* lanes. The following set interleave by - // 16 bytes (128-bit), operating *between* AVX lanes. For instance (t0, - // t4) are interleaved to create (r0, r1). This complexity follows from - // the way that AVX is centered around MM 128-bit lanes. - r0 = _mm256_permute2x128_si256(t0, t4, 0x20); - r4 = _mm256_permute2x128_si256(t1, t5, 0x20); - r1 = _mm256_permute2x128_si256(t0, t4, 0x31); - r5 = _mm256_permute2x128_si256(t1, t5, 0x31); - r2 = _mm256_permute2x128_si256(t2, t6, 0x20); - r6 = _mm256_permute2x128_si256(t3, t7, 0x20); - r3 = _mm256_permute2x128_si256(t2, t6, 0x31); - r7 = _mm256_permute2x128_si256(t3, t7, 0x31); - - r0 = _mm256_xor_si256(r0, input_xor_v); - r1 = _mm256_xor_si256(r1, input_xor_v); - r2 = _mm256_xor_si256(r2, input_xor_v); - r3 = _mm256_xor_si256(r3, input_xor_v); - r4 = _mm256_xor_si256(r4, input_xor_v); - r5 = _mm256_xor_si256(r5, input_xor_v); - r6 = _mm256_xor_si256(r6, input_xor_v); - r7 = _mm256_xor_si256(r7, input_xor_v); - - __m256i sums_4x4_16bit_lo; - sums_4x4_16bit_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r0)); - sums_4x4_16bit_lo = - _mm256_add_epi16(sums_4x4_16bit_lo, - _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r1))); - sums_4x4_16bit_lo = - _mm256_add_epi16(sums_4x4_16bit_lo, - _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r2))); - sums_4x4_16bit_lo = - _mm256_add_epi16(sums_4x4_16bit_lo, - _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r3))); - sums_4x4_16bit_lo = - _mm256_add_epi16(sums_4x4_16bit_lo, - _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r4))); - sums_4x4_16bit_lo = - _mm256_add_epi16(sums_4x4_16bit_lo, - _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r5))); - sums_4x4_16bit_lo = - _mm256_add_epi16(sums_4x4_16bit_lo, - _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r6))); - sums_4x4_16bit_lo = - _mm256_add_epi16(sums_4x4_16bit_lo, - _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r7))); - - // The sums have been performed across columns, and now we have 4x16-bit - // sums packed together. We use madd for pairwise 32-bit sums. - const __m256i sums_4x2_32bit_lo_new = - _mm256_madd_epi16(sums_4x4_16bit_lo, ones_16bit); - sums_4x2_32bit_lo = - _mm256_add_epi32(sums_4x2_32bit_lo, sums_4x2_32bit_lo_new); - - __m256i sums_4x4_16bit_hi; - sums_4x4_16bit_hi = - _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r0, 1)); - sums_4x4_16bit_hi = _mm256_add_epi16( - sums_4x4_16bit_hi, - _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r1, 1))); - sums_4x4_16bit_hi = _mm256_add_epi16( - sums_4x4_16bit_hi, - _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r2, 1))); - sums_4x4_16bit_hi = _mm256_add_epi16( - sums_4x4_16bit_hi, - _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r3, 1))); - sums_4x4_16bit_hi = _mm256_add_epi16( - sums_4x4_16bit_hi, - _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r4, 1))); - sums_4x4_16bit_hi = _mm256_add_epi16( - sums_4x4_16bit_hi, - _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r5, 1))); - sums_4x4_16bit_hi = _mm256_add_epi16( - sums_4x4_16bit_hi, - _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r6, 1))); - sums_4x4_16bit_hi = _mm256_add_epi16( - sums_4x4_16bit_hi, - _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r7, 1))); - - const __m256i sums_4x2_32bit_hi_new = - _mm256_madd_epi16(sums_4x4_16bit_hi, ones_16bit); - sums_4x2_32bit_hi = - _mm256_add_epi32(sums_4x2_32bit_hi, sums_4x2_32bit_hi_new); - - _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 0 * 8 * 4), - r0); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 2 * 8 * 4), - r4); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 4 * 8 * 4), - r1); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 6 * 8 * 4), - r5); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 1 * 8 * 4), - r2); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 3 * 8 * 4), - r6); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 5 * 8 * 4), - r3); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 7 * 8 * 4), - r7); - } else { - __m256i t0, t1, t2, t3, t4, t5, t6, t7; - __m256i r0, r1, r2, r3, r4, r5, r6, r7; - const __m256i input_xor_v = _mm256_set1_epi8(input_xor); - - t0 = _mm256_loadu_si256(reinterpret_cast(src_ptr0)); - t4 = _mm256_loadu_si256(reinterpret_cast(src_ptr4)); - t1 = _mm256_loadu_si256(reinterpret_cast(src_ptr1)); - t5 = _mm256_loadu_si256(reinterpret_cast(src_ptr5)); - t2 = _mm256_loadu_si256(reinterpret_cast(src_ptr2)); - t6 = _mm256_loadu_si256(reinterpret_cast(src_ptr6)); - t3 = _mm256_loadu_si256(reinterpret_cast(src_ptr3)); - t7 = _mm256_loadu_si256(reinterpret_cast(src_ptr7)); - - r0 = _mm256_unpacklo_epi32(t0, t1); - r4 = _mm256_unpacklo_epi32(t4, t5); - r2 = _mm256_unpackhi_epi32(t0, t1); - r6 = _mm256_unpackhi_epi32(t4, t5); - r1 = _mm256_unpacklo_epi32(t2, t3); - r5 = _mm256_unpacklo_epi32(t6, t7); - r3 = _mm256_unpackhi_epi32(t2, t3); - r7 = _mm256_unpackhi_epi32(t6, t7); - - t0 = _mm256_unpacklo_epi64(r0, r1); - t4 = _mm256_unpacklo_epi64(r4, r5); - t2 = _mm256_unpackhi_epi64(r0, r1); - t6 = _mm256_unpackhi_epi64(r4, r5); - t1 = _mm256_unpacklo_epi64(r2, r3); - t5 = _mm256_unpacklo_epi64(r6, r7); - t3 = _mm256_unpackhi_epi64(r2, r3); - t7 = _mm256_unpackhi_epi64(r6, r7); - - // The preceding sets of rearrangement operations interleaved by 4 bytes - // and then by 8 bytes *within* lanes. The following set interleave by - // 16 bytes (128-bit), operating *between* AVX lanes. For instance (t0, - // t4) are interleaved to create (r0, r1). This complexity follows from - // the way that AVX is centered around MM 128-bit lanes. - r0 = _mm256_permute2x128_si256(t0, t4, 0x20); - r4 = _mm256_permute2x128_si256(t1, t5, 0x20); - r1 = _mm256_permute2x128_si256(t0, t4, 0x31); - r5 = _mm256_permute2x128_si256(t1, t5, 0x31); - r2 = _mm256_permute2x128_si256(t2, t6, 0x20); - r6 = _mm256_permute2x128_si256(t3, t7, 0x20); - r3 = _mm256_permute2x128_si256(t2, t6, 0x31); - r7 = _mm256_permute2x128_si256(t3, t7, 0x31); - - r0 = _mm256_xor_si256(r0, input_xor_v); - r1 = _mm256_xor_si256(r1, input_xor_v); - r2 = _mm256_xor_si256(r2, input_xor_v); - r3 = _mm256_xor_si256(r3, input_xor_v); - r4 = _mm256_xor_si256(r4, input_xor_v); - r5 = _mm256_xor_si256(r5, input_xor_v); - r6 = _mm256_xor_si256(r6, input_xor_v); - r7 = _mm256_xor_si256(r7, input_xor_v); - - _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 0 * 8 * 4), - r0); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 2 * 8 * 4), - r4); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 4 * 8 * 4), - r1); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 6 * 8 * 4), - r5); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 1 * 8 * 4), - r2); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 3 * 8 * 4), - r6); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 5 * 8 * 4), - r3); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 7 * 8 * 4), - r7); - } - } else if (available_src_rows > 0) { - RUY_DCHECK_LT(available_src_rows, kNumChunkedSrcRows); - // We do not care what goes into the trailing buffer, but we want - // in_data[...] ^ input_xor == 0 for irrelevant values in the summation. - // - // We compensate for padding-with-zero_point by initializing the - // summations with the compensating offset, effectively - // ((input_xor ^ input_xor) - (zero_point ^ input_xor)) * - // 4 * (8 - ((available_src_rows + 3) >> 2)). - // - // Note that (zero_point ^ input_xor) is performed in 8-bits and then - // cast. - sums_adjustment += - -(zero_point ^ input_xor) * 4 * (8 - ((available_src_rows + 3) >> 2)); - - __m256i t0, t1, t2, t3, t4, t5, t6, t7; - __m256i r0, r1, r2, r3, r4, r5, r6, r7; - const __m256i input_xor_v = _mm256_set1_epi8(input_xor); - - t0 = MaskLoadu(available_src_rows, zero_point, src_ptr0); - t4 = MaskLoadu(available_src_rows, zero_point, src_ptr4); - t1 = MaskLoadu(available_src_rows, zero_point, src_ptr1); - t5 = MaskLoadu(available_src_rows, zero_point, src_ptr5); - t2 = MaskLoadu(available_src_rows, zero_point, src_ptr2); - t6 = MaskLoadu(available_src_rows, zero_point, src_ptr6); - t3 = MaskLoadu(available_src_rows, zero_point, src_ptr3); - t7 = MaskLoadu(available_src_rows, zero_point, src_ptr7); - - r0 = _mm256_unpacklo_epi32(t0, t1); - r4 = _mm256_unpacklo_epi32(t4, t5); - r2 = _mm256_unpackhi_epi32(t0, t1); - r6 = _mm256_unpackhi_epi32(t4, t5); - r1 = _mm256_unpacklo_epi32(t2, t3); - r5 = _mm256_unpacklo_epi32(t6, t7); - r3 = _mm256_unpackhi_epi32(t2, t3); - r7 = _mm256_unpackhi_epi32(t6, t7); - - t0 = _mm256_unpacklo_epi64(r0, r1); - t4 = _mm256_unpacklo_epi64(r4, r5); - t2 = _mm256_unpackhi_epi64(r0, r1); - t6 = _mm256_unpackhi_epi64(r4, r5); - t1 = _mm256_unpacklo_epi64(r2, r3); - t5 = _mm256_unpacklo_epi64(r6, r7); - t3 = _mm256_unpackhi_epi64(r2, r3); - t7 = _mm256_unpackhi_epi64(r6, r7); - - // The preceding sets of rearrangement operations interleaved by 4 bytes - // and then by 8 bytes *within* lanes. The following set interleave by - // 16 bytes (128-bit), operating *between* AVX lanes. For instance (t0, - // t4) are interleaved to create (r0, r1). This complexity follows from - // the way that AVX is centered around MM 128-bit lanes. - r0 = _mm256_permute2x128_si256(t0, t4, 0x20); - r4 = _mm256_permute2x128_si256(t1, t5, 0x20); - r1 = _mm256_permute2x128_si256(t0, t4, 0x31); - r5 = _mm256_permute2x128_si256(t1, t5, 0x31); - r2 = _mm256_permute2x128_si256(t2, t6, 0x20); - r6 = _mm256_permute2x128_si256(t3, t7, 0x20); - r3 = _mm256_permute2x128_si256(t2, t6, 0x31); - r7 = _mm256_permute2x128_si256(t3, t7, 0x31); - - r0 = _mm256_xor_si256(r0, input_xor_v); - r1 = _mm256_xor_si256(r1, input_xor_v); - r2 = _mm256_xor_si256(r2, input_xor_v); - r3 = _mm256_xor_si256(r3, input_xor_v); - r4 = _mm256_xor_si256(r4, input_xor_v); - r5 = _mm256_xor_si256(r5, input_xor_v); - r6 = _mm256_xor_si256(r6, input_xor_v); - r7 = _mm256_xor_si256(r7, input_xor_v); - - __m256i sums_4x4_16bit_lo; - sums_4x4_16bit_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r0)); - sums_4x4_16bit_lo = _mm256_add_epi16( - sums_4x4_16bit_lo, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r1))); - sums_4x4_16bit_lo = _mm256_add_epi16( - sums_4x4_16bit_lo, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r2))); - sums_4x4_16bit_lo = _mm256_add_epi16( - sums_4x4_16bit_lo, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r3))); - sums_4x4_16bit_lo = _mm256_add_epi16( - sums_4x4_16bit_lo, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r4))); - sums_4x4_16bit_lo = _mm256_add_epi16( - sums_4x4_16bit_lo, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r5))); - sums_4x4_16bit_lo = _mm256_add_epi16( - sums_4x4_16bit_lo, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r6))); - sums_4x4_16bit_lo = _mm256_add_epi16( - sums_4x4_16bit_lo, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r7))); - - // The sums have been performed across columns, and now we have 4x16-bit - // sums packed together. We use madd for pairwise 32-bit sums. - const __m256i sums_4x2_32bit_lo_new = - _mm256_madd_epi16(sums_4x4_16bit_lo, ones_16bit); - sums_4x2_32bit_lo = - _mm256_add_epi32(sums_4x2_32bit_lo, sums_4x2_32bit_lo_new); - - __m256i sums_4x4_16bit_hi; - sums_4x4_16bit_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r0, 1)); - sums_4x4_16bit_hi = _mm256_add_epi16( - sums_4x4_16bit_hi, - _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r1, 1))); - sums_4x4_16bit_hi = _mm256_add_epi16( - sums_4x4_16bit_hi, - _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r2, 1))); - sums_4x4_16bit_hi = _mm256_add_epi16( - sums_4x4_16bit_hi, - _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r3, 1))); - sums_4x4_16bit_hi = _mm256_add_epi16( - sums_4x4_16bit_hi, - _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r4, 1))); - sums_4x4_16bit_hi = _mm256_add_epi16( - sums_4x4_16bit_hi, - _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r5, 1))); - sums_4x4_16bit_hi = _mm256_add_epi16( - sums_4x4_16bit_hi, - _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r6, 1))); - sums_4x4_16bit_hi = _mm256_add_epi16( - sums_4x4_16bit_hi, - _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r7, 1))); - - const __m256i sums_4x2_32bit_hi_new = - _mm256_madd_epi16(sums_4x4_16bit_hi, ones_16bit); - sums_4x2_32bit_hi = - _mm256_add_epi32(sums_4x2_32bit_hi, sums_4x2_32bit_hi_new); - - _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 0 * 8 * 4), - r0); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 2 * 8 * 4), - r4); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 4 * 8 * 4), - r1); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 6 * 8 * 4), - r5); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 1 * 8 * 4), - r2); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 3 * 8 * 4), - r6); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 5 * 8 * 4), - r3); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 7 * 8 * 4), - r7); - } - - packed_ptr += 8 * kNumChunkedSrcRows; - src_ptr0 += src_inc0; - src_ptr1 += src_inc1; - src_ptr2 += src_inc2; - src_ptr3 += src_inc3; - src_ptr4 += src_inc4; - src_ptr5 += src_inc5; - src_ptr6 += src_inc6; - src_ptr7 += src_inc7; - } - - if (sums_ptr) { - const __m256i sums_adjustment_v = _mm256_set1_epi32(sums_adjustment); - - __m256i sums = - _mm256_loadu_si256(reinterpret_cast(sums_ptr)); - const __m256i idx = _mm256_set_epi32(7, 5, 3, 1, 6, 4, 2, 0); - - // We earlier used madd for pairwise 32-bit sums, and now we deinterlace the - // neighbours, finshing up by adding them to the stored accumulated sums. - const __m256i sums_2x4_32bit_lo = - _mm256_permutevar8x32_epi32(sums_4x2_32bit_lo, idx); - const __m256i sums_2x4_32bit_hi = - _mm256_permutevar8x32_epi32(sums_4x2_32bit_hi, idx); - const __m256i sums_2x4_32bit_a = - _mm256_permute2x128_si256(sums_2x4_32bit_lo, sums_2x4_32bit_hi, 0x20); - const __m256i sums_2x4_32bit_b = - _mm256_permute2x128_si256(sums_2x4_32bit_lo, sums_2x4_32bit_hi, 0x31); - sums = _mm256_add_epi32(sums, sums_adjustment_v); - sums = _mm256_add_epi32(sums, sums_2x4_32bit_a); - sums = _mm256_add_epi32(sums, sums_2x4_32bit_b); - - _mm256_storeu_si256(reinterpret_cast<__m256i*>(sums_ptr), sums); - } -} - -inline __m256 Mm256UnpackloPsx2(const __m256 a, const __m256 b) { - return _mm256_castpd_ps( - _mm256_unpacklo_pd(_mm256_castps_pd(a), _mm256_castps_pd(b))); -} - -inline __m256 Mm256UnpackhiPsx2(const __m256 a, const __m256 b) { - return _mm256_castpd_ps( - _mm256_unpackhi_pd(_mm256_castps_pd(a), _mm256_castps_pd(b))); -} - -inline void PackFloatAvx2Packer(const float* src_ptr, const float* zerobuf, - int src_stride, int remaining_src_cols, - int src_rows, float* packed_ptr, - float* trailing_buf) { - RUY_DCHECK_EQ(PackImplFloatAvx2::Layout::kCols, 8); - RUY_DCHECK_EQ(PackImplFloatAvx2::Layout::kRows, 1); - - // This packing amounts to transposition of 8x8 blocks. - static constexpr int kPackCols = 8; // Source cols packed together. - static constexpr int kPackRows = 8; // Short input is padded. - - const float* src_ptr0 = src_ptr; - const float* src_ptr1 = src_ptr0 + src_stride; - const float* src_ptr2 = src_ptr1 + src_stride; - const float* src_ptr3 = src_ptr2 + src_stride; - const float* src_ptr4 = src_ptr3 + src_stride; - const float* src_ptr5 = src_ptr4 + src_stride; - const float* src_ptr6 = src_ptr5 + src_stride; - const float* src_ptr7 = src_ptr6 + src_stride; - std::int64_t src_inc0 = 8; - std::int64_t src_inc1 = 8; - std::int64_t src_inc2 = 8; - std::int64_t src_inc3 = 8; - std::int64_t src_inc4 = 8; - std::int64_t src_inc5 = 8; - std::int64_t src_inc6 = 8; - std::int64_t src_inc7 = 8; - // Handle cases where source does not have kPackDim (8) columns. - if (remaining_src_cols < kPackCols) { - if (remaining_src_cols <= 0) { - src_ptr0 = zerobuf; - src_inc0 = 0; - } - if (remaining_src_cols <= 1) { - src_ptr1 = zerobuf; - src_inc1 = 0; - } - if (remaining_src_cols <= 2) { - src_ptr2 = zerobuf; - src_inc2 = 0; - } - if (remaining_src_cols <= 3) { - src_ptr3 = zerobuf; - src_inc3 = 0; - } - if (remaining_src_cols <= 4) { - src_ptr4 = zerobuf; - src_inc4 = 0; - } - if (remaining_src_cols <= 5) { - src_ptr5 = zerobuf; - src_inc5 = 0; - } - if (remaining_src_cols <= 6) { - src_ptr6 = zerobuf; - src_inc6 = 0; - } - src_ptr7 = zerobuf; - src_inc7 = 0; - } - - for (int k = 0; k < src_rows; k += kPackRows) { - const int available_src_rows = src_rows - k; - // Effectively, - // available_src_rows = std::max(0, std::min(kPackDim, src_rows - k)); - // but treat each case separately. - if (available_src_rows >= kPackRows) { - __m256 t0, t1, t2, t3, t4, t5, t6, t7; - __m256 r0, r1, r2, r3, r4, r5, r6, r7; - - t0 = _mm256_loadu_ps(src_ptr0); - t4 = _mm256_loadu_ps(src_ptr4); - t1 = _mm256_loadu_ps(src_ptr1); - t5 = _mm256_loadu_ps(src_ptr5); - t2 = _mm256_loadu_ps(src_ptr2); - t6 = _mm256_loadu_ps(src_ptr6); - t3 = _mm256_loadu_ps(src_ptr3); - t7 = _mm256_loadu_ps(src_ptr7); - - r0 = _mm256_unpacklo_ps(t0, t1); - r4 = _mm256_unpacklo_ps(t4, t5); - r2 = _mm256_unpackhi_ps(t0, t1); - r6 = _mm256_unpackhi_ps(t4, t5); - r1 = _mm256_unpacklo_ps(t2, t3); - r5 = _mm256_unpacklo_ps(t6, t7); - r3 = _mm256_unpackhi_ps(t2, t3); - r7 = _mm256_unpackhi_ps(t6, t7); - - t0 = Mm256UnpackloPsx2(r0, r1); - t4 = Mm256UnpackloPsx2(r4, r5); - t2 = Mm256UnpackhiPsx2(r0, r1); - t6 = Mm256UnpackhiPsx2(r4, r5); - t1 = Mm256UnpackloPsx2(r2, r3); - t5 = Mm256UnpackloPsx2(r6, r7); - t3 = Mm256UnpackhiPsx2(r2, r3); - t7 = Mm256UnpackhiPsx2(r6, r7); - - // The preceding sets of rearrangement operations interleaved by 4 bytes - // and then by 8 bytes *within* lanes. The following set interleave by 16 - // bytes (128-bit), operating *between* AVX lanes. For instance (t0, t4) - // are interleaved to create (r0, r1). This complexity follows from the - // way that AVX is centered around MM 128-bit lanes. - r0 = _mm256_permute2f128_ps(t0, t4, 0x20); - r4 = _mm256_permute2f128_ps(t1, t5, 0x20); - r1 = _mm256_permute2f128_ps(t0, t4, 0x31); - r5 = _mm256_permute2f128_ps(t1, t5, 0x31); - r2 = _mm256_permute2f128_ps(t2, t6, 0x20); - r6 = _mm256_permute2f128_ps(t3, t7, 0x20); - r3 = _mm256_permute2f128_ps(t2, t6, 0x31); - r7 = _mm256_permute2f128_ps(t3, t7, 0x31); - - _mm256_storeu_ps(packed_ptr + 0 * 8, r0); - _mm256_storeu_ps(packed_ptr + 2 * 8, r4); - _mm256_storeu_ps(packed_ptr + 4 * 8, r1); - _mm256_storeu_ps(packed_ptr + 6 * 8, r5); - _mm256_storeu_ps(packed_ptr + 1 * 8, r2); - _mm256_storeu_ps(packed_ptr + 3 * 8, r6); - _mm256_storeu_ps(packed_ptr + 5 * 8, r3); - _mm256_storeu_ps(packed_ptr + 7 * 8, r7); - } else if (available_src_rows > 0) { - const __m256i series = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0); - const __m256i row_mask_v = - _mm256_cmpgt_epi32(_mm256_set1_epi32(available_src_rows), series); - - __m256 t0, t1, t2, t3, t4, t5, t6, t7; - __m256 r0, r1, r2, r3, r4, r5, r6, r7; - - t0 = _mm256_maskload_ps(src_ptr0, row_mask_v); - t4 = _mm256_maskload_ps(src_ptr4, row_mask_v); - t1 = _mm256_maskload_ps(src_ptr1, row_mask_v); - t5 = _mm256_maskload_ps(src_ptr5, row_mask_v); - t2 = _mm256_maskload_ps(src_ptr2, row_mask_v); - t6 = _mm256_maskload_ps(src_ptr6, row_mask_v); - t3 = _mm256_maskload_ps(src_ptr3, row_mask_v); - t7 = _mm256_maskload_ps(src_ptr7, row_mask_v); - - r0 = _mm256_unpacklo_ps(t0, t1); - r4 = _mm256_unpacklo_ps(t4, t5); - r2 = _mm256_unpackhi_ps(t0, t1); - r6 = _mm256_unpackhi_ps(t4, t5); - r1 = _mm256_unpacklo_ps(t2, t3); - r5 = _mm256_unpacklo_ps(t6, t7); - r3 = _mm256_unpackhi_ps(t2, t3); - r7 = _mm256_unpackhi_ps(t6, t7); - - t0 = Mm256UnpackloPsx2(r0, r1); - t4 = Mm256UnpackloPsx2(r4, r5); - t2 = Mm256UnpackhiPsx2(r0, r1); - t6 = Mm256UnpackhiPsx2(r4, r5); - t1 = Mm256UnpackloPsx2(r2, r3); - t5 = Mm256UnpackloPsx2(r6, r7); - t3 = Mm256UnpackhiPsx2(r2, r3); - t7 = Mm256UnpackhiPsx2(r6, r7); - - // The preceding sets of rearrangement operations interleaved by 4 bytes - // and then by 8 bytes *within* lanes. The following set interleave by 16 - // bytes (128-bit), operating *between* AVX lanes. For instance (t0, t4) - // are interleaved to create (r0, r1). This complexity follows from the - // way that AVX is centered around MM 128-bit lanes. - r0 = _mm256_permute2f128_ps(t0, t4, 0x20); - r4 = _mm256_permute2f128_ps(t1, t5, 0x20); - r1 = _mm256_permute2f128_ps(t0, t4, 0x31); - r5 = _mm256_permute2f128_ps(t1, t5, 0x31); - r2 = _mm256_permute2f128_ps(t2, t6, 0x20); - r6 = _mm256_permute2f128_ps(t3, t7, 0x20); - r3 = _mm256_permute2f128_ps(t2, t6, 0x31); - // r7 no longer needed. - - _mm256_storeu_ps(trailing_buf + 0 * 8, r0); - _mm256_storeu_ps(trailing_buf + 2 * 8, r4); - _mm256_storeu_ps(trailing_buf + 4 * 8, r1); - _mm256_storeu_ps(trailing_buf + 6 * 8, r5); - _mm256_storeu_ps(trailing_buf + 1 * 8, r2); - _mm256_storeu_ps(trailing_buf + 3 * 8, r6); - _mm256_storeu_ps(trailing_buf + 5 * 8, r3); - // No store to (trailing_buf + 7 * 8), space not allocated. - } - - packed_ptr += kPackRows * kPackCols; - src_ptr0 += src_inc0; - src_ptr1 += src_inc1; - src_ptr2 += src_inc2; - src_ptr3 += src_inc3; - src_ptr4 += src_inc4; - src_ptr5 += src_inc5; - src_ptr6 += src_inc6; - src_ptr7 += src_inc7; - } -} - -} // namespace. - -void Pack8bitAvx2(const std::int8_t* src_ptr, std::int8_t input_xor, - const std::int8_t* zerobuf, int src_stride, - int remaining_src_cols, int src_rows, std::int8_t* packed_ptr, - std::int32_t* sums_ptr) { - profiler::ScopeLabel label("Pack kAvx2 8bit"); - - using Layout = PackImpl8bitAvx2::Layout; - RUY_DCHECK_EQ(Layout::kCols, 8); - RUY_DCHECK_EQ(Layout::kRows, 4); - - // Each Layout::Rows is 4 contiguous input, contiguous packed elements. - // We process 8 of these chunks at a time, padding short input chunks. - static constexpr int kNumRowChunks = 8; // Short input is padded. - - // Each packed block is 4*8, and there are normally 8. The trailing block is - // only slightly shorter. - constexpr int kTrailingBufSize = - kNumRowChunks * Layout::kCols * Layout::kRows; - std::int8_t trailing_buf[kTrailingBufSize]; - memset(trailing_buf, 0, kTrailingBufSize * sizeof(std::int8_t)); - - Pack8bitAvx2Packer(src_ptr, input_xor, zerobuf, src_stride, - remaining_src_cols, src_rows, packed_ptr, sums_ptr, - trailing_buf); - - constexpr int kChunkedRowMask = kNumRowChunks * Layout::kRows - 1; - const bool trailing_data = (src_rows & kChunkedRowMask) > 0; - // If the number of source rows is not a multiple of kChunkedRowMask, there - // will be data in the trailing buffer, - if (trailing_data > 0) { - const int non_trailing_rows = src_rows & ~kChunkedRowMask; - // Destination "rows" are padded to next highest multiple of Layout::kRows. - const int dst_rows = (src_rows + 3) & ~3; - const int trailing_rows = dst_rows - non_trailing_rows; - memcpy(packed_ptr + Layout::kCols * non_trailing_rows, trailing_buf, - Layout::kCols * trailing_rows * sizeof(std::int8_t)); - } -} - -void PackFloatAvx2(const float* src_ptr, const float* zerobuf, int src_stride, - int remaining_src_cols, int src_rows, float* packed_ptr) { - profiler::ScopeLabel label("Pack kAvx2 float"); - static constexpr int kPackCols = 8; // Source cols packed together. - static constexpr int kPackRows = 8; // Short input is padded. - float trailing_buf[(kPackRows - 1) * kPackCols]; - if (remaining_src_cols < 8) { - memset(trailing_buf, 0, sizeof(trailing_buf)); - } - PackFloatAvx2Packer(src_ptr, zerobuf, src_stride, remaining_src_cols, - src_rows, packed_ptr, trailing_buf); - - const int trailing_rows = src_rows & (kPackRows - 1); - if (trailing_rows > 0) { - const int non_trailing_rows = src_rows & ~(kPackRows - 1); - memcpy(packed_ptr + kPackCols * non_trailing_rows, trailing_buf, - kPackCols * trailing_rows * sizeof(float)); - } -} - -#endif // RUY_PLATFORM(AVX2) && RUY_OPT_ENABLED(RUY_OPT_INTRINSICS) - -} // namespace ruy diff --git a/pack_avx512.cc b/pack_avx512.cc deleted file mode 100644 index 9d27b9a..0000000 --- a/pack_avx512.cc +++ /dev/null @@ -1,693 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include - -#include "check_macros.h" -#include "matrix.h" -#include "opt_set.h" -#include "pack.h" -#include "path.h" -#include "platform.h" -#include "profiler/instrumentation.h" - -#if RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_INTRINSICS) -#include // IWYU pragma: keep -#endif - -namespace ruy { - -#if !(RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_ASM)) - -void Pack8bitAvx512(const std::int8_t* src_ptr, std::int8_t input_xor, - const std::int8_t* zerobuf, int src_stride, - int remaining_src_cols, int src_rows, - std::int8_t* packed_ptr, std::int32_t* sums_ptr) { - // CPU-ID-based checks should disable the path that would reach this point. - RUY_DCHECK(false); -} - -void PackFloatAvx512(const float* src_ptr, const float* zerobuf, int src_stride, - int remaining_src_cols, int src_rows, float* packed_ptr) { - // CPU-ID-based checks should disable the path that would reach this point. - RUY_DCHECK(false); -} - -#else // RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_ASM) - -// The first int8_t template parameter is arbitrary: this routine is common to -// all 8-bit source matrix types. -using PackImpl8bitAvx512 = - PackImpl, - std::int8_t, std::int8_t, std::int32_t>; - -namespace { - -inline void ZeroHalf8bitAvx512(int src_rows, std::int8_t packed_zero_point, - std::int8_t* packed_ptr) { - using Layout = PackImpl8bitAvx512::Layout; - static constexpr int kHalfLayoutCols = - PackImpl8bitAvx512::kHalfLayoutCols; // Half the number of cols in a - // block. - RUY_DCHECK_EQ(kHalfLayoutCols, 8); - RUY_DCHECK_EQ(Layout::kCols, 16); - RUY_DCHECK_EQ(Layout::kRows, 4); - - const int non_trailing_blocks = (src_rows & ~31) >> 2; - // This routine fills half blocks, and typically fills the second halves. - // Thus packed_ptr is already offset by 8 * 4. - for (int k = 0; k < non_trailing_blocks; ++k) { - for (int j = 0; j < (kHalfLayoutCols * Layout::kRows); ++j) { - packed_ptr[Layout::kCols * Layout::kRows * k + j] = packed_zero_point; - } - } -} - -inline __m512i LoaduTwo(const std::int8_t* addr_lo, - const std::int8_t* addr_hi) { - __m512i lower_filled = _mm512_castsi256_si512(_mm256_loadu_epi8(addr_lo)); - return _mm512_inserti32x8(lower_filled, _mm256_loadu_epi8(addr_hi), 1); -} - -inline __m512i MaskLoaduTwo(__mmask32 row_mask, const __m256i default_value_v, - const std::int8_t* addr_lo, - const std::int8_t* addr_hi) { - const __m512i lower_filled = _mm512_castsi256_si512( - _mm256_mask_loadu_epi8(default_value_v, row_mask, addr_lo)); - return _mm512_inserti32x8( - lower_filled, _mm256_mask_loadu_epi8(default_value_v, row_mask, addr_hi), - 1); -} - -inline void HalfPack8bitAvx512(const std::int8_t* src_ptr, - std::int8_t input_xor, - const std::int8_t* zerobuf, int src_stride, - int remaining_src_cols, int src_rows, - std::int8_t* packed_ptr, std::int32_t* sums_ptr, - std::int8_t* trailing_buf) { - using Layout = PackImpl8bitAvx512::Layout; - RUY_DCHECK_EQ(Layout::kCols, 16); - RUY_DCHECK_EQ(Layout::kRows, 4); - // Each Layout::Rows is 4 contiguous input, contiguous packed elements. - // We process 8 of these chunks at a time, padding short input chunks. - constexpr int kNumRowChunks = 8; - constexpr int kNumChunkedSrcRows = kNumRowChunks * Layout::kRows; - - const std::int8_t* src_ptr0 = src_ptr; - const std::int8_t* src_ptr1 = src_ptr0 + src_stride; - const std::int8_t* src_ptr2 = src_ptr1 + src_stride; - const std::int8_t* src_ptr3 = src_ptr2 + src_stride; - const std::int8_t* src_ptr4 = src_ptr3 + src_stride; - const std::int8_t* src_ptr5 = src_ptr4 + src_stride; - const std::int8_t* src_ptr6 = src_ptr5 + src_stride; - const std::int8_t* src_ptr7 = src_ptr6 + src_stride; - std::int64_t src_inc0 = kNumChunkedSrcRows; - std::int64_t src_inc1 = kNumChunkedSrcRows; - std::int64_t src_inc2 = kNumChunkedSrcRows; - std::int64_t src_inc3 = kNumChunkedSrcRows; - std::int64_t src_inc4 = kNumChunkedSrcRows; - std::int64_t src_inc5 = kNumChunkedSrcRows; - std::int64_t src_inc6 = kNumChunkedSrcRows; - std::int64_t src_inc7 = kNumChunkedSrcRows; - // Handle cases where source does not have kHalfLayoutCols (8) columns. - if (remaining_src_cols < 8) { - if (remaining_src_cols <= 0) { - src_ptr0 = zerobuf; - src_inc0 = 0; - } - if (remaining_src_cols <= 1) { - src_ptr1 = zerobuf; - src_inc1 = 0; - } - if (remaining_src_cols <= 2) { - src_ptr2 = zerobuf; - src_inc2 = 0; - } - if (remaining_src_cols <= 3) { - src_ptr3 = zerobuf; - src_inc3 = 0; - } - if (remaining_src_cols <= 4) { - src_ptr4 = zerobuf; - src_inc4 = 0; - } - if (remaining_src_cols <= 5) { - src_ptr5 = zerobuf; - src_inc5 = 0; - } - if (remaining_src_cols <= 6) { - src_ptr6 = zerobuf; - src_inc6 = 0; - } - src_ptr7 = zerobuf; - src_inc7 = 0; - } - - const std::int8_t zero_point = zerobuf[0]; - - if (sums_ptr) { - // i: kHalfLayoutCols. - for (int i = 0; i < 8; ++i) { - sums_ptr[i] = 0; - } - } - std::int32_t sums_adjustment = 0; - const __m512i ones_16bit = _mm512_set1_epi16(1); - __m512i sums_8x2_32bit = _mm512_set1_epi32(0); - - // The overall packing effectively pads the source rows to - // (src_rows + 63) & ~63. The iteration over k may skip when m=1, and then we - // only pack for (src_rows + 31) & ~31. When there is an incomplete - // destination block, this is stored into trailing_buf instead of packed_ptr. - for (int k = 0; k < src_rows; k += 2 * kNumChunkedSrcRows) { - // m: {0, 1} for 2 chunks of rows. - for (int m = 0; m < 2; ++m) { - // Available source rows. - // If this is less than 0 (for m=1), we skip, having filled trailing - // buffer for m=0. Also, if source rows is zero on m=1, then we filled - // exactly to the end of the column in the packed buffer. - const int available_src_rows = src_rows - k - m * kNumChunkedSrcRows; - // Effectively, - // available rows = std::max(0, std::min(8, src_rows - k - 8 * 4 * m)); - // treat each case separately. - if (available_src_rows >= kNumChunkedSrcRows) { - // i: chunks, s: Layout::Rows. - if (sums_ptr) { - __m512i t0, t1, t2, t3; - __m512i r0, r1, r2, r3; - const __m512i input_xor_v = _mm512_set1_epi8(input_xor); - - t0 = LoaduTwo(src_ptr0, src_ptr4); - t1 = LoaduTwo(src_ptr1, src_ptr5); - t2 = LoaduTwo(src_ptr2, src_ptr6); - t3 = LoaduTwo(src_ptr3, src_ptr7); - - r0 = _mm512_unpacklo_epi32(t0, t1); - r2 = _mm512_unpackhi_epi32(t0, t1); - r1 = _mm512_unpacklo_epi32(t2, t3); - r3 = _mm512_unpackhi_epi32(t2, t3); - - t0 = _mm512_unpacklo_epi64(r0, r1); - t2 = _mm512_unpackhi_epi64(r0, r1); - t1 = _mm512_unpacklo_epi64(r2, r3); - t3 = _mm512_unpackhi_epi64(r2, r3); - - r0 = _mm512_shuffle_i32x4(t0, t1, 0x88); - r1 = _mm512_shuffle_i32x4(t0, t1, 0xdd); - r2 = _mm512_shuffle_i32x4(t2, t3, 0x88); - r3 = _mm512_shuffle_i32x4(t2, t3, 0xdd); - - r0 = _mm512_xor_si512(r0, input_xor_v); - r1 = _mm512_xor_si512(r1, input_xor_v); - r2 = _mm512_xor_si512(r2, input_xor_v); - r3 = _mm512_xor_si512(r3, input_xor_v); - - const __m256i r0_0 = _mm512_castsi512_si256(r0); - const __m256i r0_1 = _mm512_extracti32x8_epi32(r0, 1); - const __m256i r1_0 = _mm512_castsi512_si256(r1); - const __m256i r1_1 = _mm512_extracti32x8_epi32(r1, 1); - const __m256i r2_0 = _mm512_castsi512_si256(r2); - const __m256i r2_1 = _mm512_extracti32x8_epi32(r2, 1); - const __m256i r3_0 = _mm512_castsi512_si256(r3); - const __m256i r3_1 = _mm512_extracti32x8_epi32(r3, 1); - - __m512i sums_8x4_16bit; - sums_8x4_16bit = _mm512_cvtepi8_epi16(r0_0); - sums_8x4_16bit = - _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r0_1)); - sums_8x4_16bit = - _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r1_0)); - sums_8x4_16bit = - _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r1_1)); - sums_8x4_16bit = - _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r2_0)); - sums_8x4_16bit = - _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r2_1)); - sums_8x4_16bit = - _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r3_0)); - sums_8x4_16bit = - _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r3_1)); - // The sums have been performed across columns, and now we have - // 4x16-bit sums packed together. We use madd for pairwise 32-bit - // sums. - const __m512i sums_8x2_32bit_new = - _mm512_madd_epi16(sums_8x4_16bit, ones_16bit); - sums_8x2_32bit = _mm512_add_epi32(sums_8x2_32bit, sums_8x2_32bit_new); - - _mm256_storeu_epi8(packed_ptr + 0 * 16 * 4, r0_0); - _mm256_storeu_epi8(packed_ptr + 2 * 16 * 4, r0_1); - _mm256_storeu_epi8(packed_ptr + 4 * 16 * 4, r1_0); - _mm256_storeu_epi8(packed_ptr + 6 * 16 * 4, r1_1); - _mm256_storeu_epi8(packed_ptr + 1 * 16 * 4, r2_0); - _mm256_storeu_epi8(packed_ptr + 3 * 16 * 4, r2_1); - _mm256_storeu_epi8(packed_ptr + 5 * 16 * 4, r3_0); - _mm256_storeu_epi8(packed_ptr + 7 * 16 * 4, r3_1); - } else { - __m512i t0, t1, t2, t3; - __m512i r0, r1, r2, r3; - const __m512i input_xor_v = _mm512_set1_epi8(input_xor); - - t0 = LoaduTwo(src_ptr0, src_ptr4); - t1 = LoaduTwo(src_ptr1, src_ptr5); - t2 = LoaduTwo(src_ptr2, src_ptr6); - t3 = LoaduTwo(src_ptr3, src_ptr7); - - r0 = _mm512_unpacklo_epi32(t0, t1); - r2 = _mm512_unpackhi_epi32(t0, t1); - r1 = _mm512_unpacklo_epi32(t2, t3); - r3 = _mm512_unpackhi_epi32(t2, t3); - - t0 = _mm512_unpacklo_epi64(r0, r1); - t2 = _mm512_unpackhi_epi64(r0, r1); - t1 = _mm512_unpacklo_epi64(r2, r3); - t3 = _mm512_unpackhi_epi64(r2, r3); - - r0 = _mm512_shuffle_i32x4(t0, t1, 0x88); - r1 = _mm512_shuffle_i32x4(t0, t1, 0xdd); - r2 = _mm512_shuffle_i32x4(t2, t3, 0x88); - r3 = _mm512_shuffle_i32x4(t2, t3, 0xdd); - - r0 = _mm512_xor_si512(r0, input_xor_v); - r1 = _mm512_xor_si512(r1, input_xor_v); - r2 = _mm512_xor_si512(r2, input_xor_v); - r3 = _mm512_xor_si512(r3, input_xor_v); - - const __m256i r0_0 = _mm512_castsi512_si256(r0); - const __m256i r0_1 = _mm512_extracti32x8_epi32(r0, 1); - const __m256i r1_0 = _mm512_castsi512_si256(r1); - const __m256i r1_1 = _mm512_extracti32x8_epi32(r1, 1); - const __m256i r2_0 = _mm512_castsi512_si256(r2); - const __m256i r2_1 = _mm512_extracti32x8_epi32(r2, 1); - const __m256i r3_0 = _mm512_castsi512_si256(r3); - const __m256i r3_1 = _mm512_extracti32x8_epi32(r3, 1); - _mm256_storeu_epi8(packed_ptr + 0 * 16 * 4, r0_0); - _mm256_storeu_epi8(packed_ptr + 2 * 16 * 4, r0_1); - _mm256_storeu_epi8(packed_ptr + 4 * 16 * 4, r1_0); - _mm256_storeu_epi8(packed_ptr + 6 * 16 * 4, r1_1); - _mm256_storeu_epi8(packed_ptr + 1 * 16 * 4, r2_0); - _mm256_storeu_epi8(packed_ptr + 3 * 16 * 4, r2_1); - _mm256_storeu_epi8(packed_ptr + 5 * 16 * 4, r3_0); - _mm256_storeu_epi8(packed_ptr + 7 * 16 * 4, r3_1); - } - } else if (available_src_rows > 0) { - RUY_DCHECK_LT(available_src_rows >> 2, kNumChunkedSrcRows); - const __mmask32 row_mask = - (static_cast(1) << available_src_rows) - 1; - - // We do not care what goes into the trailing buffer, but we want - // in_data[...] ^ input_xor == 0 for irrelevant values in the summation. - // - // We compensate for padding-with-zero_point by initializing the - // summations with the compensating offset, effectively - // ((input_xor ^ input_xor) - (zero_point ^ input_xor)) * - // 4 * (8 - ((available_src_rows + 3) >> 2)). - // - // Note that (zero_point ^ input_xor) is performed in 8-bits and then - // cast. - sums_adjustment += -(zero_point ^ input_xor) * 4 * - (8 - ((available_src_rows + 3) >> 2)); - - __m512i t0, t1, t2, t3; - __m512i r0, r1, r2, r3; - const __m512i input_xor_v = _mm512_set1_epi8(input_xor); - const __m256i zero_point_v = _mm256_set1_epi8(zero_point); - - t0 = MaskLoaduTwo(row_mask, zero_point_v, src_ptr0, src_ptr4); - t1 = MaskLoaduTwo(row_mask, zero_point_v, src_ptr1, src_ptr5); - t2 = MaskLoaduTwo(row_mask, zero_point_v, src_ptr2, src_ptr6); - t3 = MaskLoaduTwo(row_mask, zero_point_v, src_ptr3, src_ptr7); - - r0 = _mm512_unpacklo_epi32(t0, t1); - r2 = _mm512_unpackhi_epi32(t0, t1); - r1 = _mm512_unpacklo_epi32(t2, t3); - r3 = _mm512_unpackhi_epi32(t2, t3); - - t0 = _mm512_unpacklo_epi64(r0, r1); - t2 = _mm512_unpackhi_epi64(r0, r1); - t1 = _mm512_unpacklo_epi64(r2, r3); - t3 = _mm512_unpackhi_epi64(r2, r3); - - r0 = _mm512_shuffle_i32x4(t0, t1, 0x88); - r1 = _mm512_shuffle_i32x4(t0, t1, 0xdd); - r2 = _mm512_shuffle_i32x4(t2, t3, 0x88); - r3 = _mm512_shuffle_i32x4(t2, t3, 0xdd); - - r0 = _mm512_xor_si512(r0, input_xor_v); - r1 = _mm512_xor_si512(r1, input_xor_v); - r2 = _mm512_xor_si512(r2, input_xor_v); - r3 = _mm512_xor_si512(r3, input_xor_v); - - const __m256i r0_0 = _mm512_castsi512_si256(r0); - const __m256i r0_1 = _mm512_extracti32x8_epi32(r0, 1); - const __m256i r1_0 = _mm512_castsi512_si256(r1); - const __m256i r1_1 = _mm512_extracti32x8_epi32(r1, 1); - const __m256i r2_0 = _mm512_castsi512_si256(r2); - const __m256i r2_1 = _mm512_extracti32x8_epi32(r2, 1); - const __m256i r3_0 = _mm512_castsi512_si256(r3); - const __m256i r3_1 = _mm512_extracti32x8_epi32(r3, 1); - - __m512i sums_8x4_16bit; - sums_8x4_16bit = _mm512_cvtepi8_epi16(r0_0); - sums_8x4_16bit = - _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r0_1)); - sums_8x4_16bit = - _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r1_0)); - sums_8x4_16bit = - _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r1_1)); - sums_8x4_16bit = - _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r2_0)); - sums_8x4_16bit = - _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r2_1)); - sums_8x4_16bit = - _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r3_0)); - sums_8x4_16bit = - _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r3_1)); - // The sums have been performed across columns, and now we have - // 4x16-bit sums packed together. We use madd for pairwise 32-bit - // sums. - const __m512i sums_8x2_32bit_new = - _mm512_madd_epi16(sums_8x4_16bit, ones_16bit); - sums_8x2_32bit = _mm512_add_epi32(sums_8x2_32bit, sums_8x2_32bit_new); - - _mm256_storeu_epi8(trailing_buf + 0 * 16 * 4, r0_0); - _mm256_storeu_epi8(trailing_buf + 2 * 16 * 4, r0_1); - _mm256_storeu_epi8(trailing_buf + 4 * 16 * 4, r1_0); - _mm256_storeu_epi8(trailing_buf + 6 * 16 * 4, r1_1); - _mm256_storeu_epi8(trailing_buf + 1 * 16 * 4, r2_0); - _mm256_storeu_epi8(trailing_buf + 3 * 16 * 4, r2_1); - _mm256_storeu_epi8(trailing_buf + 5 * 16 * 4, r3_0); - _mm256_storeu_epi8(trailing_buf + 7 * 16 * 4, r3_1); - } - - packed_ptr += 16 * kNumChunkedSrcRows; - src_ptr0 += src_inc0; - src_ptr1 += src_inc1; - src_ptr2 += src_inc2; - src_ptr3 += src_inc3; - src_ptr4 += src_inc4; - src_ptr5 += src_inc5; - src_ptr6 += src_inc6; - src_ptr7 += src_inc7; - } - } - - if (sums_ptr) { - const __m256i sums_adjustment_v = _mm256_set1_epi32(sums_adjustment); - - __m256i sums = _mm256_loadu_epi32(sums_ptr); - const __m512i idx = - _mm512_set_epi32(15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0); - - // We earlier used madd for pairwise 32-bit sums, and now we deinterlace the - // neighbours, finshing up by adding them to the stored accumulated sums. - const __m512i sums_2x8_32bit = - _mm512_permutexvar_epi32(idx, sums_8x2_32bit); - sums = _mm256_add_epi32(sums, sums_adjustment_v); - sums = _mm256_add_epi32(sums, _mm512_castsi512_si256(sums_2x8_32bit)); - sums = _mm256_add_epi32(sums, _mm512_extracti32x8_epi32(sums_2x8_32bit, 1)); - - _mm256_storeu_epi32(sums_ptr, sums); - } -} - -inline __m512 LoaduTwo(const float* addr_lo, const float* addr_hi) { - const __m512 lower_filled = _mm512_castps256_ps512(_mm256_loadu_ps(addr_lo)); - return _mm512_insertf32x8(lower_filled, _mm256_loadu_ps(addr_hi), 1); -} - -inline __m512 MaskLoaduTwo(__mmask8 row_mask, const float* addr_lo, - const float* addr_hi) { - const __m512 lower_filled = - _mm512_castps256_ps512(_mm256_maskz_loadu_ps(row_mask, addr_lo)); - return _mm512_insertf32x8(lower_filled, - _mm256_maskz_loadu_ps(row_mask, addr_hi), 1); -} - -inline __m512 Mm512UnpackloPsx2(const __m512 a, const __m512 b) { - return _mm512_castpd_ps( - _mm512_unpacklo_pd(_mm512_castps_pd(a), _mm512_castps_pd(b))); -} - -inline __m512 Mm512UnpackhiPsx2(const __m512 a, const __m512 b) { - return _mm512_castpd_ps( - _mm512_unpackhi_pd(_mm512_castps_pd(a), _mm512_castps_pd(b))); -} - -inline void HalfPackFloatAvx512(const float* src_ptr, const float* zerobuf, - int src_stride, int remaining_src_cols, - int src_rows, float* packed_ptr, - float* trailing_buf) { - const float* src_ptr0 = src_ptr; - const float* src_ptr1 = src_ptr0 + src_stride; - const float* src_ptr2 = src_ptr1 + src_stride; - const float* src_ptr3 = src_ptr2 + src_stride; - const float* src_ptr4 = src_ptr3 + src_stride; - const float* src_ptr5 = src_ptr4 + src_stride; - const float* src_ptr6 = src_ptr5 + src_stride; - const float* src_ptr7 = src_ptr6 + src_stride; - std::int64_t src_inc0 = 8; - std::int64_t src_inc1 = 8; - std::int64_t src_inc2 = 8; - std::int64_t src_inc3 = 8; - std::int64_t src_inc4 = 8; - std::int64_t src_inc5 = 8; - std::int64_t src_inc6 = 8; - std::int64_t src_inc7 = 8; - if (remaining_src_cols < 8) { - if (remaining_src_cols <= 0) { - src_ptr0 = zerobuf; - src_inc0 = 0; - } - if (remaining_src_cols <= 1) { - src_ptr1 = zerobuf; - src_inc1 = 0; - } - if (remaining_src_cols <= 2) { - src_ptr2 = zerobuf; - src_inc2 = 0; - } - if (remaining_src_cols <= 3) { - src_ptr3 = zerobuf; - src_inc3 = 0; - } - if (remaining_src_cols <= 4) { - src_ptr4 = zerobuf; - src_inc4 = 0; - } - if (remaining_src_cols <= 5) { - src_ptr5 = zerobuf; - src_inc5 = 0; - } - if (remaining_src_cols <= 6) { - src_ptr6 = zerobuf; - src_inc6 = 0; - } - src_ptr7 = zerobuf; - src_inc7 = 0; - } - - for (int k = 0; k < src_rows; k += 16) { - for (int m = 0; m < 2; ++m) { - const int available_src_rows = src_rows - k - 8 * m; - // Effectively, - // available_src_rows = std::max(0, std::min(8, src_rows - k - 8 * m)); - // but treat each case separately. - if (available_src_rows > 7) { - __m512 t0, t1, t2, t3; - __m512 r0, r1, r2, r3; - - t0 = LoaduTwo(src_ptr0, src_ptr4); - t1 = LoaduTwo(src_ptr1, src_ptr5); - t2 = LoaduTwo(src_ptr2, src_ptr6); - t3 = LoaduTwo(src_ptr3, src_ptr7); - - r0 = _mm512_unpacklo_ps(t0, t1); - r2 = _mm512_unpackhi_ps(t0, t1); - r1 = _mm512_unpacklo_ps(t2, t3); - r3 = _mm512_unpackhi_ps(t2, t3); - - t0 = Mm512UnpackloPsx2(r0, r1); - t2 = Mm512UnpackhiPsx2(r0, r1); - t1 = Mm512UnpackloPsx2(r2, r3); - t3 = Mm512UnpackhiPsx2(r2, r3); - - r0 = _mm512_shuffle_f32x4(t0, t1, 0x88); - r1 = _mm512_shuffle_f32x4(t0, t1, 0xdd); - r2 = _mm512_shuffle_f32x4(t2, t3, 0x88); - r3 = _mm512_shuffle_f32x4(t2, t3, 0xdd); - - _mm256_storeu_ps(packed_ptr + 0 * 16, _mm512_castps512_ps256(r0)); - _mm256_storeu_ps(packed_ptr + 2 * 16, _mm512_extractf32x8_ps(r0, 1)); - _mm256_storeu_ps(packed_ptr + 4 * 16, _mm512_castps512_ps256(r1)); - _mm256_storeu_ps(packed_ptr + 6 * 16, _mm512_extractf32x8_ps(r1, 1)); - _mm256_storeu_ps(packed_ptr + 1 * 16, _mm512_castps512_ps256(r2)); - _mm256_storeu_ps(packed_ptr + 3 * 16, _mm512_extractf32x8_ps(r2, 1)); - _mm256_storeu_ps(packed_ptr + 5 * 16, _mm512_castps512_ps256(r3)); - _mm256_storeu_ps(packed_ptr + 7 * 16, _mm512_extractf32x8_ps(r3, 1)); - } else if (available_src_rows > 0) { - const __mmask8 row_mask = - (static_cast(1) << available_src_rows) - 1; - - __m512 t0, t1, t2, t3; - __m512 r0, r1, r2, r3; - - t0 = MaskLoaduTwo(row_mask, src_ptr0, src_ptr4); - t1 = MaskLoaduTwo(row_mask, src_ptr1, src_ptr5); - t2 = MaskLoaduTwo(row_mask, src_ptr2, src_ptr6); - t3 = MaskLoaduTwo(row_mask, src_ptr3, src_ptr7); - - r0 = _mm512_unpacklo_ps(t0, t1); - r2 = _mm512_unpackhi_ps(t0, t1); - r1 = _mm512_unpacklo_ps(t2, t3); - r3 = _mm512_unpackhi_ps(t2, t3); - - t0 = Mm512UnpackloPsx2(r0, r1); - t2 = Mm512UnpackhiPsx2(r0, r1); - t1 = Mm512UnpackloPsx2(r2, r3); - t3 = Mm512UnpackhiPsx2(r2, r3); - - r0 = _mm512_shuffle_f32x4(t0, t1, 0x88); - r1 = _mm512_shuffle_f32x4(t0, t1, 0xdd); - r2 = _mm512_shuffle_f32x4(t2, t3, 0x88); - r3 = _mm512_shuffle_f32x4(t2, t3, 0xdd); - - _mm256_storeu_ps(trailing_buf + 0 * 16, _mm512_castps512_ps256(r0)); - _mm256_storeu_ps(trailing_buf + 2 * 16, _mm512_extractf32x8_ps(r0, 1)); - _mm256_storeu_ps(trailing_buf + 4 * 16, _mm512_castps512_ps256(r1)); - _mm256_storeu_ps(trailing_buf + 6 * 16, _mm512_extractf32x8_ps(r1, 1)); - _mm256_storeu_ps(trailing_buf + 1 * 16, _mm512_castps512_ps256(r2)); - _mm256_storeu_ps(trailing_buf + 3 * 16, _mm512_extractf32x8_ps(r2, 1)); - _mm256_storeu_ps(trailing_buf + 5 * 16, _mm512_castps512_ps256(r3)); - // Do not store _mm512_extractf32x8_ps(r3, 1). - } - - packed_ptr += 16 * 8; - src_ptr0 += src_inc0; - src_ptr1 += src_inc1; - src_ptr2 += src_inc2; - src_ptr3 += src_inc3; - src_ptr4 += src_inc4; - src_ptr5 += src_inc5; - src_ptr6 += src_inc6; - src_ptr7 += src_inc7; - } - } -} - -inline void ZeroHalfFloatAvx512(int src_rows, float* packed_ptr) { - const int non_trailing_rows = src_rows & ~7; - for (int k = 0; k < non_trailing_rows; ++k) { - for (int j = 0; j < 8; ++j) { - packed_ptr[j] = 0.0f; - } - packed_ptr += 16; - } -} - -} // namespace. - -void Pack8bitAvx512(const std::int8_t* src_ptr, std::int8_t input_xor, - const std::int8_t* zerobuf, int src_stride, - int remaining_src_cols, int src_rows, - std::int8_t* packed_ptr, std::int32_t* sums_ptr) { - profiler::ScopeLabel label("Pack kAvx512 8bit"); - - using Layout = PackImpl8bitAvx512::Layout; - constexpr int kHalfBlockOffset = 32; - RUY_DCHECK_EQ(kHalfBlockOffset * 2, Layout::kRows * Layout::kCols); - static constexpr int kHalfLayoutCols = - PackImpl8bitAvx512::kHalfLayoutCols; // Half the number of cols in a - // block. - RUY_DCHECK_EQ(kHalfLayoutCols, 8); - RUY_DCHECK_EQ(Layout::kCols, 16); - RUY_DCHECK_EQ(Layout::kRows, 4); - - // Each Layout::Rows is 4 contiguous input, contiguous packed elements. - // We process 8 of these chunks at a time, padding short input chunks. - constexpr int kNumRowChunks = 8; - - // Each packed block is 4*16, and there are normally 8. The trailing block is - // only slightly shorter. - constexpr int kTrailingBufSize = - kNumRowChunks * Layout::kCols * Layout::kRows; - std::int8_t trailing_buf[kTrailingBufSize]; - memset(trailing_buf, 0, kTrailingBufSize * sizeof(std::int8_t)); - - std::int32_t* second_sums_ptr = - sums_ptr ? sums_ptr + kHalfLayoutCols : nullptr; - if (remaining_src_cols > kHalfLayoutCols) { - HalfPack8bitAvx512(src_ptr, input_xor, zerobuf, src_stride, - remaining_src_cols, src_rows, packed_ptr, sums_ptr, - trailing_buf); - HalfPack8bitAvx512(src_ptr + src_stride * kHalfLayoutCols, input_xor, - zerobuf, src_stride, - remaining_src_cols - kHalfLayoutCols, src_rows, - packed_ptr + kHalfBlockOffset, second_sums_ptr, - trailing_buf + kHalfBlockOffset); - } else { - HalfPack8bitAvx512(src_ptr, input_xor, zerobuf, src_stride, - remaining_src_cols, src_rows, packed_ptr, sums_ptr, - trailing_buf); - ZeroHalf8bitAvx512(src_rows, zerobuf[0] ^ input_xor, - packed_ptr + kHalfBlockOffset); - // The kernel may not need the second half-blocks sums to be set. - if (second_sums_ptr) { - for (int i = 0; i < kHalfLayoutCols; ++i) { - second_sums_ptr[i] = (zerobuf[0] ^ input_xor) * ((src_rows + 3) & ~3); - } - } - } - constexpr int kChunkedRowMask = kNumRowChunks * Layout::kRows - 1; - const bool trailing_data = (src_rows & kChunkedRowMask) > 0; - // If the number of source rows is not a multiple of kChunkedRowMask, there - // will be data in the trailing buffer, - if (trailing_data > 0) { - const int non_trailing_rows = src_rows & ~kChunkedRowMask; - // Destination "rows" are padded to next highest multiple of Layout::kRows. - const int dst_rows = (src_rows + 3) & ~3; - const int trailing_rows = dst_rows - non_trailing_rows; - memcpy(packed_ptr + Layout::kCols * non_trailing_rows, trailing_buf, - Layout::kCols * trailing_rows * sizeof(std::int8_t)); - } -} - -void PackFloatAvx512(const float* src_ptr, const float* zerobuf, int src_stride, - int remaining_src_cols, int src_rows, float* packed_ptr) { - profiler::ScopeLabel label("Pack kAvx512 float"); - float trailing_buf[7 * 16]; - if (remaining_src_cols > 8) { - HalfPackFloatAvx512(src_ptr, zerobuf, src_stride, remaining_src_cols, - src_rows, packed_ptr, trailing_buf); - HalfPackFloatAvx512(src_ptr + src_stride * 8, zerobuf, src_stride, - remaining_src_cols - 8, src_rows, packed_ptr + 8, - trailing_buf + 8); - } else { - memset(trailing_buf, 0, sizeof(trailing_buf)); - HalfPackFloatAvx512(src_ptr, zerobuf, src_stride, remaining_src_cols, - src_rows, packed_ptr, trailing_buf); - ZeroHalfFloatAvx512(src_rows, packed_ptr + 8); - } - const int trailing_rows = src_rows & 7; - if (trailing_rows > 0) { - const int non_trailing_rows = src_rows & ~7; - memcpy(packed_ptr + 16 * non_trailing_rows, trailing_buf, - 16 * trailing_rows * sizeof(float)); - } -} - -#endif // RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_INTRINSICS) - -} // namespace ruy diff --git a/pack_avxvnni.cc b/pack_avxvnni.cc deleted file mode 100644 index 6b08415..0000000 --- a/pack_avxvnni.cc +++ /dev/null @@ -1,478 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include - -#include "check_macros.h" -#include "matrix.h" -#include "opt_set.h" -#include "pack.h" -#include "path.h" -#include "platform.h" -#include "profiler/instrumentation.h" - -#if RUY_PLATFORM(AVX_VNNI) && RUY_OPT_ENABLED(RUY_OPT_INTRINSICS) -#include // IWYU pragma: keep -#endif - -namespace ruy { - -#if !(RUY_PLATFORM(AVX_VNNI) && RUY_OPT_ENABLED(RUY_OPT_ASM)) - -void Pack8bitAvxVnni(const std::int8_t* src_ptr, std::int8_t input_xor, - const std::int8_t* zerobuf, int src_stride, - int remaining_src_cols, int src_rows, - std::int8_t* packed_ptr, std::int32_t* sums_ptr) { - // CPU-ID-based checks should disable the path that would reach this point. - RUY_DCHECK(false); -} - -void PackFloatAvxVnni(const float* src_ptr, const float* zerobuf, - int src_stride, int remaining_src_cols, int src_rows, - float* packed_ptr) { - // CPU-ID-based checks should disable the path that would reach this point. - RUY_DCHECK(false); -} - -#else // RUY_PLATFORM(AVX_VNNI) && RUY_OPT_ENABLED(RUY_OPT_ASM) - -// The first int8_t template parameter is arbitrary: this routine is common to -// all 8-bit source matrix types. -using PackImpl8bitAvxVnni = - PackImpl, - std::int8_t, std::int8_t, std::int32_t>; - -namespace { - -inline void ZeroHalf8bitAvxVnni(int src_rows, std::int8_t packed_zero_point, - std::int8_t* packed_ptr) { - const int non_trailing_blocks = (src_rows & ~31) >> 2; - // This routine fills half blocks, and typically fills the second halves. Thus - // packed_ptr is already offset by 8*4. - for (int k = 0; k < non_trailing_blocks; ++k) { - for (int j = 0; j < (8 * 4); ++j) { - packed_ptr[16 * 4 * k + j] = packed_zero_point; - } - } -} - -inline void HalfPack8bitAvxVnni(const std::int8_t* src_ptr, - std::int8_t input_xor, - const std::int8_t* zerobuf, int src_stride, - int remaining_src_cols, int src_rows, - std::int8_t* packed_ptr, std::int32_t* sums_ptr, - std::int8_t* trailing_buf) { - std::int8_t in_data[8][8][4]; - - const std::int8_t* src_ptr0 = src_ptr; - const std::int8_t* src_ptr1 = src_ptr0 + src_stride; - const std::int8_t* src_ptr2 = src_ptr1 + src_stride; - const std::int8_t* src_ptr3 = src_ptr2 + src_stride; - const std::int8_t* src_ptr4 = src_ptr3 + src_stride; - const std::int8_t* src_ptr5 = src_ptr4 + src_stride; - const std::int8_t* src_ptr6 = src_ptr5 + src_stride; - const std::int8_t* src_ptr7 = src_ptr6 + src_stride; - std::int64_t src_inc0 = 8 * 4; - std::int64_t src_inc1 = 8 * 4; - std::int64_t src_inc2 = 8 * 4; - std::int64_t src_inc3 = 8 * 4; - std::int64_t src_inc4 = 8 * 4; - std::int64_t src_inc5 = 8 * 4; - std::int64_t src_inc6 = 8 * 4; - std::int64_t src_inc7 = 8 * 4; - if (remaining_src_cols < 8) { - if (remaining_src_cols <= 0) { - src_ptr0 = zerobuf; - src_inc0 = 0; - } - if (remaining_src_cols <= 1) { - src_ptr1 = zerobuf; - src_inc1 = 0; - } - if (remaining_src_cols <= 2) { - src_ptr2 = zerobuf; - src_inc2 = 0; - } - if (remaining_src_cols <= 3) { - src_ptr3 = zerobuf; - src_inc3 = 0; - } - if (remaining_src_cols <= 4) { - src_ptr4 = zerobuf; - src_inc4 = 0; - } - if (remaining_src_cols <= 5) { - src_ptr5 = zerobuf; - src_inc5 = 0; - } - if (remaining_src_cols <= 6) { - src_ptr6 = zerobuf; - src_inc6 = 0; - } - src_ptr7 = zerobuf; - src_inc7 = 0; - } - - const std::int8_t zero_point = zerobuf[0]; - - if (sums_ptr) { - for (int i = 0; i < 8; ++i) { - sums_ptr[i] = 0; - } - } - - // The overall packing effectively pads the source rows to - // (src_rows + 63) & ~63. The iteration over k may skip when m=1, and then we - // only pack for (src_rows + 31) & ~31. When there is an incomplete - // destination block, this is stored into trailing_buf instead of packed_ptr. - for (int k = 0; k < src_rows; k += 16 * 4) { - for (int m = 0; m < 2; ++m) { - // Available source rows. - // If this is less than 0 (for m=1), we skip, having filled trailing - // buffer for m=0. Also, if source rows is zero on m=1, then we filled - // exactly to the end of the column in the packed buffer. - const int packed_rows = src_rows - k - 8 * m * 4; - // Effectively, - // packed_rows = std::max(0, std::min(8, src_rows - k - 8 * m)); - // but treat each case separately. - if (packed_rows >= (8 * 4)) { - for (int i = 0; i < 8; ++i) { - for (int s = 0; s < 4; ++s) { - in_data[0][i][s] = src_ptr0[i * 4 + s]; - in_data[1][i][s] = src_ptr1[i * 4 + s]; - in_data[2][i][s] = src_ptr2[i * 4 + s]; - in_data[3][i][s] = src_ptr3[i * 4 + s]; - in_data[4][i][s] = src_ptr4[i * 4 + s]; - in_data[5][i][s] = src_ptr5[i * 4 + s]; - in_data[6][i][s] = src_ptr6[i * 4 + s]; - in_data[7][i][s] = src_ptr7[i * 4 + s]; - } - } - for (int i = 0; i < 8; ++i) { - for (int j = 0; j < 8; ++j) { - for (int s = 0; s < 4; ++s) { - packed_ptr[(16 * i + j) * 4 + s] = - static_cast(in_data[j][i][s] ^ input_xor); - } - if (sums_ptr) { - for (int s = 0; s < 4; ++s) { - sums_ptr[j] += in_data[j][i][s] ^ input_xor; - } - } - } - } - } else if (packed_rows > 0) { - RUY_DCHECK_LT(packed_rows >> 2, 8); - int i = 0; - for (; i < (packed_rows >> 2); ++i) { - for (int s = 0; s < 4; ++s) { - in_data[0][i][s] = src_ptr0[i * 4 + s]; - in_data[1][i][s] = src_ptr1[i * 4 + s]; - in_data[2][i][s] = src_ptr2[i * 4 + s]; - in_data[3][i][s] = src_ptr3[i * 4 + s]; - in_data[4][i][s] = src_ptr4[i * 4 + s]; - in_data[5][i][s] = src_ptr5[i * 4 + s]; - in_data[6][i][s] = src_ptr6[i * 4 + s]; - in_data[7][i][s] = src_ptr7[i * 4 + s]; - } - } - if (i < ((packed_rows + 3) >> 2)) { - int s = 0; - for (; s < (packed_rows & 3); ++s) { - in_data[0][i][s] = src_ptr0[i * 4 + s]; - in_data[1][i][s] = src_ptr1[i * 4 + s]; - in_data[2][i][s] = src_ptr2[i * 4 + s]; - in_data[3][i][s] = src_ptr3[i * 4 + s]; - in_data[4][i][s] = src_ptr4[i * 4 + s]; - in_data[5][i][s] = src_ptr5[i * 4 + s]; - in_data[6][i][s] = src_ptr6[i * 4 + s]; - in_data[7][i][s] = src_ptr7[i * 4 + s]; - } - RUY_DCHECK_LE(s, 4); - for (; s < 4; ++s) { - for (int j = 0; j < 8; ++j) { - in_data[j][i][s] = zero_point; - } - } - ++i; - } - // We do not care what goes into the trailing buffer, but we want - // in_data[...] ^ input_xor == 0 for irrelevant values in the summation. - // - // It might prove better in optimized code to pad uniformly with - // zero_point, and compensate by initializing the summations with the - // compensating offset, effectively - // ((input_xor - zero_point) ^ input_xor) * - // 4 * (8 - ((packed_rows + 3) >> 2)). - for (; i < 8; ++i) { - for (int s = 0; s < 4; ++s) { - for (int j = 0; j < 8; ++j) { - in_data[j][i][s] = input_xor; - } - } - } - // We loop through [0, 8) rather than [0, (packed_rows + 3) >> 2), since - // that emulates what we might do in fully-optimized code. - if (sums_ptr) { - for (int i = 0; i < 8; ++i) { - for (int j = 0; j < 8; ++j) { - for (int s = 0; s < 4; ++s) { - trailing_buf[(16 * i + j) * 4 + s] = - static_cast(in_data[j][i][s] ^ input_xor); - sums_ptr[j] += in_data[j][i][s] ^ input_xor; - } - } - } - } else { - for (int i = 0; i < 8; ++i) { - for (int j = 0; j < 8; ++j) { - for (int s = 0; s < 4; ++s) { - trailing_buf[(16 * i + j) * 4 + s] = - static_cast(in_data[j][i][s] ^ input_xor); - } - } - } - } - } - - packed_ptr += 16 * 8 * 4; - src_ptr0 += src_inc0; - src_ptr1 += src_inc1; - src_ptr2 += src_inc2; - src_ptr3 += src_inc3; - src_ptr4 += src_inc4; - src_ptr5 += src_inc5; - src_ptr6 += src_inc6; - src_ptr7 += src_inc7; - } - } -} - -inline void HalfPackFloatAvxVnni(const float* src_ptr, const float* zerobuf, - int src_stride, int remaining_src_cols, - int src_rows, float* packed_ptr, - float* trailing_buf) { - float in_data[8][8]; - - const float* src_ptr0 = src_ptr; - const float* src_ptr1 = src_ptr0 + src_stride; - const float* src_ptr2 = src_ptr1 + src_stride; - const float* src_ptr3 = src_ptr2 + src_stride; - const float* src_ptr4 = src_ptr3 + src_stride; - const float* src_ptr5 = src_ptr4 + src_stride; - const float* src_ptr6 = src_ptr5 + src_stride; - const float* src_ptr7 = src_ptr6 + src_stride; - std::int64_t src_inc0 = 8; - std::int64_t src_inc1 = 8; - std::int64_t src_inc2 = 8; - std::int64_t src_inc3 = 8; - std::int64_t src_inc4 = 8; - std::int64_t src_inc5 = 8; - std::int64_t src_inc6 = 8; - std::int64_t src_inc7 = 8; - if (remaining_src_cols < 8) { - if (remaining_src_cols <= 0) { - src_ptr0 = zerobuf; - src_inc0 = 0; - } - if (remaining_src_cols <= 1) { - src_ptr1 = zerobuf; - src_inc1 = 0; - } - if (remaining_src_cols <= 2) { - src_ptr2 = zerobuf; - src_inc2 = 0; - } - if (remaining_src_cols <= 3) { - src_ptr3 = zerobuf; - src_inc3 = 0; - } - if (remaining_src_cols <= 4) { - src_ptr4 = zerobuf; - src_inc4 = 0; - } - if (remaining_src_cols <= 5) { - src_ptr5 = zerobuf; - src_inc5 = 0; - } - if (remaining_src_cols <= 6) { - src_ptr6 = zerobuf; - src_inc6 = 0; - } - src_ptr7 = zerobuf; - src_inc7 = 0; - } - - for (int k = 0; k < src_rows; k += 16) { - for (int m = 0; m < 2; ++m) { - const int packed_rows = src_rows - k - 8 * m; - // Effectively, - // packed_rows = std::max(0, std::min(8, src_rows - k - 8 * m)); - // but treat each case separately. - if (packed_rows > 7) { - for (int i = 0; i < 8; ++i) { - in_data[0][i] = src_ptr0[i]; - in_data[1][i] = src_ptr1[i]; - in_data[2][i] = src_ptr2[i]; - in_data[3][i] = src_ptr3[i]; - in_data[4][i] = src_ptr4[i]; - in_data[5][i] = src_ptr5[i]; - in_data[6][i] = src_ptr6[i]; - in_data[7][i] = src_ptr7[i]; - } - for (int i = 0; i < 8; ++i) { - for (int j = 0; j < 8; ++j) { - packed_ptr[16 * i + j] = in_data[j][i]; - } - } - } else if (packed_rows > 0) { - for (int i = 0; i < packed_rows; ++i) { - in_data[0][i] = src_ptr0[i]; - in_data[1][i] = src_ptr1[i]; - in_data[2][i] = src_ptr2[i]; - in_data[3][i] = src_ptr3[i]; - in_data[4][i] = src_ptr4[i]; - in_data[5][i] = src_ptr5[i]; - in_data[6][i] = src_ptr6[i]; - in_data[7][i] = src_ptr7[i]; - } - for (int i = packed_rows; i < 8; ++i) { - in_data[0][i] = 0.0f; - in_data[1][i] = 0.0f; - in_data[2][i] = 0.0f; - in_data[3][i] = 0.0f; - in_data[4][i] = 0.0f; - in_data[5][i] = 0.0f; - in_data[6][i] = 0.0f; - in_data[7][i] = 0.0f; - } - // We loop through [0, 7) rather than [0, packed_rows), since that - // emulates what we might do in fully-optimized code. - for (int i = 0; i < 7; ++i) { - for (int j = 0; j < 8; ++j) { - trailing_buf[16 * i + j] = in_data[j][i]; - } - } - } - - packed_ptr += 16 * 8; - src_ptr0 += src_inc0; - src_ptr1 += src_inc1; - src_ptr2 += src_inc2; - src_ptr3 += src_inc3; - src_ptr4 += src_inc4; - src_ptr5 += src_inc5; - src_ptr6 += src_inc6; - src_ptr7 += src_inc7; - } - } -} - -inline void ZeroHalfFloatAvxVnni(int src_rows, float* packed_ptr) { - const int non_trailing_rows = src_rows & ~7; - for (int k = 0; k < non_trailing_rows; ++k) { - for (int j = 0; j < 8; ++j) { - packed_ptr[j] = 0.0f; - } - packed_ptr += 16; - } -} - -} // namespace. - -// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. -// Optimization is not finished. In particular the dimensions of the kernel -// blocks can be changed as desired. -// -// When removing this comment, update profiling label below. -void Pack8bitAvxVnni(const std::int8_t* src_ptr, std::int8_t input_xor, - const std::int8_t* zerobuf, int src_stride, - int remaining_src_cols, int src_rows, - std::int8_t* packed_ptr, std::int32_t* sums_ptr) { - profiler::ScopeLabel label("Pack kAvxVnni 8bit (UNFINISHED)"); - - // Each packed block is 4*16, and there are normally 8. The trailing block is - // only slightly shorter. - std::int8_t trailing_buf[8 * 16 * 4]; - memset(trailing_buf, 0, 8 * 16 * 4 * sizeof(std::int8_t)); - - std::int32_t* second_sums_ptr = sums_ptr ? sums_ptr + 8 : nullptr; - if (remaining_src_cols > 8) { - HalfPack8bitAvxVnni(src_ptr, input_xor, zerobuf, src_stride, - remaining_src_cols, src_rows, packed_ptr, sums_ptr, - trailing_buf); - HalfPack8bitAvxVnni(src_ptr + src_stride * 8, input_xor, zerobuf, - src_stride, remaining_src_cols - 8, src_rows, - packed_ptr + 8 * 4, second_sums_ptr, - trailing_buf + 8 * 4); - } else { - HalfPack8bitAvxVnni(src_ptr, input_xor, zerobuf, src_stride, - remaining_src_cols, src_rows, packed_ptr, sums_ptr, - trailing_buf); - ZeroHalf8bitAvxVnni(src_rows, zerobuf[0] ^ input_xor, packed_ptr + 8 * 4); - // The kernel may not need the second half-blocks sums to be set. - if (second_sums_ptr) { - for (int i = 0; i < 8; ++i) { - second_sums_ptr[i] = (zerobuf[0] ^ input_xor) * ((src_rows + 3) & ~3); - } - } - } - const bool trailing_data = (src_rows & 31) > 0; - // If the number of source rows is not a multiple of 32, there will be data in - // the trailing buffer, - if (trailing_data > 0) { - const int non_trailing_rows = src_rows & ~31; - // Destination "rows" are padded to next highest multiple of 4. - const int dst_rows = (src_rows + 3) & ~3; - const int trailing_rows = dst_rows - non_trailing_rows; - memcpy(packed_ptr + 16 * non_trailing_rows, trailing_buf, - 16 * trailing_rows * sizeof(std::int8_t)); - } -} - -// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. -// Optimization is not finished. In particular the dimensions of the kernel -// blocks can be changed as desired. -// -// When removing this comment, update profiling label below. -void PackFloatAvxVnni(const float* src_ptr, const float* zerobuf, - int src_stride, int remaining_src_cols, int src_rows, - float* packed_ptr) { - profiler::ScopeLabel label("Pack kAvxVnni float (UNFINISHED)"); - float trailing_buf[7 * 16]; - if (remaining_src_cols > 8) { - HalfPackFloatAvxVnni(src_ptr, zerobuf, src_stride, remaining_src_cols, - src_rows, packed_ptr, trailing_buf); - HalfPackFloatAvxVnni(src_ptr + src_stride * 8, zerobuf, src_stride, - remaining_src_cols - 8, src_rows, packed_ptr + 8, - trailing_buf + 8); - } else { - memset(trailing_buf, 0, sizeof(trailing_buf)); - HalfPackFloatAvxVnni(src_ptr, zerobuf, src_stride, remaining_src_cols, - src_rows, packed_ptr, trailing_buf); - ZeroHalfFloatAvxVnni(src_rows, packed_ptr + 8); - } - const int trailing_rows = src_rows & 7; - if (trailing_rows > 0) { - const int non_trailing_rows = src_rows & ~7; - memcpy(packed_ptr + 16 * non_trailing_rows, trailing_buf, - 16 * trailing_rows * sizeof(float)); - } -} - -#endif // RUY_PLATFORM(AVX_VNNI) && RUY_OPT_ENABLED(RUY_OPT_INTRINSICS) - -} // namespace ruy diff --git a/pack_common.h b/pack_common.h deleted file mode 100644 index 0fe2797..0000000 --- a/pack_common.h +++ /dev/null @@ -1,246 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// # What is "packing"? -// -// Before feeding data to the gemm kernels (the parts of Ruy that do lots -// of multiply-add operations), Ruy first performs a data transformation (which -// we call "packing") on the input matrices. This transformation has two main -// goals: -// - rearrange data into blocks that are a convenient size/layout for the gemm -// kernels to consume. This helps make the memory access pattern of the gemm -// kernel simpler and more contiguous, and puts the data in a layout most -// convenient for specific arithmetic instructions in the gemm kernel. -// - compute row/column sums needed for handling quantization with non-symmetric -// zero points. -// -// # Simplified algorithmic analysis of packing -// -// Packing is a relatively simple transformation which does a small constant -// amount of work on each element of an input matrix, and hence for an NxM -// matrix performs O(N*M) work. If N and M are of the same order, then this is -// O(N^2) work. -// -// A NxKxM matrix multiplication requires N*K*M multiply-accumulate operations. -// Note that if N, K, and M are all the same order, then the number of -// multiply-accumulate operations is O(N^3). -// -// Thus, the O(N^2) cost of packing is small compared to the O(N^3) work, in the -// case of all dimensions being roughly the same order. -// -// # Packing cost can be significant -// -// When matrix * matrix multiplications begin to look more like matrix * vector -// multiplications, packing cost can become significant. We sometimes call these -// cases "gemv-like". -// -// Continuing the algorithmic analysis above, if we consider a case where an -// NxKxM matrix multiplication has either N = O(1) or M = O(1), then the -// situation is different. In this case, the multiply-accumulate work is only -// quadratic, so the quadratic cost of packing can be come significant. -// -// Another way to say this is that the cost of packing an input matrix (either -// the LHS or RHS) is amortized across the non-depth dimension of the opposite -// input matrix. Thus, when the LHS has very few rows or the RHS has very few -// columns, the cost of packing the opposite input matrix can become -// significant. -// -// As a rough rule of thumb, the cost of packing starts to become significant -// when either N or M is below 32 (and other dimensions are hundreds), with very -// significant packing costs at 8 or below. This varies by data type, Path, and -// tuning, so these numbers are only rough guides. -// -// One practical use case that is affected by this is inference of -// fully connected neural network layers with a low batch size. The weight -// matrix (which is a constant for inference) is the one affected by significant -// packing cost. -// -// Ruy provides an API in ruy_advanced.h for advanced users to pre-pack -// input matrices that are affected by significant packing costs. -// -// # Implementation notes -// -// Ruy's packing routines always operate on a range of columns and can be -// applied to either the LHS or RHS. This is possible because Ruy internally -// implements a TrMul, so the accumulation along depth is done along columns of -// both the LHS and RHS (whereas for a normal Mul the accumulation along depth -// for the LHS is along rows). As another example, we are always computing -// column sums for quantization (and never row sums, since the LHS is -// transposed). - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_PACK_COMMON_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_PACK_COMMON_H_ - -#include - -#include "check_macros.h" -#include "common.h" -#include "internal_matrix.h" -#include "matrix.h" -#include "opt_set.h" -#include "path.h" -#include "platform.h" -#include "profiler/instrumentation.h" -#include "tune.h" - -namespace ruy { - -template -struct PackedTypeImpl { - using Type = Scalar; -}; - -#if RUY_PLATFORM(NEON_32) -struct PackParams8bit { - const void* src_ptr0; - const void* src_ptr1; - const void* src_ptr2; - const void* src_ptr3; - const std::int32_t* sums_ptr; - const std::int8_t* packed_ptr; - int src_inc0; - int src_inc1; - int src_inc2; - int src_inc3; - int src_rows; - int src_zero_point; - int input_xor; -}; - -inline void MakePackParams8bit(const void* src_ptr0, const void* src_ptr1, - const void* src_ptr2, const void* src_ptr3, - const std::int32_t* sums_ptr, - const std::int8_t* packed_ptr, int src_inc0, - int src_inc1, int src_inc2, int src_inc3, - int src_rows, int src_zero_point, int input_xor, - PackParams8bit* params) { - params->src_ptr0 = src_ptr0; - params->src_ptr1 = src_ptr1; - params->src_ptr2 = src_ptr2; - params->src_ptr3 = src_ptr3; - params->sums_ptr = sums_ptr; - params->packed_ptr = packed_ptr; - params->src_inc0 = src_inc0; - params->src_inc1 = src_inc1; - params->src_inc2 = src_inc2; - params->src_inc3 = src_inc3; - params->src_rows = src_rows; - params->src_zero_point = src_zero_point; - params->input_xor = input_xor; -} -#endif - -#if RUY_PLATFORM(NEON) -template <> -struct PackedTypeImpl { - using Type = std::int8_t; -}; -template <> -struct PackedTypeImpl { - using Type = std::int8_t; -}; -#elif RUY_PLATFORM(X86) -template <> -struct PackedTypeImpl { - using Type = std::int8_t; -}; -template <> -struct PackedTypeImpl { - using Type = std::int8_t; -}; -template <> -struct PackedTypeImpl { - using Type = std::int8_t; -}; -template <> -struct PackedTypeImpl { - using Type = std::int8_t; -}; -#endif - -template -using PackedType = typename PackedTypeImpl::Type; - -template -PackedScalar Pack(Scalar x) { - return x - SymmetricZeroPoint() + SymmetricZeroPoint(); -} - -template -struct PackImpl {}; - -#define RUY_INHERIT_PACK(PARENT, CHILD) \ - template \ - struct PackImpl \ - : PackImpl { \ - }; - -template -struct PackImpl { - static void Run(Tuning, const Matrix& src_matrix, - PackedMatrix* packed_matrix, int start_col, - int end_col) { - profiler::ScopeLabel label("Pack (generic)"); - RUY_DCHECK_EQ((end_col - start_col) % FixedKernelLayout::kCols, 0); - SumsType* sums = packed_matrix->sums; - for (int col = start_col; col < end_col; col++) { - SumsType accum = 0; - for (int row = 0; row < packed_matrix->layout.rows; row++) { - PackedScalar packed_val; - if (col < src_matrix.layout.cols && row < src_matrix.layout.rows) { - packed_val = Pack(Element(src_matrix, row, col)); - } else { - packed_val = packed_matrix->zero_point; - } - accum += packed_val; - *ElementPtr(packed_matrix, row, col) = packed_val; - } - if (sums) { - sums[col] = accum; - } - } - } -}; - -#if RUY_PLATFORM(NEON) -RUY_INHERIT_PACK(Path::kStandardCpp, Path::kNeon) -RUY_INHERIT_PACK(Path::kNeon, Path::kNeonDotprod) -#elif RUY_PLATFORM(X86) -RUY_INHERIT_PACK(Path::kStandardCpp, Path::kSse42) -RUY_INHERIT_PACK(Path::kSse42, Path::kAvx2) -RUY_INHERIT_PACK(Path::kAvx2, Path::kAvx512) -RUY_INHERIT_PACK(Path::kAvx512, Path::kAvxVnni) -#endif - -// Main entry point for packing. -template -void RunPack(Tuning tuning, const DMatrix& src_matrix, PMatrix* packed_matrix, - int start_col, int end_col) { - using SumsType = typename PackedMatrix::SumsType; - Matrix src = ToMatrix(src_matrix); - PackedMatrix packed = - ToPackedMatrix(*packed_matrix); - PackImpl::Run( - tuning, src, &packed, start_col, end_col); -} - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_PACK_COMMON_H_ diff --git a/pack_sse42.cc b/pack_sse42.cc deleted file mode 100644 index ca59dc7..0000000 --- a/pack_sse42.cc +++ /dev/null @@ -1,471 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include - -#include "check_macros.h" -#include "matrix.h" -#include "opt_set.h" -#include "pack.h" -#include "path.h" -#include "platform.h" -#include "profiler/instrumentation.h" - -#if RUY_PLATFORM(SSE42) && RUY_OPT_ENABLED(RUY_OPT_INTRINSICS) -#include // IWYU pragma: keep -#endif - -namespace ruy { - -#if !(RUY_PLATFORM(SSE42) && RUY_OPT_ENABLED(RUY_OPT_ASM)) - -void Pack8bitSse42(const std::int8_t* src_ptr, std::int8_t input_xor, - const std::int8_t* zerobuf, int src_stride, - int remaining_src_cols, int src_rows, - std::int8_t* packed_ptr, std::int32_t* sums_ptr) { - // CPU-ID-based checks should disable the path that would reach this point. - RUY_DCHECK(false); -} - -void PackFloatSse42(const float* src_ptr, const float* zerobuf, int src_stride, - int remaining_src_cols, int src_rows, float* packed_ptr) { - // CPU-ID-based checks should disable the path that would reach this point. - RUY_DCHECK(false); -} - -#else // RUY_PLATFORM(SSE42) && RUY_OPT_ENABLED(RUY_OPT_ASM) - -// The first int8_t template parameter is arbitrary: this routine is common to -// all 8-bit source matrix types. -using PackImpl8bitSse42 = - PackImpl, - std::int8_t, std::int8_t, std::int32_t>; - -using PackImplFloatSse42 = - PackImpl, float, - float, float>; - -namespace { - -inline void Pack8bitSse42Packer(const std::int8_t* src_ptr, - std::int8_t input_xor, - const std::int8_t* zerobuf, int src_stride, - int remaining_src_cols, int src_rows, - std::int8_t* packed_ptr, std::int32_t* sums_ptr, - std::int8_t* trailing_buf) { - using Layout = PackImpl8bitSse42::Layout; - RUY_DCHECK_EQ(Layout::kCols, 8); - RUY_DCHECK_EQ(Layout::kRows, 4); - // Each Layout::Rows is 4 contiguous input, contiguous packed elements. - // We process 8 of these chunks at a time, padding short input chunks. - constexpr int kNumRowChunks = 8; - constexpr int kNumChunkedSrcRows = kNumRowChunks * Layout::kRows; - - std::int8_t in_data[Layout::kCols][kNumRowChunks][Layout::kRows]; - - const std::int8_t* src_ptr0 = src_ptr; - const std::int8_t* src_ptr1 = src_ptr0 + src_stride; - const std::int8_t* src_ptr2 = src_ptr1 + src_stride; - const std::int8_t* src_ptr3 = src_ptr2 + src_stride; - const std::int8_t* src_ptr4 = src_ptr3 + src_stride; - const std::int8_t* src_ptr5 = src_ptr4 + src_stride; - const std::int8_t* src_ptr6 = src_ptr5 + src_stride; - const std::int8_t* src_ptr7 = src_ptr6 + src_stride; - std::int64_t src_inc0 = kNumChunkedSrcRows; - std::int64_t src_inc1 = kNumChunkedSrcRows; - std::int64_t src_inc2 = kNumChunkedSrcRows; - std::int64_t src_inc3 = kNumChunkedSrcRows; - std::int64_t src_inc4 = kNumChunkedSrcRows; - std::int64_t src_inc5 = kNumChunkedSrcRows; - std::int64_t src_inc6 = kNumChunkedSrcRows; - std::int64_t src_inc7 = kNumChunkedSrcRows; - // Handle cases where source does not have Layout::kCols (8) columns. - if (remaining_src_cols < 8) { - if (remaining_src_cols <= 0) { - src_ptr0 = zerobuf; - src_inc0 = 0; - } - if (remaining_src_cols <= 1) { - src_ptr1 = zerobuf; - src_inc1 = 0; - } - if (remaining_src_cols <= 2) { - src_ptr2 = zerobuf; - src_inc2 = 0; - } - if (remaining_src_cols <= 3) { - src_ptr3 = zerobuf; - src_inc3 = 0; - } - if (remaining_src_cols <= 4) { - src_ptr4 = zerobuf; - src_inc4 = 0; - } - if (remaining_src_cols <= 5) { - src_ptr5 = zerobuf; - src_inc5 = 0; - } - if (remaining_src_cols <= 6) { - src_ptr6 = zerobuf; - src_inc6 = 0; - } - src_ptr7 = zerobuf; - src_inc7 = 0; - } - - const std::int8_t zero_point = zerobuf[0]; - - if (sums_ptr) { - // i: Layout::kCols. - for (int i = 0; i < 8; ++i) { - sums_ptr[i] = 0; - } - } - - // The overall packing effectively pads the source rows to - // (src_rows + 63) & ~63. The iteration over k may skip when m=1, and then we - // only pack for (src_rows + 31) & ~31. When there is an incomplete - // destination block, this is stored into trailing_buf instead of packed_ptr. - for (int k = 0; k < src_rows; k += kNumChunkedSrcRows) { - // Available source rows. - // If this is less than 0 (for m=1), we skip, having filled trailing - // buffer for m=0. Also, if source rows is zero on m=1, then we filled - // exactly to the end of the column in the packed buffer. - const int available_src_rows = src_rows - k; - // Effectively, - // available rows = std::max(0, std::min(8, src_rows - k)); - // treat each case separately. - if (available_src_rows >= kNumChunkedSrcRows) { - // i: chunks, s: Layout::Rows. - for (int i = 0; i < 8; ++i) { - for (int s = 0; s < 4; ++s) { - in_data[0][i][s] = src_ptr0[i * 4 + s]; - in_data[1][i][s] = src_ptr1[i * 4 + s]; - in_data[2][i][s] = src_ptr2[i * 4 + s]; - in_data[3][i][s] = src_ptr3[i * 4 + s]; - in_data[4][i][s] = src_ptr4[i * 4 + s]; - in_data[5][i][s] = src_ptr5[i * 4 + s]; - in_data[6][i][s] = src_ptr6[i * 4 + s]; - in_data[7][i][s] = src_ptr7[i * 4 + s]; - } - } - // i: chunks, j: Layout::kCols, s: Layout::Rows. - for (int i = 0; i < 8; ++i) { - for (int j = 0; j < 8; ++j) { - for (int s = 0; s < 4; ++s) { - // 8 * 4 * i is offset for each block, that is - // (Layout::kCols * Layout::kRows * i) - packed_ptr[(8 * i + j) * 4 + s] = in_data[j][i][s] ^ input_xor; - } - if (sums_ptr) { - for (int s = 0; s < 4; ++s) { - sums_ptr[j] += in_data[j][i][s] ^ input_xor; - } - } - } - } - } else if (available_src_rows > 0) { - RUY_DCHECK_LT(available_src_rows, kNumChunkedSrcRows); - int i = 0; - // Consume chunks of 4 rows that are complete. - for (; i < (available_src_rows >> 2); ++i) { - for (int s = 0; s < 4; ++s) { - in_data[0][i][s] = src_ptr0[i * 4 + s]; - in_data[1][i][s] = src_ptr1[i * 4 + s]; - in_data[2][i][s] = src_ptr2[i * 4 + s]; - in_data[3][i][s] = src_ptr3[i * 4 + s]; - in_data[4][i][s] = src_ptr4[i * 4 + s]; - in_data[5][i][s] = src_ptr5[i * 4 + s]; - in_data[6][i][s] = src_ptr6[i * 4 + s]; - in_data[7][i][s] = src_ptr7[i * 4 + s]; - } - } - // Consume any incomplete chunk. - if (i < ((available_src_rows + 3) >> 2)) { - int s = 0; - for (; s < (available_src_rows & 3); ++s) { - in_data[0][i][s] = src_ptr0[i * 4 + s]; - in_data[1][i][s] = src_ptr1[i * 4 + s]; - in_data[2][i][s] = src_ptr2[i * 4 + s]; - in_data[3][i][s] = src_ptr3[i * 4 + s]; - in_data[4][i][s] = src_ptr4[i * 4 + s]; - in_data[5][i][s] = src_ptr5[i * 4 + s]; - in_data[6][i][s] = src_ptr6[i * 4 + s]; - in_data[7][i][s] = src_ptr7[i * 4 + s]; - } - RUY_DCHECK_LE(s, 4); - for (; s < 4; ++s) { - // j: Layout::kCols. - for (int j = 0; j < 8; ++j) { - in_data[j][i][s] = zero_point; - } - } - ++i; - } - // We do not care what goes into the trailing buffer, but we want - // in_data[...] ^ input_xor == 0 for irrelevant values in the summation. - // - // It might prove better in optimized code to pad uniformly with - // zero_point, and compensate by initializing the summations with the - // compensating offset, effectively - // ((input_xor - zero_point) ^ input_xor) * - // 4 * (8 - ((available_src_rows + 3) >> 2)). - for (; i < 8; ++i) { - for (int s = 0; s < 4; ++s) { - for (int j = 0; j < 8; ++j) { - in_data[j][i][s] = input_xor; - } - } - } - // We loop through [0, 8) rather than - // [0, (available_src_rows + 3) >> 2), since that emulates what we might - // do in fully-optimized code. - // - // i: chunks, j: Layout::kCols, s: Layout::Rows. - if (sums_ptr) { - for (int i = 0; i < 8; ++i) { - for (int j = 0; j < 8; ++j) { - for (int s = 0; s < 4; ++s) { - trailing_buf[(8 * i + j) * 4 + s] = in_data[j][i][s] ^ input_xor; - sums_ptr[j] = sums_ptr[j] + (in_data[j][i][s] ^ input_xor); - } - } - } - } else { - for (int i = 0; i < 8; ++i) { - for (int j = 0; j < 8; ++j) { - for (int s = 0; s < 4; ++s) { - trailing_buf[(8 * i + j) * 4 + s] = in_data[j][i][s] ^ input_xor; - } - } - } - } - } - - packed_ptr += 8 * kNumChunkedSrcRows; - src_ptr0 += src_inc0; - src_ptr1 += src_inc1; - src_ptr2 += src_inc2; - src_ptr3 += src_inc3; - src_ptr4 += src_inc4; - src_ptr5 += src_inc5; - src_ptr6 += src_inc6; - src_ptr7 += src_inc7; - } -} - -inline void PackFloatSse42Packer(const float* src_ptr, const float* zerobuf, - int src_stride, int remaining_src_cols, - int src_rows, float* packed_ptr, - float* trailing_buf) { - using Layout = PackImplFloatSse42::Layout; - RUY_DCHECK_EQ(Layout::kCols, 8); - RUY_DCHECK_EQ(Layout::kRows, 1); - - // This packing amounts to tranposition of 8x8 blocks. - static constexpr int kPackCols = 8; // Source cols packed together. - static constexpr int kPackRows = 8; // Short input is padded. - - float in_data[kPackCols][kPackRows]; - - const float* src_ptr0 = src_ptr; - const float* src_ptr1 = src_ptr0 + src_stride; - const float* src_ptr2 = src_ptr1 + src_stride; - const float* src_ptr3 = src_ptr2 + src_stride; - const float* src_ptr4 = src_ptr3 + src_stride; - const float* src_ptr5 = src_ptr4 + src_stride; - const float* src_ptr6 = src_ptr5 + src_stride; - const float* src_ptr7 = src_ptr6 + src_stride; - std::int64_t src_inc0 = 8; - std::int64_t src_inc1 = 8; - std::int64_t src_inc2 = 8; - std::int64_t src_inc3 = 8; - std::int64_t src_inc4 = 8; - std::int64_t src_inc5 = 8; - std::int64_t src_inc6 = 8; - std::int64_t src_inc7 = 8; - // Handle cases where source does not have kPackDim (8) columns. - if (remaining_src_cols < kPackCols) { - if (remaining_src_cols <= 0) { - src_ptr0 = zerobuf; - src_inc0 = 0; - } - if (remaining_src_cols <= 1) { - src_ptr1 = zerobuf; - src_inc1 = 0; - } - if (remaining_src_cols <= 2) { - src_ptr2 = zerobuf; - src_inc2 = 0; - } - if (remaining_src_cols <= 3) { - src_ptr3 = zerobuf; - src_inc3 = 0; - } - if (remaining_src_cols <= 4) { - src_ptr4 = zerobuf; - src_inc4 = 0; - } - if (remaining_src_cols <= 5) { - src_ptr5 = zerobuf; - src_inc5 = 0; - } - if (remaining_src_cols <= 6) { - src_ptr6 = zerobuf; - src_inc6 = 0; - } - src_ptr7 = zerobuf; - src_inc7 = 0; - } - - for (int k = 0; k < src_rows; k += kPackRows) { - const int available_src_rows = src_rows - k; - // Effectively, - // available_src_rows = std::max(0, std::min(kPackDim, src_rows - k)); - // but treat each case separately. - if (available_src_rows >= kPackRows) { - for (int i = 0; i < 8; ++i) { - in_data[0][i] = src_ptr0[i]; - in_data[1][i] = src_ptr1[i]; - in_data[2][i] = src_ptr2[i]; - in_data[3][i] = src_ptr3[i]; - in_data[4][i] = src_ptr4[i]; - in_data[5][i] = src_ptr5[i]; - in_data[6][i] = src_ptr6[i]; - in_data[7][i] = src_ptr7[i]; - } - for (int i = 0; i < 8; ++i) { - for (int j = 0; j < 8; ++j) { - packed_ptr[8 * i + j] = in_data[j][i]; - } - } - } else if (available_src_rows > 0) { - for (int i = 0; i < available_src_rows; ++i) { - in_data[0][i] = src_ptr0[i]; - in_data[1][i] = src_ptr1[i]; - in_data[2][i] = src_ptr2[i]; - in_data[3][i] = src_ptr3[i]; - in_data[4][i] = src_ptr4[i]; - in_data[5][i] = src_ptr5[i]; - in_data[6][i] = src_ptr6[i]; - in_data[7][i] = src_ptr7[i]; - } - for (int i = available_src_rows; i < kPackRows; ++i) { - in_data[0][i] = 0.0f; - in_data[1][i] = 0.0f; - in_data[2][i] = 0.0f; - in_data[3][i] = 0.0f; - in_data[4][i] = 0.0f; - in_data[5][i] = 0.0f; - in_data[6][i] = 0.0f; - in_data[7][i] = 0.0f; - } - // We loop through [0, 7) rather than [0, packed_rows), since that - // emulates what we might do in fully-optimized code. - // i: (kPackRows - 1), j: kPackCols. - for (int i = 0; i < 7; ++i) { - for (int j = 0; j < 8; ++j) { - trailing_buf[kPackRows * i + j] = in_data[j][i]; - } - } - } - - packed_ptr += kPackRows * kPackCols; - src_ptr0 += src_inc0; - src_ptr1 += src_inc1; - src_ptr2 += src_inc2; - src_ptr3 += src_inc3; - src_ptr4 += src_inc4; - src_ptr5 += src_inc5; - src_ptr6 += src_inc6; - src_ptr7 += src_inc7; - } -} - -} // namespace. - -// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. -// Optimization is not finished. In particular the dimensions of the kernel -// blocks can be changed as desired. -// -// When removing this comment, update profiling label below. -void Pack8bitSse42(const std::int8_t* src_ptr, std::int8_t input_xor, - const std::int8_t* zerobuf, int src_stride, - int remaining_src_cols, int src_rows, - std::int8_t* packed_ptr, std::int32_t* sums_ptr) { - profiler::ScopeLabel label("Pack kSse42 8bit (UNFINISHED)"); - - using Layout = PackImpl8bitSse42::Layout; - RUY_DCHECK_EQ(Layout::kCols, 8); - RUY_DCHECK_EQ(Layout::kRows, 4); - - // Each Layout::Rows is 4 contiguous input, contiguous packed elements. - // We process 8 of these chunks at a time, padding short input chunks. - static constexpr int kNumRowChunks = 8; // Short input is padded. - - // Each packed block is 4*8, and there are normally 8. The trailing block is - // only slightly shorter. - constexpr int kTrailingBufSize = - kNumRowChunks * Layout::kCols * Layout::kRows; - std::int8_t trailing_buf[kTrailingBufSize]; - memset(trailing_buf, 0, kTrailingBufSize * sizeof(std::int8_t)); - - Pack8bitSse42Packer(src_ptr, input_xor, zerobuf, src_stride, - remaining_src_cols, src_rows, packed_ptr, sums_ptr, - trailing_buf); - - constexpr int kChunkedRowMask = kNumRowChunks * Layout::kRows - 1; - const bool trailing_data = (src_rows & kChunkedRowMask) > 0; - // If the number of source rows is not a multiple of kChunkedRowMask, there - // will be data in the trailing buffer, - if (trailing_data > 0) { - const int non_trailing_rows = src_rows & ~kChunkedRowMask; - // Destination "rows" are padded to next highest multiple of Layout::kRows. - const int dst_rows = (src_rows + 3) & ~3; - const int trailing_rows = dst_rows - non_trailing_rows; - memcpy(packed_ptr + Layout::kCols * non_trailing_rows, trailing_buf, - Layout::kCols * trailing_rows * sizeof(std::int8_t)); - } -} - -// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. -// Optimization is not finished. In particular the dimensions of the kernel -// blocks can be changed as desired. -// -// When removing this comment, update profiling label below. -void PackFloatSse42(const float* src_ptr, const float* zerobuf, int src_stride, - int remaining_src_cols, int src_rows, float* packed_ptr) { - profiler::ScopeLabel label("Pack kSse42 float (UNFINISHED)"); - static constexpr int kPackCols = 8; // Source cols packed together. - static constexpr int kPackRows = 8; // Short input is padded. - float trailing_buf[(kPackRows - 1) * kPackCols]; - if (remaining_src_cols < 8) { - memset(trailing_buf, 0, sizeof(trailing_buf)); - } - PackFloatSse42Packer(src_ptr, zerobuf, src_stride, remaining_src_cols, - src_rows, packed_ptr, trailing_buf); - - const int trailing_rows = src_rows & (kPackRows - 1); - if (trailing_rows > 0) { - const int non_trailing_rows = src_rows & ~(kPackRows - 1); - memcpy(packed_ptr + kPackCols * non_trailing_rows, trailing_buf, - kPackCols * trailing_rows * sizeof(float)); - } -} - -#endif // RUY_PLATFORM(SSE42) && RUY_OPT_ENABLED(RUY_OPT_INTRINSICS) - -} // namespace ruy diff --git a/pack_x86.h b/pack_x86.h deleted file mode 100644 index 4998fc6..0000000 --- a/pack_x86.h +++ /dev/null @@ -1,461 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// # What is "packing"? -// -// Before feeding data to the gemm kernels (the parts of Ruy that do lots -// of multiply-add operations), Ruy first performs a data transformation (which -// we call "packing") on the input matrices. This transformation has two main -// goals: -// - rearrange data into blocks that are a convenient size/layout for the gemm -// kernels to consume. This helps make the memory access pattern of the gemm -// kernel simpler and more contiguous, and puts the data in a layout most -// convenient for specific arithmetic instructions in the gemm kernel. -// - compute row/column sums needed for handling quantization with non-symmetric -// zero points. -// -// # Simplified algorithmic analysis of packing -// -// Packing is a relatively simple transformation which does a small constant -// amount of work on each element of an input matrix, and hence for an NxM -// matrix performs O(N*M) work. If N and M are of the same order, then this is -// O(N^2) work. -// -// A NxKxM matrix multiplication requires N*K*M multiply-accumulate operations. -// Note that if N, K, and M are all the same order, then the number of -// multiply-accumulate operations is O(N^3). -// -// Thus, the O(N^2) cost of packing is small compared to the O(N^3) work, in the -// case of all dimensions being roughly the same order. -// -// # Packing cost can be significant -// -// When matrix * matrix multiplications begin to look more like matrix * vector -// multiplications, packing cost can become significant. We sometimes call these -// cases "gemv-like". -// -// Continuing the algorithmic analysis above, if we consider a case where an -// NxKxM matrix multiplication has either N = O(1) or M = O(1), then the -// situation is different. In this case, the multiply-accumulate work is only -// quadratic, so the quadratic cost of packing can be come significant. -// -// Another way to say this is that the cost of packing an input matrix (either -// the LHS or RHS) is amortized across the non-depth dimension of the opposite -// input matrix. Thus, when the LHS has very few rows or the RHS has very few -// columns, the cost of packing the opposite input matrix can become -// significant. -// -// As a rough rule of thumb, the cost of packing starts to become significant -// when either N or M is below 32 (and other dimensions are hundreds), with very -// significant packing costs at 8 or below. This varies by data type, Path, and -// tuning, so these numbers are only rough guides. -// -// One practical use case that is affected by this is inference of -// fully connected neural network layers with a low batch size. The weight -// matrix (which is a constant for inference) is the one affected by significant -// packing cost. -// -// Ruy provides an API in ruy_advanced.h for advanced users to pre-pack -// input matrices that are affected by significant packing costs. -// -// # Implementation notes -// -// Ruy's packing routines always operate on a range of columns and can be -// applied to either the LHS or RHS. This is possible because Ruy internally -// implements a TrMul, so the accumulation along depth is done along columns of -// both the LHS and RHS (whereas for a normal Mul the accumulation along depth -// for the LHS is along rows). As another example, we are always computing -// column sums for quantization (and never row sums, since the LHS is -// transposed). - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_PACK_X86_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_PACK_X86_H_ - -#include -#include -#include - -#include "check_macros.h" -#include "common.h" -#include "internal_matrix.h" -#include "matrix.h" -#include "opt_set.h" -#include "pack_common.h" -#include "path.h" -#include "platform.h" -#include "profiler/instrumentation.h" -#include "tune.h" - -namespace ruy { - -#if RUY_PLATFORM(X86) -// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. -// Optimization is not finished. In particular the dimensions of the kernel -// blocks can be changed as desired. -// -// Note that source and zero buffers can be uint8 type, but in the packing -// function are reinterpreted as int8, and are XOR-ed with input_xor. -void Pack8bitSse42(const std::int8_t* src_ptr, std::int8_t input_xor, - const std::int8_t* zerobuf, int src_stride, - int remaining_src_cols, int src_rows, - std::int8_t* packed_ptr, std::int32_t* sums_ptr); - -template -struct PackImpl, Scalar, - std::int8_t, std::int32_t> { - static_assert(std::is_same::value || - std::is_same::value, - ""); - using Layout = FixedKernelLayout; - static constexpr std::int8_t kInputXor = - std::is_same::value ? 0 : 0x80; - - static void Run(Tuning tuning, const Matrix& src_matrix, - PackedMatrix* packed_matrix, int start_col, - int end_col) { - profiler::ScopeLabel label("Pack (SSE 4.2 8-bit)"); - - RUY_DCHECK(IsColMajor(src_matrix.layout)); - RUY_DCHECK(IsColMajor(packed_matrix->layout)); - RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0); - RUY_DCHECK_EQ(start_col % Layout::kCols, 0); - std::int32_t* sums = packed_matrix->sums; - Scalar zerobuf[Layout::kCols * Layout::kRows]; - memset(zerobuf, packed_matrix->zero_point ^ kInputXor, - Layout::kCols * Layout::kRows * sizeof(Scalar)); - for (int block_col = start_col; block_col < end_col; - block_col += Layout::kCols) { - std::int32_t* sums_ptr = sums ? sums + block_col : nullptr; - int src_stride = src_matrix.layout.stride; - const Scalar* src_ptr = src_matrix.data.get() + src_stride * block_col; - int remaining_src_cols = src_matrix.layout.cols - block_col; - - static constexpr int block_col_mask = ~(Layout::kCols - 1); // High bits. - std::int8_t* packed_ptr = - packed_matrix->data + - packed_matrix->layout.stride * (block_col & block_col_mask); - Pack8bitSse42(reinterpret_cast(src_ptr), kInputXor, - reinterpret_cast(zerobuf), src_stride, - remaining_src_cols, src_matrix.layout.rows, packed_ptr, - sums_ptr); - } - } -}; - -// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. -// Optimization is not finished. In particular the dimensions of the kernel -// blocks can be changed as desired. -// -void PackFloatSse42(const float* src_ptr, const float* zerobuf, int src_stride, - int remaining_src_cols, int src_rows, float* packed_ptr); - -template <> -struct PackImpl, float, - float, float> { - using Layout = FixedKernelLayout; - static void Run(Tuning, const Matrix& src_matrix, - PackedMatrix* packed_matrix, int start_col, - int end_col) { - profiler::ScopeLabel label("Pack (SSE 4.2 float)"); - - RUY_DCHECK(IsColMajor(src_matrix.layout)); - RUY_DCHECK(IsColMajor(packed_matrix->layout)); - RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0); - RUY_DCHECK_EQ(start_col % Layout::kCols, 0); - const float zerobuf[Layout::kCols] = { - 0.0f}; // Remainder default inits to 0.0f. - for (int block_col = start_col; block_col < end_col; - block_col += Layout::kCols) { - int src_stride = src_matrix.layout.stride; - const float* src_ptr = src_matrix.data.get() + src_stride * block_col; - int remaining_src_cols = src_matrix.layout.cols - block_col; - - static constexpr int block_col_mask = ~(Layout::kCols - 1); // High bits. - float* packed_ptr = - packed_matrix->data + - packed_matrix->layout.stride * (block_col & block_col_mask); - PackFloatSse42(src_ptr, zerobuf, src_stride, remaining_src_cols, - src_matrix.layout.rows, packed_ptr); - } - } -}; - -// Note that source and zero buffers can be uint8 type, but in the packing -// function are reinterpreted as int8, and are XOR-ed with input_xor. -void Pack8bitAvx2(const std::int8_t* src_ptr, std::int8_t input_xor, - const std::int8_t* zerobuf, int src_stride, - int remaining_src_cols, int src_rows, std::int8_t* packed_ptr, - std::int32_t* sums_ptr); - -template -struct PackImpl, Scalar, - std::int8_t, std::int32_t> { - static_assert(std::is_same::value || - std::is_same::value, - ""); - using Layout = FixedKernelLayout; - static constexpr std::int8_t kInputXor = - std::is_same::value ? 0 : 0x80; - - static void Run(Tuning tuning, const Matrix& src_matrix, - PackedMatrix* packed_matrix, int start_col, - int end_col) { - profiler::ScopeLabel label("Pack (AVX2 8-bit)"); - - RUY_DCHECK(IsColMajor(src_matrix.layout)); - RUY_DCHECK(IsColMajor(packed_matrix->layout)); - RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0); - RUY_DCHECK_EQ(start_col % Layout::kCols, 0); - std::int32_t* sums = packed_matrix->sums; - Scalar zerobuf[Layout::kCols * Layout::kRows]; - memset(zerobuf, packed_matrix->zero_point ^ kInputXor, - Layout::kCols * Layout::kRows * sizeof(Scalar)); - for (int block_col = start_col; block_col < end_col; - block_col += Layout::kCols) { - std::int32_t* sums_ptr = sums ? sums + block_col : nullptr; - int src_stride = src_matrix.layout.stride; - const Scalar* src_ptr = src_matrix.data.get() + src_stride * block_col; - int remaining_src_cols = src_matrix.layout.cols - block_col; - - static constexpr int block_col_mask = ~(Layout::kCols - 1); // High bits. - std::int8_t* packed_ptr = - packed_matrix->data + - packed_matrix->layout.stride * (block_col & block_col_mask); - Pack8bitAvx2(reinterpret_cast(src_ptr), kInputXor, - reinterpret_cast(zerobuf), src_stride, - remaining_src_cols, src_matrix.layout.rows, packed_ptr, - sums_ptr); - } - } -}; - -void PackFloatAvx2(const float* src_ptr, const float* zerobuf, int src_stride, - int remaining_src_cols, int src_rows, float* packed_ptr); - -template <> -struct PackImpl, float, - float, float> { - using Layout = FixedKernelLayout; - static void Run(Tuning, const Matrix& src_matrix, - PackedMatrix* packed_matrix, int start_col, - int end_col) { - profiler::ScopeLabel label("Pack (AVX2 float)"); - - RUY_DCHECK(IsColMajor(src_matrix.layout)); - RUY_DCHECK(IsColMajor(packed_matrix->layout)); - RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0); - RUY_DCHECK_EQ(start_col % Layout::kCols, 0); - const float zerobuf[Layout::kCols] = { - 0.0f}; // Remainder default inits to 0.0f. - for (int block_col = start_col; block_col < end_col; - block_col += Layout::kCols) { - int src_stride = src_matrix.layout.stride; - const float* src_ptr = src_matrix.data.get() + src_stride * block_col; - int remaining_src_cols = src_matrix.layout.cols - block_col; - - static constexpr int block_col_mask = ~(Layout::kCols - 1); // High bits. - float* packed_ptr = - packed_matrix->data + - packed_matrix->layout.stride * (block_col & block_col_mask); - PackFloatAvx2(src_ptr, zerobuf, src_stride, remaining_src_cols, - src_matrix.layout.rows, packed_ptr); - } - } -}; - -// Note that source and zero buffers can be uint8 type, but in the packing -// function are reinterpreted as int8, and are XOR-ed with input_xor. -void Pack8bitAvx512(const std::int8_t* src_ptr, std::int8_t input_xor, - const std::int8_t* zerobuf, int src_stride, - int remaining_src_cols, int src_rows, - std::int8_t* packed_ptr, std::int32_t* sums_ptr); - -template -struct PackImpl, - Scalar, std::int8_t, std::int32_t> { - static_assert(std::is_same::value || - std::is_same::value, - ""); - using Layout = FixedKernelLayout; - static constexpr int kHalfLayoutCols = - 8; // Half the number of cols in a block. - static constexpr std::int8_t kInputXor = - std::is_same::value ? 0 : 0x80; - - static void Run(Tuning tuning, const Matrix& src_matrix, - PackedMatrix* packed_matrix, int start_col, - int end_col) { - profiler::ScopeLabel label("Pack (AVX-512 8-bit)"); - - RUY_DCHECK(IsColMajor(src_matrix.layout)); - RUY_DCHECK(IsColMajor(packed_matrix->layout)); - RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0); - RUY_DCHECK_EQ(start_col % Layout::kCols, 0); - RUY_DCHECK_EQ(kHalfLayoutCols * 2, Layout::kCols); - std::int32_t* sums = packed_matrix->sums; - Scalar zerobuf[kHalfLayoutCols * Layout::kRows]; - memset(zerobuf, packed_matrix->zero_point ^ kInputXor, - kHalfLayoutCols * Layout::kRows * sizeof(Scalar)); - for (int block_col = start_col; block_col < end_col; - block_col += Layout::kCols) { - std::int32_t* sums_ptr = sums ? sums + block_col : nullptr; - int src_stride = src_matrix.layout.stride; - const Scalar* src_ptr = src_matrix.data.get() + src_stride * block_col; - int remaining_src_cols = src_matrix.layout.cols - block_col; - - static constexpr int block_col_mask = ~(Layout::kCols - 1); // High bits. - std::int8_t* packed_ptr = - packed_matrix->data + - packed_matrix->layout.stride * (block_col & block_col_mask); - Pack8bitAvx512(reinterpret_cast(src_ptr), kInputXor, - reinterpret_cast(zerobuf), src_stride, - remaining_src_cols, src_matrix.layout.rows, packed_ptr, - sums_ptr); - } - } -}; - -void PackFloatAvx512(const float* src_ptr, const float* zerobuf, int src_stride, - int remaining_src_cols, int src_rows, float* packed_ptr); - -template <> -struct PackImpl, - float, float, float> { - static void Run(Tuning, const Matrix& src_matrix, - PackedMatrix* packed_matrix, int start_col, - int end_col) { - profiler::ScopeLabel label("Pack (AVX-512 float)"); - using Layout = FixedKernelLayout; - RUY_DCHECK(IsColMajor(src_matrix.layout)); - RUY_DCHECK(IsColMajor(packed_matrix->layout)); - RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0); - RUY_DCHECK_EQ(start_col % Layout::kCols, 0); - const float zerobuf[Layout::kCols] = { - 0.0f}; // Remainder default inits to 0.0f. - for (int block_col = start_col; block_col < end_col; - block_col += Layout::kCols) { - int src_stride = src_matrix.layout.stride; - const float* src_ptr = src_matrix.data.get() + src_stride * block_col; - int remaining_src_cols = src_matrix.layout.cols - block_col; - - static constexpr int block_col_mask = ~(Layout::kCols - 1); // High bits. - float* packed_ptr = - packed_matrix->data + - packed_matrix->layout.stride * (block_col & block_col_mask); - PackFloatAvx512(src_ptr, zerobuf, src_stride, remaining_src_cols, - src_matrix.layout.rows, packed_ptr); - } - } -}; - -// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. -// Optimization is not finished. In particular the dimensions of the kernel -// blocks can be changed as desired. -// -// Note that source and zero buffers can be uint8 type, but in the packing -// function are reinterpreted as int8, and are XOR-ed with input_xor. -void Pack8bitAvxVnni(const std::int8_t* src_ptr, std::int8_t input_xor, - const std::int8_t* zerobuf, int src_stride, - int remaining_src_cols, int src_rows, - std::int8_t* packed_ptr, std::int32_t* sums_ptr); - -template -struct PackImpl, - Scalar, std::int8_t, std::int32_t> { - static_assert(std::is_same::value || - std::is_same::value, - ""); - using Layout = FixedKernelLayout; - static constexpr int kHalfLayoutCols = - 8; // Half the number of cols in a block. - static constexpr std::int8_t kInputXor = - std::is_same::value ? 0 : 0x80; - - static void Run(Tuning tuning, const Matrix& src_matrix, - PackedMatrix* packed_matrix, int start_col, - int end_col) { - profiler::ScopeLabel label("Pack (AVX-512 8-bit)"); - - RUY_DCHECK(IsColMajor(src_matrix.layout)); - RUY_DCHECK(IsColMajor(packed_matrix->layout)); - RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0); - RUY_DCHECK_EQ(start_col % Layout::kCols, 0); - RUY_DCHECK_EQ(kHalfLayoutCols * 2, Layout::kCols); - std::int32_t* sums = packed_matrix->sums; - Scalar zerobuf[kHalfLayoutCols * Layout::kRows]; - memset(zerobuf, packed_matrix->zero_point ^ kInputXor, - kHalfLayoutCols * Layout::kRows * sizeof(Scalar)); - for (int block_col = start_col; block_col < end_col; - block_col += Layout::kCols) { - std::int32_t* sums_ptr = sums ? sums + block_col : nullptr; - int src_stride = src_matrix.layout.stride; - const Scalar* src_ptr = src_matrix.data.get() + src_stride * block_col; - int remaining_src_cols = src_matrix.layout.cols - block_col; - - static constexpr int block_col_mask = ~(Layout::kCols - 1); // High bits. - std::int8_t* packed_ptr = - packed_matrix->data + - packed_matrix->layout.stride * (block_col & block_col_mask); - Pack8bitAvxVnni(reinterpret_cast(src_ptr), kInputXor, - reinterpret_cast(zerobuf), src_stride, - remaining_src_cols, src_matrix.layout.rows, packed_ptr, - sums_ptr); - } - } -}; - -// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. -// Optimization is not finished. In particular the dimensions of the kernel -// blocks can be changed as desired. -// -void PackFloatAvxVnni(const float* src_ptr, const float* zerobuf, - int src_stride, int remaining_src_cols, int src_rows, - float* packed_ptr); - -template <> -struct PackImpl, - float, float, float> { - static void Run(Tuning, const Matrix& src_matrix, - PackedMatrix* packed_matrix, int start_col, - int end_col) { - profiler::ScopeLabel label("Pack (AVX-512 float)"); - - using Layout = FixedKernelLayout; - RUY_DCHECK(IsColMajor(src_matrix.layout)); - RUY_DCHECK(IsColMajor(packed_matrix->layout)); - RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0); - RUY_DCHECK_EQ(start_col % Layout::kCols, 0); - const float zerobuf[Layout::kCols] = { - 0.0f}; // Remainder default inits to 0.0f. - for (int block_col = start_col; block_col < end_col; - block_col += Layout::kCols) { - int src_stride = src_matrix.layout.stride; - const float* src_ptr = src_matrix.data.get() + src_stride * block_col; - int remaining_src_cols = src_matrix.layout.cols - block_col; - - static constexpr int block_col_mask = ~(Layout::kCols - 1); // High bits. - float* packed_ptr = - packed_matrix->data + - packed_matrix->layout.stride * (block_col & block_col_mask); - PackFloatAvxVnni(src_ptr, zerobuf, src_stride, remaining_src_cols, - src_matrix.layout.rows, packed_ptr); - } - } -}; -#endif // RUY_PLATFORM(X86) - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_PACK_X86_H_ diff --git a/path.h b/path.h deleted file mode 100644 index 3ff0c57..0000000 --- a/path.h +++ /dev/null @@ -1,162 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_PATH_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_PATH_H_ - -#include - -#include "platform.h" -#include "size_util.h" - -namespace ruy { - -// A Path is a choice of implementation path, e.g. between reference code -// and optimized code, or between different optimized code paths using different -// instruction sets. -// -// It's important that any symbol that depends on such implementation -// details, is somehow templatized in such a Path, so that different Path values -// yield different symbols, so we never have the situation where a symbols has -// multiple inequivalent definitions based on which code paths are compiled. -// That would be a violation of the ODR (One Definition Rule) which is Undefined -// Behavior, and one of the most serious issues plaguing both Eigen and -// gemmlowp. -// -// This enum is actually a bit-field: aside from kNone, all other values are -// powers of two, thus are one bit each. We define bit-wise operators below -// for this enum. Some places in Ruy accept a Path bit-field where multiple -// Paths may be selected, while some other places require a single Path (i.e. -// just one of the enum values here). Typically, user-facing parts of Ruy -// accept arbitrary bit-fields, allowing the user to compile support for -// multiple paths and to inform Ruy of all the paths that are to be enabled -// at runtime; then, typically in dispatch.h, we internally pick one -// specific path and from there on, internal Ruy code deals with only one -// path. -// -// When a user selects a set of compiled paths, Ruy internally dispatches to the -// "best" one, which typically means the newest optimized instructions for a -// given base architecture (such as ARM). Higher values of this enum correspond -// to "better" code paths within a given base architecture for which Ruy has -// optimized code paths. -// -// Values are reused across architectures. -// Rationale: Scale better to N architectures, it is good to have small values -// both for the compile-time logic to select paths, and when manually spelling -// out Path values, such as when invoking a test or benchmark. -enum class Path : std::uint8_t { - // This is a special null value, representing the absence of any path. - kNone = 0, - // Reference multiplication code. - // The main purpose of this path is to have a very simple standalone Mul - // implementation to check against. - // This path bypasses almost all of Ruy's internal implementation details. - // - // This is intended for testing/development. - kReference = 0x1, - // Standard C++ implementation of Ruy's architecture-specific parts. - // Unlike Path::kReference, this path exercises most of Ruy's internal logic. - // - // This is intended for testing/development. - kStandardCpp = 0x2, - -#if RUY_PLATFORM(ARM) - // ARM architectures. - // - // Optimized path using a widely available subset of ARM NEON instructions. - kNeon = 0x4, - // Optimized path making use of ARM NEON dot product instructions that are - // available on newer ARM cores. - kNeonDotprod = 0x8, -#endif // RUY_PLATFORM(ARM) - -#if RUY_PLATFORM(X86) - // x86 architectures. - // - // TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / - // placeholder. - // Optimization is not finished. In particular the dimensions of the kernel - // blocks can be changed as desired. - // - // Optimized for SSE 4.2. - kSse42 = 0x4, - // Optimized for AVX2. - kAvx2 = 0x8, - // Optimized for AVX-512. - kAvx512 = 0x10, - // TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / - // placeholder. - // Optimization is not finished. In particular the dimensions of the kernel - // blocks can be changed as desired. - // - // Optimized for AVX-VNNI. - kAvxVnni = 0x20, -#endif // RUY_PLATFORM(X86) -}; - -inline constexpr Path operator|(Path p, Path q) { - return static_cast(static_cast(p) | - static_cast(q)); -} - -inline constexpr Path operator&(Path p, Path q) { - return static_cast(static_cast(p) & - static_cast(q)); -} - -inline constexpr Path operator^(Path p, Path q) { - return static_cast(static_cast(p) ^ - static_cast(q)); -} - -inline constexpr Path operator~(Path p) { - return static_cast(~static_cast(p)); -} - -inline Path GetMostSignificantPath(Path path_mask) { - return static_cast(round_down_pot(static_cast(path_mask))); -} - -// ruy::kAllPaths represents all Path's that make sense to on a given -// base architecture. -#ifdef __linux__ -#if RUY_PLATFORM(NEON_64) -constexpr Path kAllPaths = - Path::kReference | Path::kStandardCpp | Path::kNeon | Path::kNeonDotprod; -#elif RUY_PLATFORM(NEON_32) -constexpr Path kAllPaths = Path::kReference | Path::kStandardCpp | Path::kNeon; -#elif RUY_PLATFORM(X86) -constexpr Path kAllPaths = Path::kReference | Path::kStandardCpp | - Path::kSse42 | Path::kAvx2 | Path::kAvx512 | - Path::kAvxVnni; -#else -constexpr Path kAllPaths = Path::kReference | Path::kStandardCpp; -#endif -#else // __linux__ -// We don't know how to do runtime dotprod detection outside of linux for now. -#if RUY_PLATFORM(NEON) -constexpr Path kAllPaths = Path::kReference | Path::kStandardCpp | Path::kNeon; -#elif RUY_PLATFORM(X86) -constexpr Path kAllPaths = Path::kReference | Path::kStandardCpp | - Path::kSse42 | Path::kAvx2 | Path::kAvx512 | - Path::kAvxVnni; -#else -constexpr Path kAllPaths = Path::kReference | Path::kStandardCpp; -#endif -#endif // __linux__ - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_PATH_H_ diff --git a/platform.h b/platform.h deleted file mode 100644 index d86c957..0000000 --- a/platform.h +++ /dev/null @@ -1,156 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_PLATFORM_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_PLATFORM_H_ - -#ifdef __ANDROID_NDK__ -#include -#endif - -#define RUY_PLATFORM(X) ((RUY_DONOTUSEDIRECTLY_##X) != 0) - -// Architecture-level platform detection. -// -// Ruy requires these to be mutually exclusive. - -// Detect x86. -#if defined(__x86_64__) || defined(__i386__) || defined(__i386) || \ - defined(__x86__) || defined(__X86__) || defined(_X86_) || \ - defined(_M_IX86) || defined(_M_X64) -#define RUY_DONOTUSEDIRECTLY_X86 1 -#else -#define RUY_DONOTUSEDIRECTLY_X86 0 -#endif - -// Detect ARM 32-bit. -#ifdef __arm__ -#define RUY_DONOTUSEDIRECTLY_ARM_32 1 -#else -#define RUY_DONOTUSEDIRECTLY_ARM_32 0 -#endif - -// Detect ARM 64-bit. -#ifdef __aarch64__ -#define RUY_DONOTUSEDIRECTLY_ARM_64 1 -#else -#define RUY_DONOTUSEDIRECTLY_ARM_64 0 -#endif - -// Combined ARM. -#define RUY_DONOTUSEDIRECTLY_ARM \ - (RUY_DONOTUSEDIRECTLY_ARM_64 || RUY_DONOTUSEDIRECTLY_ARM_32) - -// Feature and capability platform detection. -// -// These are mostly sub-selections of architectures. - -// Detect NEON. Explicitly avoid emulation, or anything like it, on x86. -#if (defined(__ARM_NEON) || defined(__ARM_NEON__)) && !RUY_PLATFORM(X86) -#define RUY_DONOTUSEDIRECTLY_NEON 1 -#else -#define RUY_DONOTUSEDIRECTLY_NEON 0 -#endif - -// Define ARM 32-bit NEON. -#define RUY_DONOTUSEDIRECTLY_NEON_32 \ - (RUY_DONOTUSEDIRECTLY_NEON && RUY_DONOTUSEDIRECTLY_ARM_32) - -// Define ARM 64-bit NEON. -// Note: NEON is implied by ARM64, so this define is redundant. -// It still allows some conveyance of intent. -#define RUY_DONOTUSEDIRECTLY_NEON_64 \ - (RUY_DONOTUSEDIRECTLY_NEON && RUY_DONOTUSEDIRECTLY_ARM_64) - -// Disable X86 enhancements on __APPLE__ because b/138922878, see comment #8, we -// may only need to disable this on XCode <= 10.2. -// -// Disable when not using Clang-Linux, because too many user issues arise from -// compilation variations. -// -// NOTE: Consider guarding by !defined(__APPLE__) when removing Linux-only -// restriction. -// -// __EMSCRIPTEN__ is checked because the runtime Path resolution can use asm. -// -// The Android NDK logic excludes earlier and very broken versions of intrinsics -// headers. -#if defined(RUY_FORCE_ENABLE_X86_ENHANCEMENTS) || \ - (defined(__clang__) && (__clang_major__ >= 8) && defined(__linux__) && \ - !defined(__EMSCRIPTEN__) && \ - (!defined(__ANDROID_NDK__) || \ - (defined(__NDK_MAJOR__) && (__NDK_MAJOR__ >= 20)))) -#define RUY_DONOTUSEDIRECTLY_X86_ENHANCEMENTS 1 -#else -#define RUY_DONOTUSEDIRECTLY_X86_ENHANCEMENTS 0 -#endif - -// These CPU capabilities will all be true when Skylake, etc, are enabled during -// compilation. -#if RUY_PLATFORM(X86_ENHANCEMENTS) && RUY_PLATFORM(X86) && \ - defined(__AVX512F__) && defined(__AVX512DQ__) && defined(__AVX512CD__) && \ - defined(__AVX512BW__) && defined(__AVX512VL__) -#define RUY_DONOTUSEDIRECTLY_AVX512 1 -#else -#define RUY_DONOTUSEDIRECTLY_AVX512 0 -#endif - -#if RUY_PLATFORM(X86_ENHANCEMENTS) && RUY_PLATFORM(X86) && defined(__AVX2__) -#define RUY_DONOTUSEDIRECTLY_AVX2 1 -#else -#define RUY_DONOTUSEDIRECTLY_AVX2 0 -#endif - -// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. -// Optimization is not finished. In particular the dimensions of the kernel -// blocks can be changed as desired. -// -// Note does not check for LZCNT or POPCNT. -#if defined(RUY_ENABLE_SSE_ENHANCEMENTS) && RUY_PLATFORM(X86_ENHANCEMENTS) && \ - RUY_PLATFORM(X86) && defined(__SSE4_2__) && defined(__FMA__) -#define RUY_DONOTUSEDIRECTLY_SSE42 1 -#else -#define RUY_DONOTUSEDIRECTLY_SSE42 0 -#endif - -// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. -// Optimization is not finished. In particular the dimensions of the kernel -// blocks can be changed as desired. -// -// Note that defined(__AVX512VBMI2__) can be false for compilation with -// -march=cascadelake. -// TODO(b/146646451) Check if we should also gate on defined(__AVX512VBMI2__). -#if defined(RUY_ENABLE_VNNI_ENHANCEMENTS) && RUY_PLATFORM(AVX512) && \ - defined(__AVX512VNNI__) -#define RUY_DONOTUSEDIRECTLY_AVX_VNNI 1 -#else -#define RUY_DONOTUSEDIRECTLY_AVX_VNNI 0 -#endif - -// Detect APPLE. -#ifdef __APPLE__ -#define RUY_DONOTUSEDIRECTLY_APPLE 1 -#else -#define RUY_DONOTUSEDIRECTLY_APPLE 0 -#endif - -// Detect Emscripten, typically Wasm. -#ifdef __EMSCRIPTEN__ -#define RUY_DONOTUSEDIRECTLY_EMSCRIPTEN 1 -#else -#define RUY_DONOTUSEDIRECTLY_EMSCRIPTEN 0 -#endif - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_PLATFORM_H_ diff --git a/pmu.cc b/pmu.cc deleted file mode 100644 index 5c87d73..0000000 --- a/pmu.cc +++ /dev/null @@ -1,281 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "pmu.h" - -#include "check_macros.h" - -#ifdef __linux__ -#include -#include -#include -#include -#include - -#include -#endif - -#include -#include -#include -#include - -namespace ruy { - -// Linux-specific. Not ARM-specific. -#ifdef __linux__ -class PerfEvent { - public: - PerfEvent(std::uint32_t type, std::uint64_t config) { - perf_event_attr pe; - memset(&pe, 0, sizeof(pe)); - pe.size = sizeof(pe); - pe.type = type; - pe.config = config; - pe.disabled = 1; - pe.exclude_kernel = 1; - pe.exclude_hv = 1; - pe.inherit = 1; - fd_ = syscall(__NR_perf_event_open, &pe, 0, -1, -1, 0); - if (fd_ == -1) { - fprintf(stderr, "perf_event_open failed for config 0x%lx\n", - static_cast(config)); - // abort(); - } - } - - ~PerfEvent() { - RUY_CHECK(!started_); - close(fd_); - } - - void Start() { - RUY_CHECK(!started_); - started_ = true; - ioctl(fd_, PERF_EVENT_IOC_RESET, 0); - ioctl(fd_, PERF_EVENT_IOC_ENABLE, 0); - count_at_start_ = Read(); - } - - void Stop() { - RUY_CHECK(started_); - started_ = false; - ioctl(fd_, PERF_EVENT_IOC_DISABLE, 0); - count_at_stop_ = Read(); - } - - std::int64_t Count() const { - RUY_CHECK(!started_); - return count_at_stop_ - count_at_start_; - } - - private: - std::int64_t Read() const { - std::int64_t count; - RUY_CHECK_NE(read(fd_, &count, sizeof(count)), -1); - return count; - } - std::int64_t count_at_start_ = -1; - std::int64_t count_at_stop_ = -1; - bool started_ = false; - int fd_ = -1; -}; -#else -// Placeholder implementation to at least compile outside of linux. -#define PERF_TYPE_RAW 0 -class PerfEvent { - public: - PerfEvent(std::uint32_t, std::uint64_t) {} - ~PerfEvent() {} - void Start() {} - void Stop() {} - std::int64_t Count() const { return 0; } -}; -#endif - -// ARM-specific. Query ARM PMU counters as Linux perf events using -// PERF_TYPE_RAW. -namespace arm_pmuv3 { - -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wunused-const-variable" - -// These event numbers are listed in the ARMv8 architecture reference manual. -constexpr std::uint16_t L1I_CACHE_REFILL = 0x01; -constexpr std::uint16_t L1I_TLB_REFILL = 0x02; -constexpr std::uint16_t L1D_CACHE_REFILL = 0x03; -constexpr std::uint16_t L1D_CACHE = 0x04; -constexpr std::uint16_t L1D_TLB_REFILL = 0x05; -constexpr std::uint16_t LD_RETIRED = 0x06; -constexpr std::uint16_t ST_RETIRED = 0x07; -constexpr std::uint16_t INST_RETIRED = 0x08; -constexpr std::uint16_t EXC_TAKEN = 0x09; -constexpr std::uint16_t EXC_RETURN = 0x0A; -constexpr std::uint16_t CID_WRITE_RETIRED = 0x0B; -constexpr std::uint16_t PC_WRITE_RETIRED = 0x0C; -constexpr std::uint16_t BR_IMMED_RETIRED = 0x0D; -constexpr std::uint16_t BR_RETURN_RETIRED = 0x0E; -constexpr std::uint16_t UNALIGNED_LDST_RETIRED = 0x0F; -constexpr std::uint16_t BR_MIS_PRED = 0x10; -constexpr std::uint16_t CPU_CYCLES = 0x11; -constexpr std::uint16_t BR_PRED = 0x12; -constexpr std::uint16_t MEM_ACCESS = 0x13; -constexpr std::uint16_t L1I_CACHE = 0x14; -constexpr std::uint16_t L1D_CACHE_WB = 0x15; -constexpr std::uint16_t L2D_CACHE = 0x16; -constexpr std::uint16_t L2D_CACHE_REFILL = 0x17; -constexpr std::uint16_t L2D_CACHE_WB = 0x18; -constexpr std::uint16_t BUS_ACCESS = 0x19; -constexpr std::uint16_t MEMORY_ERROR = 0x1A; -constexpr std::uint16_t INST_SPEC = 0x1B; -constexpr std::uint16_t TTBR_WRITE_RETIRED = 0x1C; -constexpr std::uint16_t BUS_CYCLES = 0x1D; -constexpr std::uint16_t CHAIN = 0x1E; -constexpr std::uint16_t L1D_CACHE_ALLOCATE = 0x1F; -constexpr std::uint16_t L2D_CACHE_ALLOCATE = 0x20; -constexpr std::uint16_t BR_RETIRED = 0x21; -constexpr std::uint16_t BR_MIS_PRED_RETIRED = 0x22; -constexpr std::uint16_t STALL_FRONTEND = 0x23; -constexpr std::uint16_t STALL_BACKEND = 0x24; -constexpr std::uint16_t L1D_TLB = 0x25; -constexpr std::uint16_t L1I_TLB = 0x26; -constexpr std::uint16_t L2I_CACHE = 0x27; -constexpr std::uint16_t L2I_CACHE_REFILL = 0x28; -constexpr std::uint16_t L3D_CACHE_ALLOCATE = 0x29; -constexpr std::uint16_t L3D_CACHE_REFILL = 0x2A; -constexpr std::uint16_t L3D_CACHE = 0x2B; -constexpr std::uint16_t L3D_CACHE_WB = 0x2C; -constexpr std::uint16_t L2D_TLB_REFILL = 0x2D; -constexpr std::uint16_t L2I_TLB_REFILL = 0x2E; -constexpr std::uint16_t L2D_TLB = 0x2F; -constexpr std::uint16_t L2I_TLB = 0x30; -constexpr std::uint16_t LL_CACHE = 0x32; -constexpr std::uint16_t LL_CACHE_MISS = 0x33; -constexpr std::uint16_t DTLB_WALK = 0x34; -constexpr std::uint16_t LL_CACHE_RD = 0x36; -constexpr std::uint16_t LL_CACHE_MISS_RD = 0x37; - -// Additional implementation-defined events found by googling around. -constexpr std::uint16_t L1D_CACHE_RD = 0x40; -constexpr std::uint16_t L1D_CACHE_REFILL_RD = 0x42; -constexpr std::uint16_t L1D_TLB_REFILL_RD = 0x4C; -constexpr std::uint16_t L1D_TLB_RD = 0x4E; -constexpr std::uint16_t L2D_CACHE_RD = 0x50; -constexpr std::uint16_t L2D_CACHE_REFILL_RD = 0x52; -constexpr std::uint16_t BUS_ACCESS_RD = 0x60; -constexpr std::uint16_t MEM_ACCESS_RD = 0x66; -constexpr std::uint16_t L3D_CACHE_RD = 0xA0; -constexpr std::uint16_t L3D_CACHE_REFILL_RD = 0xA2; - -#pragma GCC diagnostic pop - -}; // namespace arm_pmuv3 - -class PmuEventsPrivate { - public: - PmuEventsPrivate() - : l1d_cache_refill(PERF_TYPE_RAW, arm_pmuv3::L1D_CACHE_REFILL), - l2d_cache_refill(PERF_TYPE_RAW, arm_pmuv3::L2D_CACHE_REFILL), - l3d_cache_refill(PERF_TYPE_RAW, arm_pmuv3::L3D_CACHE_REFILL), - ll_cache_miss(PERF_TYPE_RAW, arm_pmuv3::LL_CACHE_MISS), - l1d_tlb_refill(PERF_TYPE_RAW, arm_pmuv3::L1D_TLB_REFILL), - l2d_tlb_refill(PERF_TYPE_RAW, arm_pmuv3::L2D_TLB_REFILL), - stall_frontend(PERF_TYPE_RAW, arm_pmuv3::STALL_FRONTEND), - stall_backend(PERF_TYPE_RAW, arm_pmuv3::STALL_BACKEND), - br_mis_pred(PERF_TYPE_RAW, arm_pmuv3::BR_MIS_PRED) {} - - private: - friend class PmuEvents; - PerfEvent l1d_cache_refill; - PerfEvent l2d_cache_refill; - PerfEvent l3d_cache_refill; - PerfEvent ll_cache_miss; - PerfEvent l1d_tlb_refill; - PerfEvent l2d_tlb_refill; - PerfEvent stall_frontend; - PerfEvent stall_backend; - PerfEvent br_mis_pred; -}; - -PmuEvents::PmuEvents() : priv(new PmuEventsPrivate) {} -PmuEvents::~PmuEvents() { delete priv; } - -void PmuEvents::StartRecording() { - priv->l1d_cache_refill.Start(); - priv->l2d_cache_refill.Start(); - priv->l3d_cache_refill.Start(); - priv->ll_cache_miss.Start(); - priv->l1d_tlb_refill.Start(); - priv->l2d_tlb_refill.Start(); - priv->stall_frontend.Start(); - priv->stall_backend.Start(); - priv->br_mis_pred.Start(); -} - -void PmuEvents::StopRecording() { - priv->l1d_cache_refill.Stop(); - priv->l2d_cache_refill.Stop(); - priv->l3d_cache_refill.Stop(); - priv->ll_cache_miss.Stop(); - priv->l1d_tlb_refill.Stop(); - priv->l2d_tlb_refill.Stop(); - priv->stall_frontend.Stop(); - priv->stall_backend.Stop(); - priv->br_mis_pred.Stop(); -} - -float PmuEvents::BranchMispredictionCount() const { - return static_cast(priv->br_mis_pred.Count()); -} - -float PmuEvents::FrontendStallCount() const { - return static_cast(priv->stall_frontend.Count()); -} - -float PmuEvents::BackendStallCount() const { - return static_cast(priv->stall_backend.Count()); -} - -float PmuEvents::L1RefillCount() const { - return static_cast(priv->l1d_cache_refill.Count()); -} - -float PmuEvents::L2RefillCount() const { - return static_cast(priv->l2d_cache_refill.Count()); -} - -float PmuEvents::L3RefillCount() const { - // Important: this was discovered in the context of the above experiments, - // which also tested the _RD variants of these counters. So it's possible that - // it's just not needed here with the default (non _RD) counters. - // - // Some CPUs implement LL_CACHE_MISS[_RD], some implement - // L3D_CACHE_REFILL[_RD]. It seems that either one of these two counters is - // zero, or they roughly both agree with each other. Therefore, taking the max - // of them is a reasonable way to get something more portable across various - // CPUs. - return static_cast( - std::max(priv->l3d_cache_refill.Count(), priv->ll_cache_miss.Count())); -} - -float PmuEvents::L1TLBRefillCount() const { - return static_cast(priv->l1d_tlb_refill.Count()); -} - -float PmuEvents::L2TLBRefillCount() const { - return static_cast(priv->l2d_tlb_refill.Count()); -} - -} // namespace ruy diff --git a/pmu.h b/pmu.h deleted file mode 100644 index 03f0cb7..0000000 --- a/pmu.h +++ /dev/null @@ -1,44 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_PMU_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_PMU_H_ - -namespace ruy { - -class PmuEventsPrivate; - -class PmuEvents { - public: - PmuEvents(); - ~PmuEvents(); - void StartRecording(); - void StopRecording(); - float L1RefillCount() const; - float L2RefillCount() const; - float L3RefillCount() const; - float BranchMispredictionCount() const; - float FrontendStallCount() const; - float BackendStallCount() const; - float L1TLBRefillCount() const; - float L2TLBRefillCount() const; - - private: - PmuEventsPrivate* priv = nullptr; -}; - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_PMU_H_ diff --git a/prepack.h b/prepack.h deleted file mode 100644 index 138410d..0000000 --- a/prepack.h +++ /dev/null @@ -1,108 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Implementation of low-level pre-packing API. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_PREPACK_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_PREPACK_H_ - -#include -#include - -#include "check_macros.h" -#include "context.h" -#include "dispatch.h" -#include "internal_matrix.h" -#include "matrix.h" -#include "path.h" -#include "profiler/instrumentation.h" -#include "side_pair.h" -#include "spec.h" -#include "trmul.h" -#include "trmul_params.h" -#include "tune.h" - -namespace ruy { - -template -void PrePackForMulInternal(const Matrix& lhs, - const Matrix& rhs, const Spec& spec, - Context* context, Matrix* dst, - SidePair prepacked, - std::function alloc_fn) { - profiler::ScopeLabel label("PrePackForMul"); - Path the_path = context->GetPathToTake(); - RUY_CHECK_NE(the_path, Path::kReference); - constexpr Path TrMulCompiledPaths = CompiledPaths & ~Path::kReference; - Matrix transposed_lhs(lhs); - Transpose(&transposed_lhs); - TrMulParams params; - CreateTrMulParams(transposed_lhs, rhs, spec, context, dst, - the_path, ¶ms); - - const SidePair origin{0, 0}; - const SidePair rounded_dims{params.packed[Side::kLhs].layout.cols, - params.packed[Side::kRhs].layout.cols}; - - Tuning tuning = context->GetMainThreadTuning(); - for (Side side : {Side::kLhs, Side::kRhs}) { - if (prepacked[side]) { - prepacked[side]->data_size = DataSize(params.packed[side]); - prepacked[side]->sums_size = SumsSize(params.packed[side]); - prepacked[side]->data = alloc_fn(prepacked[side]->data_size); - prepacked[side]->sums = alloc_fn(prepacked[side]->sums_size); - params.packed[side].data = prepacked[side]->data; - params.packed[side].sums = prepacked[side]->sums; - params.RunPack(side, tuning, origin[side], rounded_dims[side]); - } - } -} - -template -void MulWithPrepackedInternal(const Matrix& lhs, - const Matrix& rhs, const Spec& spec, - Context* context, Matrix* dst, - SidePair prepacked) { - profiler::ScopeLabel label("MulWithPrepacked"); - - EnforceLayoutSupport(lhs.layout, rhs.layout, dst->layout); - EnforceZeroPointSupport(lhs.zero_point, rhs.zero_point, - dst->zero_point); - - Path the_path = context->GetPathToTake(); - RUY_CHECK_NE(the_path, Path::kReference); - constexpr Path TrMulCompiledPaths = CompiledPaths & ~Path::kReference; - Matrix transposed_lhs(lhs); - Transpose(&transposed_lhs); - TrMulParams params; - CreateTrMulParams(transposed_lhs, rhs, spec, context, dst, - the_path, ¶ms); - - for (Side side : {Side::kLhs, Side::kRhs}) { - if (prepacked[side]) { - params.packed[side].data = prepacked[side]->data; - params.packed[side].sums = prepacked[side]->sums; - params.is_prepacked[side] = true; - } - } - - TrMul(¶ms, context); -} - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_PREPACK_H_ diff --git a/prepacked_cache.cc b/prepacked_cache.cc deleted file mode 100644 index 776ef47..0000000 --- a/prepacked_cache.cc +++ /dev/null @@ -1,82 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "prepacked_cache.h" - -#include "matrix.h" -#include "profiler/instrumentation.h" - -namespace ruy { - -using CacheIterator = PrepackedCache::CacheIterator; - -// Looks for an entry with `key`. If found, update its time stamp. -CacheIterator PrepackedCache::FindAndUpdate(const CacheKey &key) { - auto itr = cache_.find(key); - // If found, update with new access time for this entry. - if (itr != cache_.end()) { - const TimePoint time = CacheNow(); - itr->second.second = time; - } - return itr; -} - -void PrepackedCache::Insert(const CacheKey &key, - const PrepackedMatrix &matrix) { - // Calculate size of this new item. - const size_t size_bytes = matrix.data_size + matrix.sums_size; - - // While we are above the threshold of ejection, eject the LRU entry. - while (!cache_.empty() && - ((TotalSize() + size_bytes) > ejection_threshold_)) { - EjectOne(); - } - DoInsert(key, matrix); - cache_size_ += matrix.data_size + matrix.sums_size; -} - -void PrepackedCache::EjectOne() { - TimePoint oldest_time = CacheNow(); - auto oldest = cache_.begin(); - { - profiler::ScopeLabel label("PepackedCacheEjection"); - for (auto itr = cache_.begin(); itr != cache_.end(); ++itr) { - if (itr->second.second < oldest_time) { - oldest_time = itr->second.second; - oldest = itr; - } - } - } - PrepackedMatrix &pmatrix = oldest->second.first; - cache_size_ -= pmatrix.data_size; - cache_size_ -= pmatrix.sums_size; - allocator_.Free(pmatrix.data); - allocator_.Free(pmatrix.sums); - cache_.erase(oldest); -} - -void PrepackedCache::AllocatePrepackedMatrix(PrepackedMatrix *pmatrix) { - pmatrix->data = allocator_.Alloc(pmatrix->data_size); - pmatrix->sums = allocator_.Alloc(pmatrix->sums_size); -} - -void PrepackedCache::DoInsert(const CacheKey &key, - const PrepackedMatrix &matrix) { - const TimePoint t = CacheNow(); - const MatrixWithTimeStamp mts({matrix, t}); - cache_.insert({key, mts}); -} - -} // namespace ruy diff --git a/prepacked_cache.h b/prepacked_cache.h deleted file mode 100644 index a47647a..0000000 --- a/prepacked_cache.h +++ /dev/null @@ -1,130 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_PREPACKED_CACHE_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_PREPACKED_CACHE_H_ - -#include -#include -#include -#include -#include - -#include "allocator.h" -#include "matrix.h" -#include "time.h" - -namespace ruy { - -namespace detail { - -// Tracks a set of blocks allocated from the underlying system allocator. -class SystemBlockAllocator { - public: - void *Alloc(std::ptrdiff_t num_bytes) { - void *p = detail::SystemAlignedAlloc(num_bytes); - blocks_.push_back(p); - return p; - } - - void Free(void *block) { - for (auto it = blocks_.begin(); it != blocks_.end(); ++it) { - if (*it == block) { - detail::SystemAlignedFree(block); - blocks_.erase(it); - return; - } - } - RUY_DCHECK(false); // Trying to free pointer we did not allocate. - } - - ~SystemBlockAllocator() { - for (void *block : blocks_) { - detail::SystemAlignedFree(block); - } - } - - private: - std::vector blocks_; -}; - -} // namespace detail - -enum CachePolicy { kNoCache, kCacheLHSOnNarrowMul }; - -// "Low effort" Least Recently Used Cache for Prepacked Matrices -// A cache mechanism for prepacked matrices that ejects oldest entries. -// The implementation is "low effort" in the following ways: -// - we just linearly search for the oldest entry when doing an ejection -// - the ejection policy is very simple: if the new size would be above the -// . threshold, we will eject entries until the size is below the threshold. -// Current use cases (RNNs with GEMV operations) indicate that ejection is rare -// and memory constraints are tight, so we devote no additional storage to the -// LRU mechanism and accept O(n) search to eject oldest entry. In practice, -// the number of total entries has not been shown to be large. -// This class is not thread safe. In Ruy, memory allocation for packed matrices -// is done in a single threaded context and the actual packing activity may -// be done in a multi-threaded context. -class PrepackedCache { - public: - static constexpr int kDefaultEjectionThresholdBytes = 1 << 28; - - using CacheKey = std::pair; - - using MatrixWithTimeStamp = std::pair; - - using CacheIterator = std::map::const_iterator; - - using AlignedAllocator = detail::AlignedAllocator; - - explicit PrepackedCache( - int32_t ejection_threshold = kDefaultEjectionThresholdBytes) - : ejection_threshold_(ejection_threshold), cache_size_(0) {} - - // Looks for an entry with `key`. If found, update its time stamp. - CacheIterator FindAndUpdate(const CacheKey &key); - - // Returns end iterator for internal cache. The iterator type is appropriate - // to use with `FindAndUpdate`. - CacheIterator cend() const { return cache_.end(); } - - // Returns the total size (in bytes) of data held in this cache. - int TotalSize() const { return cache_size_; } - - // All calls to get current TimePoints go through here. - // TODO(b/145625614) Profile timestamps on relevant models to see if - // this level of granularity is sufficient. CoarseNow is cheap so - // it would be nice to keep it. - TimePoint CacheNow() const { return CoarseNow(); } - - // Performs the memory allocation for the `data` and `sums` members of a - // PrepackedMatrix. - void AllocatePrepackedMatrix(PrepackedMatrix *pmatrix); - - // Adds the PrepackedMatrix to the cache, possibly ejecting other values. - void Insert(const CacheKey &key, const PrepackedMatrix &matrix); - - private: - void EjectOne(); - void DoInsert(const CacheKey &key, const PrepackedMatrix &matrix); - detail::SystemBlockAllocator allocator_; - std::map cache_; - const int32_t ejection_threshold_; - size_t cache_size_; -}; - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_PREPACKED_CACHE_H_ diff --git a/prepacked_cache_test.cc b/prepacked_cache_test.cc deleted file mode 100644 index fbf4f5a..0000000 --- a/prepacked_cache_test.cc +++ /dev/null @@ -1,210 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "prepacked_cache.h" - -#include // NOLINT(build/c++11) - -#include "testing/base/public/gunit.h" -#include "ruy.h" -#include "time.h" - -namespace ruy { -namespace { - -TEST(PrepackedCacheTest, TestCacheEjection) { - // Create the cache. - PrepackedCache prepacked_cache(32); - // Allocate the prepacked matrix. - PrepackedMatrix mat1; - mat1.data_size = 16; - mat1.sums_size = 8; - prepacked_cache.AllocatePrepackedMatrix(&mat1); - auto cache_key1 = std::make_pair(nullptr, mat1.data); - prepacked_cache.Insert(cache_key1, mat1); - std::this_thread::sleep_for(std::chrono::milliseconds(10)); - // Get a time point after the insertion into the cache. - TimePoint current = CoarseNow(); - std::this_thread::sleep_for(std::chrono::milliseconds(10)); - PrepackedCache::CacheIterator itr = prepacked_cache.FindAndUpdate(cache_key1); - EXPECT_NE(itr, prepacked_cache.cend()); - // By finding mat1, we updated its timestamp. Verify that `current` is older - // than the time stamp now associated with mat1. - EXPECT_LT(current, itr->second.second); - PrepackedMatrix mat2; - mat2.data_size = 8; - mat2.sums_size = 4; - prepacked_cache.AllocatePrepackedMatrix(&mat2); - - auto cache_key2 = std::make_pair(nullptr, mat2.data); - prepacked_cache.Insert(cache_key2, mat2); - // The cache size was exceeded by inserting mat2. Ensure that mat1 was - // ejected. - EXPECT_EQ(prepacked_cache.FindAndUpdate(cache_key1), prepacked_cache.cend()); -} - -TEST(PrepackedCacheTest, TestCacheBasic) { - // Create the cache. - PrepackedCache prepacked_cache(48); - // Allocate the prepacked matrix. - PrepackedMatrix mat1; - mat1.data_size = 16; - mat1.sums_size = 8; - prepacked_cache.AllocatePrepackedMatrix(&mat1); - - auto cache_key1 = std::make_pair(nullptr, mat1.data); - prepacked_cache.Insert(cache_key1, mat1); - std::this_thread::sleep_for(std::chrono::milliseconds(10)); - EXPECT_NE(prepacked_cache.FindAndUpdate(cache_key1), prepacked_cache.cend()); - - PrepackedMatrix mat2; - mat2.data_size = 8; - mat2.sums_size = 4; - prepacked_cache.AllocatePrepackedMatrix(&mat2); - - auto cache_key2 = std::make_pair(nullptr, mat2.data); - std::this_thread::sleep_for(std::chrono::milliseconds(10)); - prepacked_cache.Insert(cache_key2, mat2); - // The cache size was not exceeded by inserting mat2. Ensure that mat1 was not - // ejected. - EXPECT_NE(prepacked_cache.FindAndUpdate(cache_key1), prepacked_cache.cend()); -} - -TEST(PrepackedCacheTest, TestCacheEjection2) { - // Create the cache. - PrepackedCache prepacked_cache(73); - // Allocate the prepacked matrix 1. - PrepackedMatrix mat1; - mat1.data_size = 16; - mat1.sums_size = 8; - prepacked_cache.AllocatePrepackedMatrix(&mat1); - auto cache_key1 = std::make_pair(nullptr, mat1.data); - prepacked_cache.Insert(cache_key1, mat1); - std::this_thread::sleep_for(std::chrono::milliseconds(10)); - - // Allocate the prepacked matrix 2. - PrepackedMatrix mat2; - mat2.data_size = 16; - mat2.sums_size = 8; - prepacked_cache.AllocatePrepackedMatrix(&mat2); - auto cache_key2 = std::make_pair(nullptr, mat2.data); - prepacked_cache.Insert(cache_key2, mat2); - std::this_thread::sleep_for(std::chrono::milliseconds(10)); - - // Allocate the prepacked matrix 3. - PrepackedMatrix mat31; - mat31.data_size = 16; - mat31.sums_size = 8; - prepacked_cache.AllocatePrepackedMatrix(&mat31); - auto cache_key3 = std::make_pair(nullptr, mat31.data); - prepacked_cache.Insert(cache_key3, mat31); - std::this_thread::sleep_for(std::chrono::milliseconds(10)); - - // The next insertion will cause the cache size to go over the ejection - // threshold. Touch matrix 1 and matrix 3 to make matrix 2 the oldest - EXPECT_NE(prepacked_cache.FindAndUpdate(cache_key1), prepacked_cache.cend()); - EXPECT_NE(prepacked_cache.FindAndUpdate(cache_key3), prepacked_cache.cend()); - std::this_thread::sleep_for(std::chrono::milliseconds(10)); - - // Allocate the prepacked matrix 4. - PrepackedMatrix mat4; - mat4.data_size = 16; - mat4.sums_size = 8; - prepacked_cache.AllocatePrepackedMatrix(&mat4); - auto cache_key4 = std::make_pair(nullptr, mat4.data); - prepacked_cache.Insert(cache_key4, mat4); - std::this_thread::sleep_for(std::chrono::milliseconds(10)); - - // Ensure that mat2 was ejected, but mat1, mat3, and mat4 were not. - EXPECT_EQ(prepacked_cache.FindAndUpdate(cache_key2), prepacked_cache.cend()); - EXPECT_NE(prepacked_cache.FindAndUpdate(cache_key3), prepacked_cache.cend()); - EXPECT_NE(prepacked_cache.FindAndUpdate(cache_key1), prepacked_cache.cend()); - EXPECT_NE(prepacked_cache.FindAndUpdate(cache_key4), prepacked_cache.cend()); -} - -TEST(PrepackedCacheTest, TestCacheOnCacheable) { - // Create context and set the cache policy - ruy::Context context; - context.cache_policy = ruy::kCacheLHSOnNarrowMul; - PrepackedCache* cache = context.GetPrepackedCache(); - EXPECT_EQ(cache->TotalSize(), 0); - - const float lhs_data[] = {1, 2, 3, 4}; - const float rhs_data[] = {1, 2}; - float dst_data[4]; - - ruy::Matrix lhs; - ruy::MakeSimpleLayout(2, 2, ruy::Order::kRowMajor, &lhs.layout); - lhs.data = lhs_data; - ruy::Matrix rhs; - ruy::MakeSimpleLayout(2, 1, ruy::Order::kColMajor, &rhs.layout); - rhs.data = rhs_data; - ruy::Matrix dst; - ruy::MakeSimpleLayout(2, 1, ruy::Order::kColMajor, &dst.layout); - dst.data = dst_data; - - ruy::BasicSpec spec; - // Perform the multiplication and confirm no caching occurred. - ruy::Mul(lhs, rhs, spec, &context, &dst); - EXPECT_EQ(cache->TotalSize(), 0); - - // Set cacheable for the LHS, repeat the multiplication, and see - // that caching did occur. - lhs.cacheable = true; - ruy::Mul(lhs, rhs, spec, &context, &dst); - EXPECT_NE(cache->TotalSize(), 0); -} - -TEST(PrepackedCacheTest, TestClearCache) { - // Create context and set the cache policy - ruy::Context context; - context.cache_policy = ruy::kCacheLHSOnNarrowMul; - PrepackedCache* cache = context.GetPrepackedCache(); - EXPECT_EQ(cache->TotalSize(), 0); - - const float lhs_data[] = {1, 2, 3, 4}; - const float rhs_data[] = {1, 2}; - float dst_data[4]; - - ruy::Matrix lhs; - ruy::MakeSimpleLayout(2, 2, ruy::Order::kRowMajor, &lhs.layout); - lhs.data = lhs_data; - ruy::Matrix rhs; - ruy::MakeSimpleLayout(2, 1, ruy::Order::kColMajor, &rhs.layout); - rhs.data = rhs_data; - ruy::Matrix dst; - ruy::MakeSimpleLayout(2, 1, ruy::Order::kColMajor, &dst.layout); - dst.data = dst_data; - - ruy::BasicSpec spec; - // Set cacheable for the LHS and see that caching occurs. - lhs.cacheable = true; - ruy::Mul(lhs, rhs, spec, &context, &dst); - EXPECT_NE(cache->TotalSize(), 0); - - // Clear the cache via the Context. - context.ClearPrepackedCache(); - // Verify that the cache is now empty. - cache = context.GetPrepackedCache(); - EXPECT_EQ(cache->TotalSize(), 0); -} - -} // namespace -} // namespace ruy - -int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); -} diff --git a/profiler/BUILD b/profiler/BUILD deleted file mode 100644 index b0af802..0000000 --- a/profiler/BUILD +++ /dev/null @@ -1,52 +0,0 @@ -# A minimalistic profiler sampling pseudo-stacks - -package( - default_visibility = ["//visibility:public"], - licenses = ["notice"], # Apache 2.0 -) - -config_setting( - name = "ruy_profiler", - define_values = {"ruy_profiler": "true"}, -) - -cc_library( - name = "instrumentation", - srcs = ["instrumentation.cc"], - hdrs = ["instrumentation.h"], - defines = select({ - ":ruy_profiler": ["RUY_PROFILER"], - "//conditions:default": [], - }), -) - -cc_library( - name = "profiler", - srcs = [ - "profiler.cc", - "treeview.cc", - ], - hdrs = [ - "profiler.h", - "treeview.h", - ], - deps = [":instrumentation"], -) - -cc_library( - name = "test_instrumented_library", - testonly = True, - srcs = ["test_instrumented_library.cc"], - hdrs = ["test_instrumented_library.h"], - deps = [":instrumentation"], -) - -cc_test( - name = "test", - srcs = ["test.cc"], - deps = [ - ":profiler", - ":test_instrumented_library", - "@com_google_googletest//:gtest", - ], -) diff --git a/profiler/README.md b/profiler/README.md deleted file mode 100644 index 8d79025..0000000 --- a/profiler/README.md +++ /dev/null @@ -1,149 +0,0 @@ -# A minimalistic profiler sampling pseudo-stacks - -## Overview - -The present directory is the "ruy profiler". As a time profiler, it allows to -measure where code is spending time. - -Contrary to most typical profilers, what it samples is not real call stacks, but -"pseudo-stacks" which are just simple data structures constructed from within -the program being profiled. Using this profiler requires manually instrumenting -code to construct such pseudo-stack information. - -Another unusual characteristic of this profiler is that it uses only the C++11 -standard library. It does not use any non-portable feature, in particular it -does not rely on signal handlers. The sampling is performed by a thread, the -"profiler thread". - -A discussion of pros/cons of this approach is appended below. - -## How to use this profiler - -### How to instrument code - -An example of instrumented code is given in `test_instrumented_library.cc`. - -Code is instrumented by constructing `ScopeLabel` objects. These are RAII -helpers, ensuring that the thread pseudo-stack contains the label during their -lifetime. In the most common use case, one would construct such an object at the -start of a function, so that its scope is the function scope and it allows to -measure how much time is spent in this function. - -```c++ -#include "ruy/profiler/instrumentation.h" - -... - -void SomeFunction() { - ruy::profiling::ScopeLabel function_label("SomeFunction"); - ... do something ... -} -``` - -A `ScopeLabel` may however have any scope, for instance: - -```c++ -if (some_case) { - ruy::profiling::ScopeLabel extra_work_label("Some more work"); - ... do some more work ... -} -``` - -The string passed to the `ScopeLabel` constructor must be just a pointer to a -literal string (a `char*` pointer). The profiler will assume that these pointers -stay valid until the profile is finalized. - -However, that literal string may be a `printf` format string, and labels may -have up to 4 parameters, of type `int`. For example: - -```c++ -void SomeFunction(int size) { - ruy::profiling::ScopeLabel function_label("SomeFunction (size=%d)", size); - -``` - -### How to run the profiler - -Profiling instrumentation is a no-op unless the preprocessor token -`RUY_PROFILER` is defined, so defining it is the first step when actually -profiling. When building with Bazel, the preferred way to enable that is to pass -this flag on the Bazel command line: - -``` ---define=ruy_profiler=true -``` - -To actually profile a code scope, it is enough to construct a `ScopeProfile` -object, also a RAII helper. It will start the profiler on construction, and on -destruction it will terminate the profiler and report the profile treeview on -standard output by default. Example: - -```c++ -void SomeProfiledBenchmark() { - ruy::profiling::ScopeProfile profile; - - CallSomeInstrumentedCode(); -} -``` - -An example is provided by the `:test` target in the present directory. Run it -with `--define=ruy_profiler=true` as explained above: - -``` -bazel run -c opt \ - --define=ruy_profiler=true \ - //tensorflow/lite/experimental/ruy/profiler:test -``` - -The default behavior dumping the treeview on standard output may be overridden -by passing a pointer to a `TreeView` object to the `ScopeProfile` constructor. -This causes the tree-view to be stored in that `TreeView` object, where it may -be accessed an manipulated using the functions declared in `treeview.h`. The -aforementioned `:test` provides examples for doing so. - -## Advantages and inconvenients - -Compared to a traditional profiler, e.g. Linux's "perf", the present kind of -profiler has the following inconvenients: - -* Requires manual instrumentation of code being profiled. -* Substantial overhead, modifying the performance characteristics of the code - being measured. -* Questionable accuracy. - -But also the following advantages: - -* Profiling can be driven from within a benchmark program, allowing the entire - profiling procedure to be a single command line. -* Not relying on symbol information removes removes exposure to toolchain - details and means less hassle in some build environments, especially - embedded/mobile (single command line to run and profile, no symbols files - required). -* Fully portable (all of this is standard C++11). -* Fully testable (see `:test`). Profiling becomes just another feature of the - code like any other. -* Customized instrumentation can result in easier to read treeviews (only - relevant functions, and custom labels may be more readable than function - names). -* Parametrized/formatted labels allow to do things that aren't possible with - call-stack-sampling profilers. For example, break down a profile where much - time is being spent in matrix multiplications, by the various matrix - multiplication shapes involved. - -The philosophy underlying this profiler is that software performance depends on -software engineers profiling often, and a key factor limiting that in practice -is the difficulty or cumbersome aspects of profiling with more serious profilers -such as Linux's "perf", especially in embedded/mobile development: multiple -command lines are involved to copy symbol files to devices, retrieve profile -data from the device, etc. In that context, it is useful to make profiling as -easy as benchmarking, even on embedded targets, even if the price to pay for -that is lower accuracy, higher overhead, and some intrusive instrumentation -requirement. - -Another key aspect determining what profiling approach is suitable for a given -context, is whether one already has a-priori knowledge of where much of the time -is likely being spent. When one has such a-priori knowledge, it is feasible to -instrument the known possibly-critical code as per the present approach. On the -other hand, in situations where one doesn't have such a-priori knowledge, a real -profiler such as Linux's "perf" allows to right away get a profile of real -stacks, from just symbol information generated by the toolchain. diff --git a/profiler/instrumentation.cc b/profiler/instrumentation.cc deleted file mode 100644 index 3ceefb3..0000000 --- a/profiler/instrumentation.cc +++ /dev/null @@ -1,130 +0,0 @@ -/* Copyright 2020 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "profiler/instrumentation.h" - -#ifdef RUY_PROFILER - -namespace ruy { -namespace profiler { - -void Label::operator=(const Label& other) { - format_ = other.format_; - args_count_ = other.args_count_; - for (int i = 0; i < args_count_; i++) { - args_[i] = other.args_[i]; - } -} - -bool Label::operator==(const Label& other) const { - if (std::string(format_) != std::string(other.format_)) { - return false; - } - if (args_count_ != other.args_count_) { - return false; - } - for (int i = 0; i < args_count_; i++) { - if (args_[i] != other.args_[i]) { - return false; - } - } - return true; -} - -std::string Label::Formatted() const { - static constexpr int kBufSize = 256; - char buf[kBufSize]; - if (args_count_ == 0) { - return format_; - } - if (args_count_ == 1) { - snprintf(buf, kBufSize, format_, args_[0]); - } else if (args_count_ == 2) { - snprintf(buf, kBufSize, format_, args_[0], args_[1]); - } else if (args_count_ == 3) { - snprintf(buf, kBufSize, format_, args_[0], args_[1], args_[2]); - } else if (args_count_ == 4) { - snprintf(buf, kBufSize, format_, args_[0], args_[1], args_[2], args_[3]); - } else { - abort(); - } - return buf; -} - -namespace detail { - -std::mutex* GlobalsMutex() { - static std::mutex mutex; - return &mutex; -} - -bool& GlobalIsProfilerRunning() { - static bool b; - return b; -} - -std::vector* GlobalAllThreadStacks() { - static std::vector all_stacks; - return &all_stacks; -} - -ThreadStack* ThreadLocalThreadStack() { - thread_local static ThreadStack thread_stack; - return &thread_stack; -} - -ThreadStack::ThreadStack() { - std::lock_guard lock(*GlobalsMutex()); - static std::uint32_t global_next_thread_stack_id = 0; - stack_.id = global_next_thread_stack_id++; - GlobalAllThreadStacks()->push_back(this); -} - -ThreadStack::~ThreadStack() { - std::lock_guard lock(*GlobalsMutex()); - std::vector* all_stacks = GlobalAllThreadStacks(); - for (auto it = all_stacks->begin(); it != all_stacks->end(); ++it) { - if (*it == this) { - all_stacks->erase(it); - return; - } - } -} -int GetBufferSize(const Stack& stack) { - return sizeof(stack.id) + sizeof(stack.size) + - stack.size * sizeof(stack.labels[0]); -} - -void CopyToBuffer(const Stack& stack, char* dst) { - memcpy(dst, &stack.id, sizeof(stack.id)); - dst += sizeof(stack.id); - memcpy(dst, &stack.size, sizeof(stack.size)); - dst += sizeof(stack.size); - memcpy(dst, stack.labels, stack.size * sizeof(stack.labels[0])); -} - -void ReadFromBuffer(const char* src, Stack* stack) { - memcpy(&stack->id, src, sizeof(stack->id)); - src += sizeof(stack->id); - memcpy(&stack->size, src, sizeof(stack->size)); - src += sizeof(stack->size); - memcpy(stack->labels, src, stack->size * sizeof(stack->labels[0])); -} - -} // namespace detail -} // namespace profiler -} // namespace ruy - -#endif diff --git a/profiler/instrumentation.h b/profiler/instrumentation.h deleted file mode 100644 index cb0e702..0000000 --- a/profiler/instrumentation.h +++ /dev/null @@ -1,203 +0,0 @@ -/* Copyright 2020 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_PROFILER_INSTRUMENTATION_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_PROFILER_INSTRUMENTATION_H_ - -#ifdef RUY_PROFILER -#include -#include -#include -#endif - -namespace ruy { -namespace profiler { - -#ifdef RUY_PROFILER - -// A label is how a code scope is annotated to appear in profiles. -// The stacks that are sampled by the profiler are stacks of such labels. -// A label consists of a literal string, plus optional integer arguments. -class Label { - public: - Label() {} - template - explicit Label(Args... args) { - Set(args...); - } - void Set(const char* format) { - format_ = format; - args_count_ = 0; - } - template - void Set(const char* format, Args... args) { - format_ = format; - args_count_ = sizeof...(args); - SetArgs(0, args...); - } - - void operator=(const Label& other); - - bool operator==(const Label& other) const; - - std::string Formatted() const; - const char* format() const { return format_; } - - private: - void SetArgs(int position, int arg0) { args_[position] = arg0; } - - template - void SetArgs(int position, int arg0, Args... args) { - SetArgs(position, arg0); - SetArgs(position + 1, args...); - } - - static constexpr int kMaxArgs = 4; - const char* format_ = nullptr; - int args_count_ = 0; - int args_[kMaxArgs]; -}; - -namespace detail { - -// Forward-declaration, see class ThreadStack below. -class ThreadStack; - -bool& GlobalIsProfilerRunning(); - -// Returns the global vector of pointers to all stacks, there being one stack -// per thread executing instrumented code. -std::vector* GlobalAllThreadStacks(); - -// Returns the mutex to be locked around any access to GlobalAllThreadStacks(). -std::mutex* GlobalsMutex(); - -// Returns the thread-local stack, specific to the current thread. -ThreadStack* ThreadLocalThreadStack(); - -// This 'stack' is what may be more appropriately called a 'pseudostack': -// It contains Label entries that are 'manually' entered by instrumentation -// code. It's unrelated to real call stacks. -struct Stack { - std::uint32_t id = 0; - static constexpr int kMaxSize = 64; - int size = 0; - Label labels[kMaxSize]; -}; - -// Returns the buffer byte size required by CopyToSample. -int GetBufferSize(const Stack& stack); - -// Copies this Stack into a byte buffer, called a 'sample'. -void CopyToBuffer(const Stack& stack, char* dst); - -// Populates this Stack from an existing sample buffer, typically -// produced by CopyToSample. -void ReadFromBuffer(const char* src, Stack* stack); - -// ThreadStack is meant to be used as a thread-local singleton, assigning to -// each thread a Stack object holding its pseudo-stack of profile labels, -// plus a mutex allowing to synchronize accesses to this pseudo-stack between -// this thread and a possible profiler thread sampling it. -class ThreadStack { - public: - ThreadStack(); - ~ThreadStack(); - - const Stack& stack() const { return stack_; } - - // Returns the mutex to lock around any access to this stack. Each stack is - // accessed by potentially two threads: the thread that it belongs to - // (which calls Push and Pop) and the profiler thread during profiling - // (which calls CopyToSample). - std::mutex& Mutex() const { return mutex_; } - - // Pushes a new label on the top of this Stack. - template - void Push(Args... args) { - // This mutex locking is needed to guard against race conditions as both - // the current thread and the profiler thread may be concurrently accessing - // this stack. In addition to that, this mutex locking also serves the other - // purpose of acting as a barrier (of compiler code reordering, of runtime - // CPU instruction reordering, and of memory access reordering), which - // gives a measure of correctness to this profiler. The downside is some - // latency. As this lock will be uncontended most of the times, the cost - // should be roughly that of an sequentially-consistent atomic access, - // comparable to an access to the level of CPU data cache that is shared - // among all cores, typically 60 cycles on current ARM CPUs, plus side - // effects from barrier instructions. - std::lock_guard lock(mutex_); - // Avoid overrunning the stack, even in 'release' builds. This profiling - // instrumentation code should not ship in release builds anyway, the - // overhead of this check is negligible, and overrunning a stack array would - // be bad. - if (stack_.size >= Stack::kMaxSize) { - abort(); - } - stack_.labels[stack_.size++].Set(args...); - } - - // Pops the top-most label from this Stack. - void Pop() { - // See the comment in Push about this lock. While it would be tempting to - // try to remove this lock and just atomically decrement size_ with a - // store-release, that would not necessarily be a substitute for all of the - // purposes that this lock serves, or if it was done carefully to serve all - // of the same purposes, then that wouldn't be faster than this (mostly - // uncontended) lock. - std::lock_guard lock(mutex_); - stack_.size--; - } - - private: - mutable std::mutex mutex_; - Stack stack_; -}; - -} // namespace detail - -// RAII user-facing way to construct Labels associated with their life scope -// and get them pushed to / popped from the current thread stack. -class ScopeLabel { - public: - template - ScopeLabel(Args... args) : thread_stack_(detail::ThreadLocalThreadStack()) { - thread_stack_->Push(args...); - } - - ~ScopeLabel() { thread_stack_->Pop(); } - - private: - detail::ThreadStack* thread_stack_; -}; - -#else // no RUY_PROFILER - -class ScopeLabel { - public: - template - explicit ScopeLabel(Args...) {} - - // This destructor is needed to consistently silence clang's -Wunused-variable - // which seems to trigger semi-randomly. - ~ScopeLabel() {} -}; - -#endif - -} // namespace profiler -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_PROFILER_INSTRUMENTATION_H_ diff --git a/profiler/profiler.cc b/profiler/profiler.cc deleted file mode 100644 index 8e527ba..0000000 --- a/profiler/profiler.cc +++ /dev/null @@ -1,109 +0,0 @@ -/* Copyright 2020 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "profiler/profiler.h" - -#ifdef RUY_PROFILER -#include -#include // NOLINT -#include -#include -#include // NOLINT -#include -#endif - -#include "profiler/instrumentation.h" -#include "profiler/treeview.h" - -namespace ruy { -namespace profiler { - -#ifdef RUY_PROFILER - -ScopeProfile::ScopeProfile() { Start(); } -ScopeProfile::ScopeProfile(bool enable) { - if (enable) { - Start(); - } -} -ScopeProfile::~ScopeProfile() { - if (!thread_) { - return; - } - finishing_.store(true); - thread_->join(); - Finish(); -} - -void ScopeProfile::Start() { - { - std::lock_guard lock(*detail::GlobalsMutex()); - if (detail::GlobalIsProfilerRunning()) { - fprintf(stderr, "FATAL: profiler already running!\n"); - abort(); - } - detail::GlobalIsProfilerRunning() = true; - } - finishing_ = false; - thread_.reset(new std::thread(&ScopeProfile::ThreadFunc, this)); -} - -void ScopeProfile::ThreadFunc() { - while (!finishing_.load()) { - std::this_thread::sleep_for(std::chrono::milliseconds(1)); - std::lock_guard lock(*detail::GlobalsMutex()); - auto* thread_stacks = detail::GlobalAllThreadStacks(); - for (detail::ThreadStack* thread_stack : *thread_stacks) { - Sample(*thread_stack); - } - } -} - -void ScopeProfile::Sample(const detail::ThreadStack& thread_stack) { - std::lock_guard lock(thread_stack.Mutex()); - // Drop empty stacks. - // This ensures that profiles aren't polluted by uninteresting threads. - if (thread_stack.stack().size == 0) { - return; - } - int sample_size = detail::GetBufferSize(thread_stack.stack()); - int old_buf_size = samples_buf_.size(); - samples_buf_.resize(old_buf_size + sample_size); - detail::CopyToBuffer(thread_stack.stack(), - samples_buf_.data() + old_buf_size); -} - -void ScopeProfile::Finish() { - { - std::lock_guard lock(*detail::GlobalsMutex()); - if (!detail::GlobalIsProfilerRunning()) { - fprintf(stderr, "FATAL: profiler is not running!\n"); - abort(); - } - detail::GlobalIsProfilerRunning() = false; - } - if (user_treeview_) { - user_treeview_->Populate(samples_buf_); - } else { - TreeView treeview; - treeview.Populate(samples_buf_); - Print(treeview); - } -} - -#endif // RUY_PROFILER - -} // namespace profiler -} // namespace ruy diff --git a/profiler/profiler.h b/profiler/profiler.h deleted file mode 100644 index caff2d5..0000000 --- a/profiler/profiler.h +++ /dev/null @@ -1,106 +0,0 @@ -/* Copyright 2020 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_PROFILER_PROFILER_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_PROFILER_PROFILER_H_ - -#include - -#ifdef RUY_PROFILER -#include -#include -#include -#include -#endif - -#include "profiler/instrumentation.h" -#include "profiler/treeview.h" - -namespace ruy { -namespace profiler { - -#ifdef RUY_PROFILER - -// RAII user-facing way to create a profiler and let it profile a code scope, -// and print out an ASCII/MarkDown treeview upon leaving the scope. -class ScopeProfile { - public: - // Default constructor, unconditionally profiling. - ScopeProfile(); - - // Constructor allowing to choose at runtime whether to profile. - explicit ScopeProfile(bool enable); - - // Destructor. It's where the profile is reported. - ~ScopeProfile(); - - // See treeview_ member. - void SetUserTreeView(TreeView* treeview) { user_treeview_ = treeview; } - - private: - void Start(); - - // Thread entry point function for the profiler thread. This thread is - // created on construction. - void ThreadFunc(); - - // Record a stack as a sample. - void Sample(const detail::ThreadStack& stack); - - // Finalize the profile. Called on destruction. - // If user_treeview_ is non-null, it will receive the treeview. - // Otherwise the treeview will just be printed. - void Finish(); - - // Buffer where samples are recorded during profiling. - std::vector samples_buf_; - - // Used to synchronize thread termination. - std::atomic finishing_; - - // Underlying profiler thread, which will perform the sampling. - // This profiler approach relies on a thread rather than on signals. - std::unique_ptr thread_; - - // TreeView to populate upon destruction. If left null (the default), - // a temporary treeview will be used and dumped on stdout. The user - // may override that by passing their own TreeView object for other - // output options or to directly inspect the TreeView. - TreeView* user_treeview_ = nullptr; -}; - -#else // no RUY_PROFILER - -struct ScopeProfile { - ScopeProfile() { -#ifdef GEMMLOWP_PROFILING - fprintf( - stderr, - "\n\n\n**********\n\nWARNING:\n\nLooks like you defined " - "GEMMLOWP_PROFILING, but this code has been ported to the new ruy " - "profiler replacing the old gemmlowp profiler. You should now be " - "defining RUY_PROFILER and not GEMMLOWP_PROFILING. When building using " - "Bazel, just pass --define=ruy_profiler=true.\n\n**********\n\n\n"); -#endif - } - explicit ScopeProfile(bool) {} -}; - -#endif - -} // namespace profiler -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_PROFILER_PROFILER_H_ diff --git a/profiler/test.cc b/profiler/test.cc deleted file mode 100644 index 6a8fbda..0000000 --- a/profiler/test.cc +++ /dev/null @@ -1,167 +0,0 @@ -/* Copyright 2020 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include - -#include "testing/base/public/gunit.h" -#include "profiler/profiler.h" -#include "profiler/test_instrumented_library.h" -#include "profiler/treeview.h" - -namespace ruy { -namespace profiler { -namespace { - -void DoSomeMergeSort(int size) { - std::vector data(size); - - std::default_random_engine engine; - for (auto& val : data) { - val = engine(); - } - - MergeSort(size, data.data()); -} - -// The purpose of this basic test is to cover the basic path that will be taken -// by a majority of users, not inspecting treeviews but just implicitly printing -// them on stdout, and to have this test enabled even when RUY_PROFILER is not -// defined, so that we have coverage for the non-RUY_PROFILER case. -TEST(ProfilerTest, MergeSortSingleThreadBasicTestEvenWithoutProfiler) { - { - ScopeProfile profile; - DoSomeMergeSort(1 << 20); - } -} - -#ifdef RUY_PROFILER - -TEST(ProfilerTest, MergeSortSingleThread) { - TreeView treeview; - { - ScopeProfile profile; - profile.SetUserTreeView(&treeview); - DoSomeMergeSort(1 << 20); - } - Print(treeview); - EXPECT_EQ(treeview.thread_roots().size(), 1); - const auto& thread_root = *treeview.thread_roots().begin()->second; - EXPECT_EQ(DepthOfTreeBelow(thread_root), 22); - EXPECT_GE( - WeightBelowNodeMatchingUnformatted(thread_root, "Merging sorted halves"), - 0.1 * thread_root.weight); - EXPECT_GE(WeightBelowNodeMatchingFormatted( - thread_root, "MergeSortRecurse (level=20, size=1)"), - 0.01 * thread_root.weight); - - TreeView treeview_collapsed; - CollapseNodesMatchingUnformatted(treeview, 5, "MergeSort (size=%d)", - &treeview_collapsed); - Print(treeview_collapsed); - const auto& collapsed_thread_root = - *treeview_collapsed.thread_roots().begin()->second; - EXPECT_EQ(DepthOfTreeBelow(collapsed_thread_root), 6); - EXPECT_EQ( - WeightBelowNodeMatchingUnformatted(thread_root, "MergeSort (size=%d)"), - WeightBelowNodeMatchingUnformatted(collapsed_thread_root, - "MergeSort (size=%d)")); -} - -TEST(ProfilerTest, MemcpyFourThreads) { - TreeView treeview; - { - ScopeProfile profile; - profile.SetUserTreeView(&treeview); - std::vector> threads; - for (int i = 0; i < 4; i++) { - threads.emplace_back(new std::thread([i]() { - ScopeLabel thread_label("worker thread #%d", i); - std::this_thread::sleep_for(std::chrono::milliseconds(1000)); - ScopeLabel some_more_work_label("some more work"); - std::this_thread::sleep_for(std::chrono::milliseconds(1000)); - })); - } - for (int i = 0; i < 4; i++) { - threads[i]->join(); - } - } - Print(treeview); - // Since we cleared GlobalAllThreadStacks and the current thread hasn't - // created any ScopeLabel, only the 4 worker threads should be recorded. - EXPECT_EQ(treeview.thread_roots().size(), 4); - for (const auto& thread_root : treeview.thread_roots()) { - const TreeView::Node& root_node = *thread_root.second; - // The root node may have 1 or 2 children depending on whether there is - // an "[other]" child. - EXPECT_GE(root_node.children.size(), 1); - EXPECT_LE(root_node.children.size(), 2); - const TreeView::Node& child_node = *root_node.children[0]; - EXPECT_EQ(child_node.label.format(), "worker thread #%d"); - // There must be 2 children, since roughly half the time will be in - // "some more work" leaving the other half in "[other]". - EXPECT_EQ(child_node.children.size(), 2); - const TreeView::Node& child_child_node = *child_node.children[0]; - // Since we sample every millisecond and the threads run for >= 2000 - // milliseconds, the "thread func" label should get roughly 2000 samples. - // Not very rigorous, as we're depending on the profiler thread getting - // scheduled, so to avoid this test being flaky, we use a much more - // conservative value of 500, one quarter of that normal value 2000. - EXPECT_GE(child_node.weight, 500); - // Likewise, allow up to four times more than the normal value 2000. - EXPECT_LE(child_node.weight, 8000); - // Roughly half of time should be spent under the "some more work" label. - float some_more_work_percentage = - 100.f * child_child_node.weight / child_node.weight; - EXPECT_GE(some_more_work_percentage, 40.0f); - EXPECT_LE(some_more_work_percentage, 60.0f); - } -} - -TEST(ProfilerTest, OneThreadAfterAnother) { - TreeView treeview; - { - ScopeProfile profile; - profile.SetUserTreeView(&treeview); - { - std::thread thread([]() { - ScopeLabel thread_label("thread 0"); - std::this_thread::sleep_for(std::chrono::milliseconds(1000)); - }); - thread.join(); - } - { - std::thread thread([]() { - ScopeLabel thread_label("thread 1"); - std::this_thread::sleep_for(std::chrono::milliseconds(1000)); - }); - thread.join(); - } - } - Print(treeview); - EXPECT_EQ(treeview.thread_roots().size(), 2); -} - -#endif // RUY_PROFILER - -} // namespace -} // namespace profiler -} // namespace ruy - -int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); -} diff --git a/profiler/test_instrumented_library.cc b/profiler/test_instrumented_library.cc deleted file mode 100644 index 42461c3..0000000 --- a/profiler/test_instrumented_library.cc +++ /dev/null @@ -1,59 +0,0 @@ -/* Copyright 2020 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include - -#include "profiler/instrumentation.h" - -namespace { - -void MergeSortRecurse(int level, int size, int* data, int* workspace) { - ruy::profiler::ScopeLabel function_label( - "MergeSortRecurse (level=%d, size=%d)", level, size); - if (size <= 1) { - return; - } - int half_size = size / 2; - MergeSortRecurse(level + 1, half_size, data, workspace); - MergeSortRecurse(level + 1, size - half_size, data + half_size, - workspace + half_size); - - ruy::profiler::ScopeLabel merging_sorted_halves_label( - "Merging sorted halves"); - int dst_index = 0; - int left_index = 0; - int right_index = half_size; - while (dst_index < size) { - int val; - if (left_index < half_size && - ((right_index >= size) || data[left_index] < data[right_index])) { - val = data[left_index++]; - } else { - val = data[right_index++]; - } - workspace[dst_index++] = val; - } - for (int i = 0; i < size; i++) { - data[i] = workspace[i]; - } -} - -} // namespace - -void MergeSort(int size, int* data) { - ruy::profiler::ScopeLabel function_label("MergeSort (size=%d)", size); - std::vector workspace(size); - MergeSortRecurse(0, size, data, workspace.data()); -} diff --git a/profiler/test_instrumented_library.h b/profiler/test_instrumented_library.h deleted file mode 100644 index 03956df..0000000 --- a/profiler/test_instrumented_library.h +++ /dev/null @@ -1,23 +0,0 @@ -/* Copyright 2020 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_PROFILER_TEST_INSTRUMENTED_LIBRARY_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_PROFILER_TEST_INSTRUMENTED_LIBRARY_H_ - -#include "profiler/instrumentation.h" - -void MergeSort(int size, int* data); - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_PROFILER_TEST_INSTRUMENTED_LIBRARY_H_ diff --git a/profiler/treeview.cc b/profiler/treeview.cc deleted file mode 100644 index 64ed05a..0000000 --- a/profiler/treeview.cc +++ /dev/null @@ -1,248 +0,0 @@ -/* Copyright 2020 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifdef RUY_PROFILER - -#include "profiler/treeview.h" - -#include -#include -#include -#include -#include - -namespace ruy { -namespace profiler { - -namespace { - -void SortNode(TreeView::Node* node) { - using NodePtr = std::unique_ptr; - std::sort(node->children.begin(), node->children.end(), - [](const NodePtr& n1, const NodePtr& n2) { - return n1->weight > n2->weight; - }); - for (const auto& child : node->children) { - SortNode(child.get()); - } -} - -// Records a stack i.e. a sample in a treeview, by incrementing the weights -// of matching existing nodes and/or by creating new nodes as needed, -// recursively, below the given node. -void AddStack(const detail::Stack& stack, TreeView::Node* node, int level) { - node->weight++; - if (stack.size == level) { - return; - } - TreeView::Node* child_to_add_to = nullptr; - for (const auto& child : node->children) { - if (child->label == stack.labels[level]) { - child_to_add_to = child.get(); - break; - } - } - if (!child_to_add_to) { - child_to_add_to = node->children.emplace_back(new TreeView::Node).get(); - child_to_add_to->label = stack.labels[level]; - } - AddStack(stack, child_to_add_to, level + 1); -} - -// Recursively populates the treeview below the given node with 'other' -// entries documenting for each node the difference between its weight and the -// sum of its children's weight. -void AddOther(TreeView::Node* node) { - int top_level_children_weight = 0; - for (const auto& child : node->children) { - AddOther(child.get()); - top_level_children_weight += child->weight; - } - if (top_level_children_weight != 0 && - top_level_children_weight != node->weight) { - const auto& new_child = node->children.emplace_back(new TreeView::Node); - new_child->label = Label("[other]"); - new_child->weight = node->weight - top_level_children_weight; - } -} - -} // namespace - -void TreeView::Populate(const std::vector& samples_buf_) { - thread_roots_.clear(); - // Populate the treeview with regular nodes coming from samples. - const char* buf_ptr = samples_buf_.data(); - const char* const buf_ptr_end = buf_ptr + samples_buf_.size(); - while (buf_ptr < buf_ptr_end) { - detail::Stack stack; - detail::ReadFromBuffer(buf_ptr, &stack); - // Empty stacks should have been dropped during sampling. - assert(stack.size > 0); - buf_ptr += GetBufferSize(stack); - const int id = stack.id; - if (!thread_roots_[id]) { - thread_roots_[id].reset(new Node); - } - AddStack(stack, thread_roots_[id].get(), 0); - } - // Populate the treeview with additional 'other' nodes, sort, and set - // root labels. - for (const auto& thread_root : thread_roots_) { - std::uint32_t id = thread_root.first; - Node* root = thread_root.second.get(); - AddOther(root); - SortNode(root); - root->label.Set("Thread %x (%d samples)", id, root->weight); - } -} - -// Recursively prints the treeview below the given node. The 'root' node -// argument is only needed to compute weights ratios, with the root ratio -// as denominator. -void PrintTreeBelow(const TreeView::Node& node, const TreeView::Node& root, - int level) { - if (&node == &root) { - printf("%s\n\n", node.label.Formatted().c_str()); - } else { - for (int i = 1; i < level; i++) { - printf(" "); - } - printf("* %.2f%% %s\n", 100.0f * node.weight / root.weight, - node.label.Formatted().c_str()); - } - for (const auto& child : node.children) { - PrintTreeBelow(*child, root, level + 1); - } -} - -void Print(const TreeView& treeview) { - printf("\n"); - printf("Profile (%d threads):\n\n", - static_cast(treeview.thread_roots().size())); - for (const auto& thread_root : treeview.thread_roots()) { - const TreeView::Node& root = *thread_root.second; - PrintTreeBelow(root, root, 0); - printf("\n"); - } -} - -int DepthOfTreeBelow(const TreeView::Node& node) { - if (node.children.empty()) { - return 0; - } else { - int max_child_depth = 0; - for (const auto& child : node.children) { - max_child_depth = std::max(max_child_depth, DepthOfTreeBelow(*child)); - } - return 1 + max_child_depth; - } -} - -int WeightBelowNodeMatchingFunction( - const TreeView::Node& node, - const std::function& match) { - int weight = 0; - if (match(node.label)) { - weight += node.weight; - } - for (const auto& child : node.children) { - weight += WeightBelowNodeMatchingFunction(*child, match); - } - return weight; -} - -int WeightBelowNodeMatchingUnformatted(const TreeView::Node& node, - const std::string& format) { - return WeightBelowNodeMatchingFunction( - node, [&format](const Label& label) { return label.format() == format; }); -} - -int WeightBelowNodeMatchingFormatted(const TreeView::Node& node, - const std::string& formatted) { - return WeightBelowNodeMatchingFunction( - node, [&formatted](const Label& label) { - return label.Formatted() == formatted; - }); -} - -void CollapseNode(const TreeView::Node& node_in, int depth, - TreeView::Node* node_out) { - node_out->label = node_in.label; - node_out->weight = node_in.weight; - node_out->children.clear(); - if (depth > 0) { - for (const auto& child_in : node_in.children) { - auto* child_out = new TreeView::Node; - node_out->children.emplace_back(child_out); - CollapseNode(*child_in, depth - 1, child_out); - } - } -} - -void CollapseSubnodesMatchingFunction( - const TreeView::Node& node_in, int depth, - const std::function& match, TreeView::Node* node_out) { - if (match(node_in.label)) { - CollapseNode(node_in, depth, node_out); - } else { - node_out->label = node_in.label; - node_out->weight = node_in.weight; - node_out->children.clear(); - - for (const auto& child_in : node_in.children) { - auto* child_out = new TreeView::Node; - node_out->children.emplace_back(child_out); - CollapseSubnodesMatchingFunction(*child_in, depth, match, child_out); - } - } -} - -void CollapseNodesMatchingFunction( - const TreeView& treeview_in, int depth, - const std::function& match, TreeView* treeview_out) { - treeview_out->mutable_thread_roots()->clear(); - for (const auto& thread_root_in : treeview_in.thread_roots()) { - std::uint32_t id = thread_root_in.first; - const auto& root_in = *thread_root_in.second; - auto* root_out = new TreeView::Node; - treeview_out->mutable_thread_roots()->emplace(id, root_out); - CollapseSubnodesMatchingFunction(root_in, depth, match, root_out); - } -} - -void CollapseNodesMatchingUnformatted(const TreeView& treeview_in, int depth, - const std::string& format, - TreeView* treeview_out) { - CollapseNodesMatchingFunction( - treeview_in, depth, - [&format](const Label& label) { return label.format() == format; }, - treeview_out); -} - -void CollapseNodesMatchingFormatted(const TreeView& treeview_in, int depth, - const std::string& formatted, - TreeView* treeview_out) { - CollapseNodesMatchingFunction( - treeview_in, depth, - [&formatted](const Label& label) { - return label.Formatted() == formatted; - }, - treeview_out); -} - -} // namespace profiler -} // namespace ruy - -#endif // RUY_PROFILER diff --git a/profiler/treeview.h b/profiler/treeview.h deleted file mode 100644 index 80d1180..0000000 --- a/profiler/treeview.h +++ /dev/null @@ -1,130 +0,0 @@ -/* Copyright 2020 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_PROFILER_TREEVIEW_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_PROFILER_TREEVIEW_H_ - -#ifdef RUY_PROFILER - -#include -#include -#include -#include - -#include "profiler/instrumentation.h" - -namespace ruy { -namespace profiler { - -// A tree view of a profile. -class TreeView { - public: - struct Node { - std::vector> children; - Label label; - int weight = 0; - }; - - void Populate(const std::vector& samples_buf_); - - // Intentionally an *ordered* map so that threads are enumerated - // in an order that's consistent and typically putting the 'main thread' - // first. - using ThreadRootsMap = std::map>; - - const ThreadRootsMap& thread_roots() const { return thread_roots_; } - ThreadRootsMap* mutable_thread_roots() { return &thread_roots_; } - - private: - ThreadRootsMap thread_roots_; -}; - -/* Below are API functions for manipulating and printing treeviews. */ - -// Prints the treeview to stdout. -void Print(const TreeView& treeview); - -// Prints the treeview below the given node on stdout. -void PrintTreeBelow(const TreeView::Node& node); - -// Returns the tree depth below the given node. -int DepthOfTreeBelow(const TreeView::Node& node); - -// Returns the sum of weights of nodes below the given node and filtered by -// the `match` predicate. -int WeightBelowNodeMatchingFunction( - const TreeView::Node& node, const std::function& match); - -// Returns the sum of weights of nodes below the given node and whose -// unformatted label (i.e. raw format string) matches the given `format` string. -// -// This allows to aggregate nodes whose labels differ only by parameter values. -int WeightBelowNodeMatchingUnformatted(const TreeView::Node& node, - const std::string& format); - -// Returns the sum of weights of nodes below the given node and whose formatted -// label matches the `formatted` string. -// -// In the case of nodes with parametrized labels, this allows to count only -// nodes with specific parameter values. For that purpose, one may also instead -// use WeightBelowNodeMatchingFunction directly, with a `match` predicate -// comparing raw integer parameter values directly, instead of going through -// formatted strings. -int WeightBelowNodeMatchingFormatted(const TreeView::Node& node, - const std::string& formatted); - -// Produces a `node_out` that is a copy of `node_in` but with tree depth below -// it clamped at `depth`, with further subtrees aggregated into single leaf -// nodes. -void CollapseNode(const TreeView::Node& node_in, int depth, - TreeView::Node* node_out); - -// Calls CollapseNode with the given `depth` on every subnode filtered by the -// `match` predicate. Note that this does NOT limit the tree depth below -// `node_out` to `depth`, since each collapsed node below `node_out` may be -// arbitrarily far below it and `depth` is only used as the collapsing depth -// at that point. -void CollapseSubnodesMatchingFunction( - const TreeView::Node& node_in, int depth, - const std::function& match, TreeView::Node* node_out); - -// Calls CollapseNode with the given `depth` on every node filtered by the -// `match` predicate. Note that this does NOT limit the tree depth below -// `node_out` to `depth`, since each collapsed node below `node_out` may be -// arbitrarily far below it and `depth` is only used as the collapsing depth -// at that point. -void CollapseNodesMatchingFunction( - const TreeView& treeview_in, int depth, - const std::function& match, TreeView* treeview_out); - -// Special case of CollapseNodesMatchingFunction matching unformatted labels, -// i.e. raw format strings. -// See the comment on WeightBelowNodeMatchingUnformatted. -void CollapseNodesMatchingUnformatted(const TreeView& treeview_in, int depth, - const std::string& format, - TreeView* treeview_out); - -// Special case of CollapseNodesMatchingFunction matching formatted labels. -// See the comment on WeightBelowNodeMatchingFormatted. -void CollapseNodesMatchingFormatted(const TreeView& treeview_in, int depth, - const std::string& formatted, - TreeView* treeview_out); - -} // namespace profiler -} // namespace ruy - -#endif // RUY_PROFILER - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_PROFILER_TREEVIEW_H_ diff --git a/ruy.h b/ruy.h deleted file mode 100644 index 8b530c6..0000000 --- a/ruy.h +++ /dev/null @@ -1,42 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This is the only Ruy header that users should #include. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_H_ - -#include "context.h" -#include "dispatch.h" -#include "matrix.h" -#include "path.h" -#include "spec.h" - -namespace ruy { - -// Performs a multiplication of matrices. This is Ruy's only API entry point. -// Should be self-explanatory given the above documentation for each of Matrix, -// Spec and Context. -template -void Mul(const Matrix& lhs, const Matrix& rhs, - const Spec& spec, Context* context, Matrix* dst) { - DispatchMul( - lhs, rhs, spec, context, dst); -} - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_H_ diff --git a/ruy/BUILD b/ruy/BUILD new file mode 100644 index 0000000..0b19193 --- /dev/null +++ b/ruy/BUILD @@ -0,0 +1,954 @@ +# Ruy is not BLAS + +load(":build_defs.bzl", "ruy_copts_avx2", "ruy_copts_avxvnni", "ruy_copts_base", "ruy_copts_skylake", "ruy_copts_sse42") +load(":ruy_test_ext.bzl", "ruy_test_ext_defines", "ruy_test_ext_deps") +load(":ruy_test.bzl", "ruy_benchmark", "ruy_test") + +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) + +config_setting( + name = "windows", + values = {"cpu": "x64_windows"}, +) + +config_setting( + name = "armeabi-v7a", + values = {"cpu": "armeabi-v7a"}, +) + +config_setting( + name = "x86_64", + values = {"cpu": "k8"}, +) + +config_setting( + name = "optimized", + values = { + "compilation_mode": "opt", + }, + visibility = ["//visibility:public"], +) + +cc_library( + name = "platform", + hdrs = ["platform.h"], + copts = ruy_copts_base(), +) + +cc_library( + name = "check_macros", + hdrs = ["check_macros.h"], + copts = ruy_copts_base(), +) + +cc_test( + name = "check_macros_test", + srcs = ["check_macros_test.cc"], + copts = ruy_copts_base(), + deps = [ + ":check_macros", + "@com_google_googletest//:gtest", + ], +) + +cc_library( + name = "opt_set", + hdrs = ["opt_set.h"], + copts = ruy_copts_base(), +) + +cc_library( + name = "time", + hdrs = ["time.h"], + copts = ruy_copts_base(), +) + +cc_library( + name = "wait", + srcs = ["wait.cc"], + hdrs = ["wait.h"], + copts = ruy_copts_base(), + deps = [":time"], +) + +cc_test( + name = "wait_test", + srcs = ["wait_test.cc"], + deps = [ + ":platform", + ":wait", + "@com_google_googletest//:gtest", + ], +) + +cc_library( + name = "size_util", + hdrs = ["size_util.h"], + copts = ruy_copts_base(), + deps = [":check_macros"], +) + +cc_test( + name = "size_util_test", + srcs = ["size_util_test.cc"], + deps = [ + ":size_util", + "@com_google_googletest//:gtest", + ], +) + +cc_library( + name = "tune", + srcs = [ + "tune.cc", + ], + hdrs = [ + "tune.h", + ], + copts = ruy_copts_base(), + deps = [ + ":opt_set", + ":platform", + ":time", + ], +) + +cc_library( + name = "prepacked_cache", + srcs = [ + "prepacked_cache.cc", + ], + hdrs = [ + "prepacked_cache.h", + ], + copts = ruy_copts_base(), + deps = [ + ":allocator", + ":matrix", + ":opt_set", + ":platform", + ":time", + "//ruy/profiler:instrumentation", + ], +) + +cc_test( + name = "tune_test", + srcs = ["tune_test.cc"], + deps = [ + ":tune", + "@com_google_googletest//:gtest", + ], +) + +cc_test( + name = "prepacked_cache_test", + srcs = ["prepacked_cache_test.cc"], + deps = [ + ":prepacked_cache", + ":ruy", + ":time", + "@com_google_googletest//:gtest", + ], +) + +cc_binary( + name = "tune_tool", + srcs = ["tune_tool.cc"], + deps = [ + ":tune", + ], +) + +cc_library( + name = "allocator", + srcs = [ + "allocator.cc", + ], + hdrs = [ + "allocator.h", + ], + copts = ruy_copts_base(), + deps = [ + ":check_macros", + ":size_util", + ], +) + +cc_test( + name = "allocator_test", + srcs = ["allocator_test.cc"], + deps = [ + ":allocator", + "@com_google_googletest//:gtest", + ], +) + +cc_library( + name = "side_pair", + hdrs = ["side_pair.h"], + copts = ruy_copts_base(), + deps = [":check_macros"], +) + +cc_library( + name = "block_map", + srcs = [ + "block_map.cc", + ], + hdrs = [ + "block_map.h", + ], + copts = ruy_copts_base(), + deps = [ + ":check_macros", + ":opt_set", + ":path", + ":side_pair", + ":size_util", + "//ruy/profiler:instrumentation", + ], +) + +cc_test( + name = "block_map_test", + srcs = ["block_map_test.cc"], + deps = [ + ":block_map", + ":cpu_cache_size", + ":path", + ":side_pair", + "@com_google_googletest//:gtest", + ], +) + +cc_library( + name = "blocking_counter", + srcs = [ + "blocking_counter.cc", + ], + hdrs = [ + "blocking_counter.h", + ], + copts = ruy_copts_base(), + deps = [ + ":check_macros", + ":wait", + ], +) + +cc_library( + name = "thread_pool", + srcs = [ + "thread_pool.cc", + ], + hdrs = [ + "thread_pool.h", + ], + copts = ruy_copts_base(), + deps = [ + ":blocking_counter", + ":check_macros", + ":wait", + ], +) + +cc_library( + name = "detect_arm", + srcs = [ + "detect_arm.cc", + ], + hdrs = [ + "detect_arm.h", + ], + copts = ruy_copts_base(), +) + +cc_library( + name = "detect_x86", + srcs = [ + "detect_x86.cc", + ], + hdrs = [ + "detect_x86.h", + ], + copts = ruy_copts_base(), + deps = [ + ":platform", + ], +) + +cc_library( + name = "path", + hdrs = ["path.h"], + copts = ruy_copts_base(), + deps = [ + ":platform", + ":size_util", + ], +) + +cc_library( + name = "cpu_cache_size", + hdrs = ["cpu_cache_size.h"], + copts = ruy_copts_base(), + deps = [ + ":path", + ":platform", + ], +) + +cc_library( + name = "trace", + srcs = [ + "trace.cc", + ], + hdrs = [ + "trace.h", + ], + copts = ruy_copts_base(), + deps = [ + ":block_map", + ":check_macros", + ":side_pair", + ":time", + ], +) + +cc_library( + name = "matrix", + hdrs = ["matrix.h"], + copts = ruy_copts_base(), + deps = [":check_macros"], +) + +cc_library( + name = "spec", + hdrs = ["spec.h"], + copts = ruy_copts_base(), + deps = [ + ":cpu_cache_size", + ":matrix", + ], +) + +cc_library( + name = "internal_matrix", + hdrs = ["internal_matrix.h"], + copts = ruy_copts_base(), + deps = [ + ":check_macros", + ":common", + ":matrix", + ":size_util", + ], +) + +cc_library( + name = "common", + hdrs = [ + "common.h", + ], + copts = ruy_copts_base(), + deps = [ + ":check_macros", + ":matrix", + ":opt_set", + ":path", + ":platform", + ], +) + +cc_library( + name = "kernel_common", + hdrs = [ + "kernel.h", + "kernel_arm.h", + "kernel_common.h", + "kernel_x86.h", + ], + copts = ruy_copts_base(), + deps = [ + ":check_macros", + ":common", + ":internal_matrix", + ":matrix", + ":opt_set", + ":path", + ":platform", + ":side_pair", + ":size_util", + ":spec", + ":tune", + "//ruy/profiler:instrumentation", + ], +) + +cc_library( + name = "pack_common", + hdrs = [ + "pack.h", + "pack_arm.h", + "pack_common.h", + "pack_x86.h", + ], + copts = ruy_copts_base(), + deps = [ + ":check_macros", + ":common", + ":internal_matrix", + ":matrix", + ":opt_set", + ":path", + ":platform", + ":tune", + "//ruy/profiler:instrumentation", + ], +) + +cc_library( + name = "kernel_arm", + srcs = [ + "kernel_arm32.cc", + "kernel_arm64.cc", + ], + copts = ruy_copts_base(), + deps = [ + ":common", + ":kernel_common", + ":opt_set", + ":platform", + "//ruy/profiler:instrumentation", + ], +) + +cc_library( + name = "pack_arm", + srcs = [ + "pack_arm.cc", + ], + copts = ruy_copts_base(), + deps = [ + ":common", + ":opt_set", + ":pack_common", + ":platform", + "//ruy/profiler:instrumentation", + ], +) + +# AVX-512 compilation units. +# +# These must use the same compiler options. +RUY_COPTS_BUILT_FOR_AVX512 = ruy_copts_base() + ruy_copts_skylake() + +cc_library( + name = "kernel_avx512", + srcs = [ + "kernel_avx512.cc", + ], + copts = RUY_COPTS_BUILT_FOR_AVX512, + deps = [ + ":check_macros", + ":kernel_common", + ":opt_set", + ":platform", + "//ruy/profiler:instrumentation", + ], +) + +cc_library( + name = "pack_avx512", + srcs = [ + "pack_avx512.cc", + ], + copts = RUY_COPTS_BUILT_FOR_AVX512, + deps = [ + ":check_macros", + ":matrix", + ":opt_set", + ":pack_common", + ":path", + ":platform", + "//ruy/profiler:instrumentation", + ], +) + +cc_library( + name = "have_built_path_for_avx512", + srcs = [ + "have_built_path_for_avx512.cc", + ], + hdrs = [ + "have_built_path_for.h", + ], + copts = RUY_COPTS_BUILT_FOR_AVX512, + deps = [ + ":opt_set", + ":platform", + ], +) +# End: AVX-512 compilation units. + +# AVX2 compilation units. +# +# These must use the same compiler options. +RUY_COPTS_BUILT_FOR_AVX2 = ruy_copts_base() + ruy_copts_avx2() + +cc_library( + name = "kernel_avx2", + srcs = [ + "kernel_avx2.cc", + ], + copts = RUY_COPTS_BUILT_FOR_AVX2, + deps = [ + ":check_macros", + ":kernel_common", + ":opt_set", + ":platform", + "//ruy/profiler:instrumentation", + ], +) + +cc_library( + name = "pack_avx2", + srcs = [ + "pack_avx2.cc", + ], + copts = RUY_COPTS_BUILT_FOR_AVX2, + deps = [ + ":check_macros", + ":matrix", + ":opt_set", + ":pack_common", + ":path", + ":platform", + "//ruy/profiler:instrumentation", + ], +) + +cc_library( + name = "have_built_path_for_avx2", + srcs = [ + "have_built_path_for_avx2.cc", + ], + hdrs = [ + "have_built_path_for.h", + ], + copts = RUY_COPTS_BUILT_FOR_AVX2, + deps = [ + ":opt_set", + ":platform", + ], +) +# End: AVX2 compilation units. + +# SSE42 compilation units. +# +# TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. +# Optimization is not finished. In particular the dimensions of the kernel +# blocks can be changed as desired. +# +# These must use the same compiler options. +RUY_COPTS_BUILT_FOR_SSE42 = ruy_copts_base() + ruy_copts_sse42() + +cc_library( + name = "kernel_sse42", + srcs = [ + "kernel_sse42.cc", + ], + copts = RUY_COPTS_BUILT_FOR_SSE42, + deps = [ + ":check_macros", + ":kernel_common", + ":opt_set", + ":platform", + "//ruy/profiler:instrumentation", + ], +) + +cc_library( + name = "pack_sse42", + srcs = [ + "pack_sse42.cc", + ], + copts = RUY_COPTS_BUILT_FOR_SSE42, + deps = [ + ":check_macros", + ":matrix", + ":opt_set", + ":pack_common", + ":path", + ":platform", + "//ruy/profiler:instrumentation", + ], +) + +cc_library( + name = "have_built_path_for_sse42", + srcs = [ + "have_built_path_for_sse42.cc", + ], + hdrs = [ + "have_built_path_for.h", + ], + copts = RUY_COPTS_BUILT_FOR_SSE42, + deps = [ + ":opt_set", + ":platform", + ], +) +# End: SSE42 compilation units. + +# AVX-VNNI compilation units. +# +# TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. +# Optimization is not finished. In particular the dimensions of the kernel +# blocks can be changed as desired. +# +# These must use the same compiler options. +RUY_COPTS_BUILT_FOR_AVX_VNNI = ruy_copts_base() + ruy_copts_avxvnni() + +cc_library( + name = "kernel_avxvnni", + srcs = [ + "kernel_avxvnni.cc", + ], + copts = RUY_COPTS_BUILT_FOR_AVX_VNNI, + deps = [ + ":check_macros", + ":kernel_common", + ":opt_set", + ":platform", + "//ruy/profiler:instrumentation", + ], +) + +cc_library( + name = "pack_avxvnni", + srcs = [ + "pack_avxvnni.cc", + ], + copts = RUY_COPTS_BUILT_FOR_AVX_VNNI, + deps = [ + ":check_macros", + ":matrix", + ":opt_set", + ":pack_common", + ":path", + ":platform", + "//ruy/profiler:instrumentation", + ], +) + +cc_library( + name = "have_built_path_for_avxvnni", + srcs = [ + "have_built_path_for_avxvnni.cc", + ], + hdrs = [ + "have_built_path_for.h", + ], + copts = RUY_COPTS_BUILT_FOR_AVX_VNNI, + deps = [ + ":opt_set", + ":platform", + ], +) +# End: AVX-VNNI compilation units. + +cc_library( + name = "kernel", + hdrs = [ + "kernel.h", + "kernel_common.h", + ], + copts = ruy_copts_base(), + deps = [ + ":check_macros", + ":common", + ":internal_matrix", + ":kernel_arm", # fixdeps: keep + ":kernel_avx2", # fixdeps: keep + ":kernel_avx512", # fixdeps: keep + ":kernel_avxvnni", # fixdeps: keep + ":kernel_common", + ":kernel_sse42", # fixdeps: keep + ":matrix", + ":opt_set", + ":path", + ":platform", + ":side_pair", + ":size_util", + ":spec", + ":tune", + "//ruy/profiler:instrumentation", + ], +) + +cc_library( + name = "pack", + hdrs = [ + "pack.h", + "pack_common.h", + ], + copts = ruy_copts_base(), + deps = [ + ":check_macros", + ":common", + ":internal_matrix", + ":matrix", + ":opt_set", + ":pack_arm", # fixdeps: keep + ":pack_avx2", # fixdeps: keep + ":pack_avx512", # fixdeps: keep + ":pack_avxvnni", # fixdeps: keep + ":pack_common", + ":pack_sse42", # fixdeps: keep + ":path", + ":platform", + ":tune", + "//ruy/profiler:instrumentation", + ], +) + +cc_library( + name = "have_built_path_for", + hdrs = [ + "have_built_path_for.h", + ], + deps = [ + ":have_built_path_for_avx2", + ":have_built_path_for_avx512", + ":have_built_path_for_avxvnni", + ":have_built_path_for_sse42", + ":platform", + ], +) + +cc_library( + name = "context", + srcs = [ + "context.cc", + ], + hdrs = [ + "context.h", + ], + copts = ruy_copts_base(), + deps = [ + ":allocator", + ":check_macros", + ":detect_arm", + ":detect_x86", + ":have_built_path_for", + ":path", + ":platform", + ":prepacked_cache", + ":thread_pool", + ":trace", + ":tune", + ], +) + +cc_test( + name = "context_test", + srcs = ["context_test.cc"], + deps = [ + ":context", + ":path", + ":platform", + "@com_google_googletest//:gtest", + ], +) + +cc_library( + name = "trmul_params", + hdrs = ["trmul_params.h"], + copts = ruy_copts_base(), + deps = [ + ":internal_matrix", + ":side_pair", + ":tune", + ], +) + +cc_library( + name = "trmul", + srcs = ["trmul.cc"], + hdrs = ["trmul.h"], + copts = ruy_copts_base(), + deps = [ + ":allocator", + ":block_map", + ":check_macros", + ":common", + ":context", + ":internal_matrix", + ":matrix", + ":opt_set", + ":side_pair", + ":size_util", + ":spec", + ":thread_pool", + ":trace", + ":trmul_params", + ":tune", + "//ruy/profiler:instrumentation", + ], +) + +# The main library. +cc_library( + name = "ruy", + srcs = [ + "dispatch.h", + "prepack.h", + ], + hdrs = [ + "ruy.h", + "ruy_advanced.h", + ], + copts = ruy_copts_base(), + deps = [ + ":check_macros", + ":common", + ":context", + ":internal_matrix", + ":kernel", + ":matrix", + ":opt_set", + ":pack", + ":path", + ":prepacked_cache", + ":side_pair", + ":size_util", + ":spec", + ":trmul", + ":trmul_params", + ":tune", + "//ruy/profiler:instrumentation", + ], +) + +# Usage examples. +cc_binary( + name = "example", + srcs = ["example.cc"], + deps = [":ruy"], +) + +# Usage examples of the advanced API. +cc_binary( + name = "example_advanced", + srcs = ["example_advanced.cc"], + deps = [":ruy"], +) + +# Small library to query PMU counters, for benchmark only +cc_library( + name = "pmu", + testonly = True, + srcs = ["pmu.cc"], + hdrs = ["pmu.h"], + copts = ruy_copts_base(), + deps = [":check_macros"], +) + +# Testing framework. +cc_library( + name = "test_lib", + testonly = True, + hdrs = ["test.h"], + copts = ruy_copts_base(), + # need defines, not copts, because it's controlling a header, test.h + defines = ruy_test_ext_defines(), + linkopts = select({ + ":windows": [], + "//conditions:default": ["-lm"], + }), + deps = [ + ":matrix", + ":pmu", + ":ruy", + ":spec", + ":time", + "@com_google_googletest//:gtest", + ":platform", + "//ruy/profiler:profiler", + ] + ruy_test_ext_deps(), +) + +ruy_benchmark( + name = "benchmark", + srcs = ["benchmark.cc"], + copts = ruy_copts_base(), + lhs_rhs_accum_dst = [ + ("f32", "f32", "f32", "f32"), + ("u8", "u8", "i32", "u8"), + ("i8", "i8", "i32", "u8"), + ("i8", "i8", "i32", "i8"), + ("u8", "u8", "i32", "i16"), + ("i8", "i8", "i32", "i32"), + ], + deps = [ + "//ruy:test_lib", + "//ruy/profiler:instrumentation", + ], +) + +ruy_test( + name = "test_fast", + srcs = ["test_fast.cc"], + copts = ruy_copts_base(), + lhs_rhs_accum_dst = [ + ("f32", "f32", "f32", "f32"), + ("f64", "f32", "f64", "f32"), + ("f32", "f64", "f64", "f64"), + ("u8", "u8", "i32", "u8"), + ("i8", "i8", "i32", "i8"), + ("i8", "u8", "i32", "i8"), + ("u8", "u8", "i32", "i16"), + ("i8", "i8", "i32", "i32"), + ("i8", "u8", "i32", "i32"), + ], + deps = [ + "@com_google_googletest//:gtest_main", + "//ruy:test_lib", + ], +) + +ruy_test( + name = "test_slow", + srcs = ["test_slow.cc"], + copts = ruy_copts_base(), + lhs_rhs_accum_dst = [ + ("f32", "f32", "f32", "f32"), + ("u8", "u8", "i32", "u8"), + ("i8", "i8", "i32", "i8"), + ("u8", "u8", "i32", "i16"), + ("i8", "i8", "i32", "i32"), + ], + tags = ["slow"], + deps = [ + "@com_google_googletest//:gtest_main", + "//ruy:test_lib", + ], +) + +ruy_test( + name = "test_special_specs", + srcs = ["test_special_specs.cc"], + copts = ruy_copts_base(), + lhs_rhs_accum_dst = [ + ("f32", "f32", "f32", "f32"), + ("u8", "u8", "i32", "u8"), + ("u8", "u8", "i32", "i16"), + ], + deps = [ + "@com_google_googletest//:gtest_main", + "//ruy:test_lib", + ], +) diff --git a/ruy/allocator.cc b/ruy/allocator.cc new file mode 100644 index 0000000..d8fb738 --- /dev/null +++ b/ruy/allocator.cc @@ -0,0 +1,51 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/allocator.h" + +#include +#include + +#ifdef _WIN32 +#include +#endif + +namespace ruy { + +namespace detail { + +void *SystemAlignedAlloc(std::ptrdiff_t num_bytes) { +#ifdef _WIN32 + return _aligned_malloc(num_bytes, kMinimumBlockAlignment); +#else + void *ptr; + if (posix_memalign(&ptr, kMinimumBlockAlignment, num_bytes)) { + return nullptr; + } + return ptr; +#endif +} + +void SystemAlignedFree(void *ptr) { +#ifdef _WIN32 + _aligned_free(ptr); +#else + free(ptr); +#endif +} + +} // namespace detail + +} // namespace ruy diff --git a/ruy/allocator.h b/ruy/allocator.h new file mode 100644 index 0000000..b0379b1 --- /dev/null +++ b/ruy/allocator.h @@ -0,0 +1,185 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_ALLOCATOR_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_ALLOCATOR_H_ + +#include +#include +#include +#include + +#include "ruy/check_macros.h" +#include "ruy/size_util.h" + +namespace ruy { + +namespace detail { + +inline void* VoidPtrAdd(void* p, std::ptrdiff_t offset) { + RUY_DCHECK(p); + std::uintptr_t addr = reinterpret_cast(p) + offset; + return reinterpret_cast(addr); +} + +// Minimum alignment for blocks. +// +// Considerations: +// - This needs to be at least the alignment of any usual data type. +// - It's useful that this is at least the size of a cache line to limit +// possible cache side effects (if only on performance behavior). +// - It's useful that this is at least the size of SIMD registers, as +// some SIMD instruction sets have at least performance behavior +// differences (e.g. NEON) or even different requirements (e.g. SSE) +// based on that. +// - It's useful that this is at least the size of an "exclusive reservation +// granule" on ARM, meaning that if we use this Allocator to allocate +// an atomic variable, there will be no side effects from other things +// contending for exclusive/atomic memory accesses to it. While the +// ARM reference manual mentions that this granule size may be as large +// as 2048 bytes, in practice we observe it to be 64 bytes. It can +// be queried cheaply, at runtime, from userspace, if needed. +static constexpr std::ptrdiff_t kMinimumBlockAlignment = 64; + +// Primitive allocation functions obtaining aligned memory from the +// operating system. +void* SystemAlignedAlloc(std::ptrdiff_t num_bytes); +void SystemAlignedFree(void* ptr); + +// Specialized allocator designed to converge to a steady-state where all +// allocations are bump-ptr allocations from an already-allocated buffer. +// +// To support these constraints, this allocator only supports two +// operations. +// - AllocateAlignedBytes: allocates a pointer to storage of a specified +// size, which must be aligned to kMinimumBlockAlignment. +// - FreeAll: frees all previous allocations (but retains the internal +// buffer to minimize future calls into the system allocator). +// +// This class is specialized for supporting just those two operations +// under this specific steady-state usage pattern. Extending this class +// with new allocation interfaces that don't fit that pattern is probably not +// the right choice. Instead, build a new class on top of +// SystemAlignedAlloc/SystemAlignedFree. +// +// All operations happen on aligned blocks for simplicity. +class AlignedAllocator { + public: + void operator=(const AlignedAllocator&) = delete; + ~AlignedAllocator() { + FreeAll(); + SystemAlignedFree(ptr_); + } + + void* AllocateAlignedBytes(std::ptrdiff_t num_bytes) { + RUY_DCHECK_GT(num_bytes, 0); + RUY_DCHECK((num_bytes & (kMinimumBlockAlignment - 1)) == 0); + if (void* p = AllocateFast(num_bytes)) { + return p; + } + return AllocateSlow(num_bytes); + } + + void FreeAll() { + current_ = 0; + if (fallback_blocks_.empty()) { + return; + } + + // No rounding-up of the size means linear instead of logarithmic + // bound on the number of allocation in some worst-case calling patterns. + // This is considered worth it because minimizing memory usage is important + // and actual calling patterns in applications that we care about still + // reach the no-further-allocations steady state in a small finite number + // of iterations. + std::ptrdiff_t new_size = size_ + fallback_blocks_total_size_; + SystemAlignedFree(ptr_); + ptr_ = SystemAlignedAlloc(new_size); + size_ = new_size; + + for (void* p : fallback_blocks_) { + SystemAlignedFree(p); + } + fallback_blocks_.clear(); + fallback_blocks_total_size_ = 0; + } + + private: + void* AllocateFast(std::ptrdiff_t num_bytes) { + if (current_ + num_bytes > size_) { + return nullptr; + } + void* ret = VoidPtrAdd(ptr_, current_); + current_ += num_bytes; + return ret; + } + + void* AllocateSlow(std::ptrdiff_t num_bytes) { + void* p = SystemAlignedAlloc(num_bytes); + fallback_blocks_total_size_ += num_bytes; + fallback_blocks_.push_back(p); + return p; + } + + // Theory of operation: + // + // - ptr_, current_, and size_ implement a basic bump-ptr allocator. + // + // - in AllocateAlignedBytes, the fast path is just a bump-ptr + // allocation. If our bump-ptr allocator doesn't have enough space for an + // allocation, then we allocate a block from the system allocator to + // service the allocation request. We save that block in fallback_blocks_ + // and track the total size of the fallback blocks in + // fallback_blocks_total_size_. + // + // - in FreeAll, the fast path just resets the bump-ptr allocator. If + // there are any fallback blocks, we free them and reallocate the + // bump-ptr allocator's buffer so that the next sequence of allocations + // will hopefully not need any fallback blocks. + void* ptr_ = nullptr; + std::ptrdiff_t current_ = 0; + std::ptrdiff_t size_ = 0; + std::vector fallback_blocks_; + std::ptrdiff_t fallback_blocks_total_size_ = 0; +}; + +} // namespace detail + +// The main Allocator class, with a convenient interface for allocating a +// typed buffer. +class Allocator { + public: + void* AllocateBytes(std::ptrdiff_t num_bytes) { + if (num_bytes == 0) { + return nullptr; + } + return aligned.AllocateAlignedBytes( + round_up_pot(num_bytes, detail::kMinimumBlockAlignment)); + } + template + void Allocate(std::ptrdiff_t count, Pointer* out) { + using T = typename std::pointer_traits::element_type; + *out = static_cast(AllocateBytes(count * sizeof(T))); + } + + void FreeAll() { aligned.FreeAll(); } + + private: + detail::AlignedAllocator aligned; +}; + +} // namespace ruy + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_ALLOCATOR_H_ diff --git a/ruy/allocator_test.cc b/ruy/allocator_test.cc new file mode 100644 index 0000000..7f46a66 --- /dev/null +++ b/ruy/allocator_test.cc @@ -0,0 +1,103 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/allocator.h" + +#include "testing/base/public/gunit.h" + +namespace ruy { +namespace { + +TEST(AllocatorTest, ReturnsValidMemory) { + Allocator allocator; + int *p; + allocator.Allocate(1, &p); + ASSERT_NE(p, nullptr); + + // If this is bogus memory, ASan will cause this test to fail. + *p = 42; + + allocator.FreeAll(); +} + +TEST(AllocatorTest, NoLeak) { + Allocator allocator; + // Allocate and free some ridiculously large total amount of memory, so + // that a leak will hopefully cause some sort of resource exhaustion. + // + // Despite the large number of allocations, this test is actually quite + // fast, since our fast-path allocation logic is very fast. + constexpr int kNumAllocations = 100 * 1024; + constexpr int kAllocationSize = 1024 * 1024; + for (int i = 0; i < kNumAllocations; i++) { + char *p; + allocator.Allocate(kAllocationSize, &p); + allocator.FreeAll(); + } +} + +TEST(AllocatorTest, IncreasingSizes) { + Allocator allocator; + // Allocate sizes that increase by small amounts across FreeAll calls. + for (int i = 1; i < 100 * 1024; i++) { + char *p; + allocator.Allocate(i, &p); + allocator.FreeAll(); + } +} + +TEST(AllocatorTest, ManySmallAllocations) { + Allocator allocator; + // Allocate many small allocations between FreeAll calls. + for (int i = 0; i < 10 * 1024; i += 100) { + for (int j = 0; j < i; j++) { + char *p; + allocator.Allocate(1, &p); + } + allocator.FreeAll(); + } +} + +TEST(AllocatorTest, DestructorHandlesMainBumpPtr) { + // This is a white-box test. + Allocator allocator; + allocator.AllocateBytes(1); + allocator.FreeAll(); + // After the call to FreeAll, the allocator will consolidate all of the memory + // into the main bump-ptr allocator's block, which we then expect to be freed + // in the destructor. + // + // We have no test assertions -- we primarily expect that this trigger a leak + // checker and cause the test to fail. +} + +TEST(AllocatorTest, DestructorHandlesFallbackBlocks) { + // This is a white-box test. + Allocator allocator; + // Since we just created the allocator, this will allocate a fallback block, + // which we then expect to be freed in the destructor. + // + // We have no test assertions -- we primarily expect that this trigger a leak + // checker and cause the test to fail. + allocator.AllocateBytes(1); +} + +} // namespace +} // namespace ruy + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/ruy/benchmark.cc b/ruy/benchmark.cc new file mode 100644 index 0000000..6ce0b32 --- /dev/null +++ b/ruy/benchmark.cc @@ -0,0 +1,196 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "ruy/test.h" + +namespace ruy { + +using LhsScalar = RUY_TEST_LHSSCALAR; +using RhsScalar = RUY_TEST_RHSSCALAR; +using AccumScalar = RUY_TEST_ACCUMSCALAR; +using DstScalar = RUY_TEST_DSTSCALAR; +using TestSetType = + TestSet>; + +struct BenchmarkShape { + int rows; + int depth; + int cols; + int symm_lhs; + int symm_rhs; +}; + +template +std::vector>> BenchmarkRCC( + const BenchmarkShape& shape) { + TestSetType test_set; + test_set.rows = shape.rows; + test_set.depth = shape.depth; + test_set.cols = shape.cols; + test_set.lhs_order = Order::kRowMajor; + test_set.rhs_order = Order::kColMajor; + test_set.dst_order = Order::kColMajor; + test_set.layout_style = LayoutStyle::kPackedLinear; + test_set.benchmark = true; + const int asymmetry_lhs = shape.symm_lhs ? 0 : 1; + const int asymmetry_rhs = shape.symm_rhs ? 0 : 1; + test_set.lhs_zero_point = SymmetricZeroPoint() + asymmetry_lhs; + test_set.rhs_zero_point = SymmetricZeroPoint() + asymmetry_rhs; + test_set.use_specified_zero_points = true; + test_set.perchannel = GetBoolEnvVarOrFalse("PERCHANNEL"); + test_set.benchmark_prepack_lhs = GetBoolEnvVarOrFalse("PREPACK_LHS"); + test_set.benchmark_prepack_rhs = GetBoolEnvVarOrFalse("PREPACK_RHS"); + test_set.Run(); + return std::move(test_set.results); +} + +std::vector ParseCommaSeparatedInts( + const std::string& comma_separated_ints) { + std::vector result; + for (std::size_t pos = 0; pos < comma_separated_ints.size();) { + std::size_t delim_pos = comma_separated_ints.find(',', pos); + if (delim_pos == std::string::npos) { + delim_pos = comma_separated_ints.size(); + } + result.push_back( + std::stoi(comma_separated_ints.substr(pos, delim_pos - pos))); + pos = delim_pos + 1; + } + return result; +} + +void Benchmark() { + const bool symm_lhs = std::is_floating_point::value || + GetBoolEnvVarOrFalse("SYMM_LHS"); + const bool symm_rhs = std::is_floating_point::value || + GetBoolEnvVarOrFalse("SYMM_RHS"); + const bool benchmark_cubic = GetBoolEnvVarOrFalse("RUY_BENCHMARK_CUBIC") || + GetBoolEnvVarOrFalse("RUY_BENCHMARK_CUBIC_LIST"); + const int explicit_rows = GetIntEnvVarOrZero("ROWS"); + const int explicit_cols = GetIntEnvVarOrZero("COLS"); + const int explicit_depth = GetIntEnvVarOrZero("DEPTH"); + + std::vector shapes; + + if (benchmark_cubic) { + std::vector sizes; + const char* benchmark_cubic_list_env = getenv("RUY_BENCHMARK_CUBIC_LIST"); + if (benchmark_cubic_list_env) { + sizes = ParseCommaSeparatedInts(benchmark_cubic_list_env); + } else { + // Often 8 is used for this multiplier, but to check teeny sizes one can + // use 1. + static constexpr int cubic_size_multiplier = 8; + for (int i = 2 * cubic_size_multiplier; + i <= (512 * cubic_size_multiplier); i *= 2) { + sizes.push_back(i); + if (i < (512 * cubic_size_multiplier)) { + sizes.push_back(i * 3 / 2); + } + } + } + for (int i : sizes) { + BenchmarkShape shape; + // Even in cubic mode, one may still override an individual dimension + // to allow testing a batch of rectangular sizes. + shape.rows = explicit_rows ? explicit_rows : i; + shape.cols = explicit_cols ? explicit_cols : i; + shape.depth = explicit_depth ? explicit_depth : i; + shape.symm_lhs = symm_lhs; + shape.symm_rhs = symm_rhs; + shapes.push_back(shape); + } + } else { + BenchmarkShape shape; + shape.rows = explicit_rows; + shape.cols = explicit_cols; + shape.depth = explicit_depth; + if (!shape.rows || !shape.depth || !shape.cols) { + fprintf(stderr, + "Please specify positive sizes with these env vars: ROWS, DEPTH, " + "COLS.\n"); + exit(1); + } + shape.symm_lhs = symm_lhs; + shape.symm_rhs = symm_rhs; + shapes.push_back(shape); + } + + for (int i = 0; i < shapes.size(); i++) { + const auto& shape = shapes[i]; + const auto& results = BenchmarkRCC(shape); + if (i == 0) { + if (benchmark_cubic) { + printf("size"); + for (const auto& result : results) { + if (results.size() > 1) { + printf(",%s:Gop/s", PathName(*result).c_str()); + } else { + printf(",Gop/s"); + } + if (GetBoolEnvVarOrFalse("RUY_BENCHMARK_PMU")) { + printf( + ",l1_refill,l2_refill,l3_refill,l1tlb_refill,l2tlb_refill," + "mispred,frontend_stall,backend_stall"); + } + } + printf("\n"); + } else { + printf("path,shape,Gop/s\n"); + } + fflush(stdout); + } + if (benchmark_cubic) { + printf("%d", shape.rows); + for (const auto& result : results) { + printf(",%.4g", 2.0e-9 * shape.rows * shape.cols * shape.depth / + result->latency); + if (GetBoolEnvVarOrFalse("RUY_BENCHMARK_PMU")) { + printf(",%.3g,%.3g,%.3g,%.3g,%.3g,%.3g,%.3g,%.3g", + result->l1_refill_rate, result->l2_refill_rate, + result->l3_refill_rate, result->l1tlb_refill_rate, + result->l2tlb_refill_rate, result->mispred_rate, + result->frontend_stall_rate, result->backend_stall_rate); + } + } + printf("\n"); + fflush(stdout); + } else { + for (const auto& result : results) { + printf( + "%s,%dx%dx%d,%.4g", PathName(*result).c_str(), shape.rows, + shape.depth, shape.cols, + 2.0e-9 * shape.rows * shape.cols * shape.depth / result->latency); + if (GetBoolEnvVarOrFalse("RUY_BENCHMARK_PMU")) { + printf(",%.3g,%.3g,%.3g,%.3g,%.3g,%.3g,%.3g,%.3g", + result->l1_refill_rate, result->l2_refill_rate, + result->l3_refill_rate, result->l1tlb_refill_rate, + result->l2tlb_refill_rate, result->mispred_rate, + result->frontend_stall_rate, result->backend_stall_rate); + } + printf("\n"); + } + fflush(stdout); + } + } +} + +} // namespace ruy + +int main() { ruy::Benchmark(); } diff --git a/ruy/block_map.cc b/ruy/block_map.cc new file mode 100644 index 0000000..e1e6166 --- /dev/null +++ b/ruy/block_map.cc @@ -0,0 +1,486 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/block_map.h" + +#include +#include + +#ifdef RUY_MAKEBLOCKMAP_DEBUG +#include +#include +#include +#endif + +#include "ruy/check_macros.h" +#include "ruy/opt_set.h" +#include "ruy/profiler/instrumentation.h" +#include "ruy/size_util.h" + +namespace ruy { + +namespace { + +void DecodeTraversalLinear(int size_log2, std::uint32_t square_index, + SidePair* local_pos) { + (*local_pos)[Side::kLhs] = square_index & ((1 << size_log2) - 1); + (*local_pos)[Side::kRhs] = square_index >> size_log2; +} + +void DecodeTraversalFractalZ(std::uint32_t square_index, + SidePair* local_pos) { + const std::uint32_t n1 = square_index; + const std::uint32_t n2 = (n1 & 0x99999999u) | ((n1 & 0x44444444u) >> 1) | + ((n1 & 0x22222222u) << 1); + const std::uint32_t n4 = (n2 & 0xc3c3c3c3u) | ((n2 & 0x30303030u) >> 2) | + ((n2 & 0x0c0c0c0cu) << 2); + const std::uint32_t n8 = (n4 & 0xf00ff00fu) | ((n4 & 0x0f000f00u) >> 4) | + ((n4 & 0x00f000f0u) << 4); + const std::uint32_t n16 = (n8 & 0xff0000ffu) | ((n8 & 0x00ff0000u) >> 8) | + ((n8 & 0x0000ff00u) << 8); + (*local_pos)[Side::kLhs] = n16 & 0xffff; + (*local_pos)[Side::kRhs] = n16 >> 16; +} + +void DecodeTraversalFractalU(std::uint32_t square_index, + SidePair* local_pos) { + DecodeTraversalFractalZ(square_index, local_pos); + // Change fractal z-order to u-order + (*local_pos)[Side::kLhs] ^= (*local_pos)[Side::kRhs]; +} + +// Code inspired by the sample code in +// https://en.wikipedia.org/wiki/Hilbert_curve +// The main optimization is to avoid hard-to-predict conditional branches +// based on the bits of the square_index parameter. +void DecodeTraversalFractalHilbert(int size_log2, std::uint32_t square_index, + SidePair* local_pos) { + std::uint32_t t = square_index; + std::uint32_t x = 0; + std::uint32_t y = 0; + // Easy-to-predict for loop, the number of iterations is the same for + // an entire GEMM. + for (int sb = 0; sb < size_log2; sb++) { + std::uint32_t s = 1 << sb; + bool rx = t & 2; + bool ry = (t & 1) ^ rx; + std::uint32_t tmp = rx ? (s - 1 - x) : x; + x = ry ? x : rx ? (s - 1 - y) : y; + y = ry ? (y + s) : tmp; + x = rx ? (x + s) : x; + t >>= 2; + } + (*local_pos)[Side::kLhs] = y; + (*local_pos)[Side::kRhs] = x; +} + +} // end anonymous namespace + +void GetBlockByIndex(const BlockMap& block_map, int index, + SidePair* block) { + profiler::ScopeLabel label("GetBlockByIndex"); + const std::uint32_t index_u32 = index; + + const std::uint32_t num_blocks_per_local_curve = + 1u << (2 * block_map.num_blocks_base_log2); + const std::uint32_t square_index = + index_u32 & (num_blocks_per_local_curve - 1); + + const int size_log2 = block_map.num_blocks_base_log2; + SidePair local_pos; + switch (block_map.traversal_order) { + case BlockMapTraversalOrder::kFractalZ: + DecodeTraversalFractalZ(square_index, &local_pos); + break; + case BlockMapTraversalOrder::kFractalU: + DecodeTraversalFractalU(square_index, &local_pos); + break; + case BlockMapTraversalOrder::kFractalHilbert: + DecodeTraversalFractalHilbert(size_log2, square_index, &local_pos); + break; + default: + RUY_DCHECK(block_map.traversal_order == BlockMapTraversalOrder::kLinear); + DecodeTraversalLinear(size_log2, square_index, &local_pos); + break; + } + + const std::uint32_t rectangular_index = + index_u32 >> 2 * block_map.num_blocks_base_log2; + for (Side side : {Side::kLhs, Side::kRhs}) { + const std::uint32_t mask = (1u << block_map.rectangularness_log2[side]) - 1; + const int rectangular_offset = (rectangular_index & mask) + << block_map.num_blocks_base_log2; + (*block)[side] = local_pos[side] + rectangular_offset; + } +} + +BlockMapTraversalOrder GetTraversalOrder(int rows, int cols, int depth, + int lhs_scalar_size, + int rhs_scalar_size, + int local_data_cache_size, + int shared_data_cache_size) { + const int kFractalOptSets = + RUY_OPT_FRACTAL_Z | RUY_OPT_FRACTAL_U | RUY_OPT_FRACTAL_HILBERT; + const int working_set_size = + (lhs_scalar_size * rows + rhs_scalar_size * cols) * depth; + if (RUY_OPT_ENABLED(kFractalOptSets) && + (working_set_size > local_data_cache_size)) { + if (RUY_OPT_ENABLED(RUY_OPT_FRACTAL_HILBERT) && + (working_set_size > shared_data_cache_size)) { + return BlockMapTraversalOrder::kFractalHilbert; + } else if (RUY_OPT_ENABLED(RUY_OPT_FRACTAL_U)) { + return BlockMapTraversalOrder::kFractalU; + } else { + return BlockMapTraversalOrder::kFractalZ; + } + } else { + return BlockMapTraversalOrder::kLinear; + } +} + +namespace { + +int floor_log2_quotient(int num, int denom) { + if (num <= denom) { + return 0; + } + int log2_quotient = floor_log2(num) - ceil_log2(denom); + if ((denom << (log2_quotient + 1)) <= num) { + log2_quotient++; + } + return log2_quotient; +} + +// Computes the rectangularness of the matrix shape (rows, cols). This is +// essentially just the log2 of the quotient (rows / cols). The kernel_rows and +// kernel_cols only get into the picture for clamping bounds but don't affect +// the generic computation. +void GetRectangularness(int rows, int cols, int kernel_rows, int kernel_cols, + int* rows_rectangularness_log2, + int* cols_rectangularness_log2) { + *rows_rectangularness_log2 = 0; + *cols_rectangularness_log2 = 0; + + // In GEMV-ish cases, that is when kernel blocks are as narrow as the kernel + // itself, we risk having too small kernel blocks for good kernel + // amortization. We avoid that by limiting recangularness so that kernel + // blocks are not too tiny at least in that dimension. Specifically, we try to + // have at least (2^min_kernel_inner_loop_runs_log2) kernels fitting in each + // kernel block along the large dimension. + const int min_kernel_inner_loop_runs_log2 = 3; + if (rows > cols) { + int cols_of_kernel_inner_loop_runs_log2 = + ceil_log2(cols) - pot_log2(kernel_cols); + int min_rows_of_kernel_inner_loop_runs_log2 = + std::max(0, min_kernel_inner_loop_runs_log2 - + cols_of_kernel_inner_loop_runs_log2); + *rows_rectangularness_log2 = + std::min(floor_log2_quotient(rows, cols), + std::max(0, floor_log2(rows) - pot_log2(kernel_rows) - + min_rows_of_kernel_inner_loop_runs_log2)); + // Sanity check that we did not over-estimate rows_rectangularness_log2. + RUY_DCHECK_GE(rows >> *rows_rectangularness_log2, cols); + } else if (cols > rows) { + int rows_of_kernel_inner_loop_runs_log2 = + ceil_log2(rows) - pot_log2(kernel_rows); + int min_cols_of_kernel_inner_loop_runs_log2 = + std::max(0, min_kernel_inner_loop_runs_log2 - + rows_of_kernel_inner_loop_runs_log2); + *cols_rectangularness_log2 = + std::min(floor_log2_quotient(cols, rows), + std::max(0, floor_log2(cols) - pot_log2(kernel_cols) - + min_cols_of_kernel_inner_loop_runs_log2)); + // Sanity check that we did not over-estimate cols_rectangularness_log2. + RUY_DCHECK_GE(cols >> *cols_rectangularness_log2, rows); + } + RUY_DCHECK(!*rows_rectangularness_log2 || !*cols_rectangularness_log2); +} + +// Computes a 'multithreading score'. When multithreading, we need there to +// be at least as many tiles as there are threads, and hopefully +// substantially more than that, so we benefit from ruy's ability to +// dispatch fine-grained workloads to threads. +int GetMultithreadingScore(int block_size_log2, int rows, int cols, + int tentative_thread_count) { + const int num_full_blocks_of_rows = rows >> block_size_log2; + const int num_full_blocks_of_cols = cols >> block_size_log2; + const int candidate_num_full_blocks_log2 = floor_log2( + std::max(1, num_full_blocks_of_rows * num_full_blocks_of_cols)); + + // The values here have been tuned on ARM Cortex-A55. + // We expect this to have to be tuned differently for other CPUs. + if (tentative_thread_count == 1) { + return 0; + } else { + const int blocks_per_thread_log2 = + candidate_num_full_blocks_log2 - ceil_log2(tentative_thread_count); + if (blocks_per_thread_log2 < 0) { + return -64; + } else if (blocks_per_thread_log2 == 0) { + return -16; + } else if (blocks_per_thread_log2 == 1) { + return -8; + } else if (blocks_per_thread_log2 == 2) { + return 0; + } else if (blocks_per_thread_log2 == 3) { + return 8; + } else { + return 16; + } + } +} + +// Computes a 'cache locality score'. +int GetCacheLocalityScore(int block_size_log2, int rows, int cols, int depth, + int kernel_rows_log2, int kernel_cols_log2, + int lhs_scalar_size, int rhs_scalar_size, Path path, + int local_data_cache_size) { + // In the narrow case (e.g. matrix*vector), each byte of the big operand + // matrix (either LHS or RHS) is traversed only once, so any notion of data + // locality is irrelevant. Ignore the 'cache locality score' by forcing it to + // be 0 in that case. + if (rows <= (1 << kernel_rows_log2) || cols <= (1 << kernel_cols_log2)) { + return 0; + } + const int block_rows = std::min(1 << block_size_log2, rows); + const int block_cols = std::min(1 << block_size_log2, cols); + const int total_read_bytes = + (lhs_scalar_size * block_rows + rhs_scalar_size * block_cols) * depth; + const int total_read_bytes_log2 = ceil_log2(total_read_bytes); + const int nonlocality_log2 = + total_read_bytes_log2 - floor_log2(local_data_cache_size); + // The values here have been tuned on ARM Cortex-A55. + // We expect this to have to be tuned differently for other CPUs. + if (nonlocality_log2 < -1) { + return 64; + } else if (nonlocality_log2 == -1) { + return 56; + } else if (nonlocality_log2 == 0) { + return 48; + } else if (nonlocality_log2 == 1) { + return 32; + } else if (nonlocality_log2 == 2) { + return 16; + } else if (nonlocality_log2 == 3) { + return 0; + } else { + return -64; + } +} + +// Compute a 'kernel amortization score'. This is the notion that very small +// tiles result in more overhead outside of kernels, more complex memory +// access patterns and less benefits from ruy's fat kernels, so we reward +// larger blocks more than smaller ones. +int GetKernelAmortizationScore(int block_size_log2, int rows, int cols, + int kernel_rows_log2, int kernel_cols_log2) { + const int block_rows = std::min(1 << block_size_log2, rows); + const int block_cols = std::min(1 << block_size_log2, cols); + const int kernels_per_block_log2 = + floor_log2(block_rows * block_cols) - kernel_rows_log2 - kernel_cols_log2; + RUY_DCHECK_GE(kernels_per_block_log2, 0); + // The values here have been tuned on ARM Cortex-A55. + // We expect this to have to be tuned differently for other CPUs. + if (kernels_per_block_log2 == 0) { + return 0; + } else if (kernels_per_block_log2 == 1) { + return 8; + } else if (kernels_per_block_log2 == 2) { + return 16; + } else if (kernels_per_block_log2 == 3) { + return 24; + } else if (kernels_per_block_log2 == 4) { + return 32; + } else if (kernels_per_block_log2 == 5) { + return 40; + } else if (kernels_per_block_log2 == 6) { + return 48; + } else if (kernels_per_block_log2 == 7) { + return 56; + } else { + return 64; + } +} + +} // namespace + +void MakeBlockMap(int rows, int cols, int depth, int kernel_rows, + int kernel_cols, int lhs_scalar_size, int rhs_scalar_size, + int tentative_thread_count, Path path, + int local_data_cache_size, int shared_data_cache_size, + BlockMap* block_map) { + profiler::ScopeLabel label("MakeBlockMap"); + +#ifdef RUY_MAKEBLOCKMAP_DEBUG +#if RUY_MAKEBLOCKMAP_DEBUG >= 2 + static constexpr bool debug_everytime = true; +#else + static constexpr bool debug_everytime = false; +#endif + static bool firsttime = true; + if (firsttime || debug_everytime) { + fprintf(stderr, + "MakeBlockMap(rows=%d, cols=%d, depth=%d, kernel_rows=%d, " + "kernel_cols=%d, lhs_scalar_size=%d, rhs_scalar_size=%d, " + "tentative_thread_count=%d)\n", + rows, cols, depth, kernel_rows, kernel_cols, lhs_scalar_size, + rhs_scalar_size, tentative_thread_count); + } +#endif + + RUY_DCHECK_GE(rows, kernel_rows); + RUY_DCHECK_GE(cols, kernel_cols); + RUY_DCHECK_EQ(rows % kernel_rows, 0); + RUY_DCHECK_EQ(cols % kernel_cols, 0); + + block_map->traversal_order = + GetTraversalOrder(rows, cols, depth, lhs_scalar_size, rhs_scalar_size, + local_data_cache_size, shared_data_cache_size); + + int rows_rectangularness_log2 = 0; + int cols_rectangularness_log2 = 0; + GetRectangularness(rows, cols, kernel_rows, kernel_cols, + &rows_rectangularness_log2, &cols_rectangularness_log2); + + const int kernel_rows_log2 = pot_log2(kernel_rows); + const int kernel_cols_log2 = pot_log2(kernel_cols); + const int kernel_size_log2 = std::max(kernel_cols_log2, kernel_rows_log2); + + const int size = std::min(rows, cols); + const int size_log2 = std::max(kernel_size_log2, floor_log2(size)); + + RUY_DCHECK_GE(size_log2, kernel_size_log2); + + // We are going to try candidate values for block_size_log2 ranging from + // kernel_size_log2 to (kernel_size_log2 + kMaxKernelsPerBlockLog2). + // For each of them we will compute a 'score' by adding individual scores + // for a few different considerations, all of which is entirely empirical. + // The values (and possibly the logic) around here are all subject to tuning + // based on benchmarks on different hardware. The current values are based + // on benchmarking on Qualcomm S855 (big and little cores), arm64, + // kNeonDotprod, 8bit quantized path. Don't read too much into it, go ahead + // and tune this as needed to achieve good performance elsewhere. Use + // the unit test, block_map_test, to encode values that should be preserved + // on specific architectures. Use RUY_MAKEBLOCKMAP_DEBUG to help tuning this. + static constexpr int kMaxKernelsPerBlockLog2 = 6; + const int max_block_size_log2 = + std::min(size_log2, kernel_size_log2 + kMaxKernelsPerBlockLog2); + int best_score = std::numeric_limits::min(); + int best_score_block_size_log2 = -1; + for (int block_size_log2 = kernel_size_log2; + block_size_log2 <= max_block_size_log2; block_size_log2++) { + const int multithreading_score = GetMultithreadingScore( + block_size_log2, rows, cols, tentative_thread_count); + const int cache_locality_score = GetCacheLocalityScore( + block_size_log2, rows, cols, depth, kernel_rows_log2, kernel_cols_log2, + lhs_scalar_size, rhs_scalar_size, path, local_data_cache_size); + const int kernel_amortization_score = GetKernelAmortizationScore( + block_size_log2, rows, cols, kernel_rows_log2, kernel_cols_log2); + const int score = + multithreading_score + cache_locality_score + kernel_amortization_score; +#ifdef RUY_MAKEBLOCKMAP_DEBUG + if (firsttime || debug_everytime) { + fprintf(stderr, + "block_size_log2=%d: score=%d multithreading_score=%d " + "cache_locality_score=%d kernel_amortization_score=%d\n", + block_size_log2, score, multithreading_score, + cache_locality_score, kernel_amortization_score); + } +#endif + if (score >= best_score) { + best_score = score; + best_score_block_size_log2 = block_size_log2; + } + } + +#ifdef RUY_MAKEBLOCKMAP_DEBUG + if (firsttime || debug_everytime) { + fprintf(stderr, "best_score_block_size_log2=%d\n", + best_score_block_size_log2); + } + + static const char* explicit_block_size_log2_env = + getenv("RUY_MAKEBLOCKMAP_EXPLICIT_BLOCK_SIZE_LOG2"); + if (explicit_block_size_log2_env) { + best_score_block_size_log2 = std::stoi(explicit_block_size_log2_env); + if (firsttime || debug_everytime) { + fprintf(stderr, "Overridden best_score_block_size_log2=%d\n", + best_score_block_size_log2); + } + } + firsttime = false; +#endif + + int num_blocks_base_log2 = size_log2 - best_score_block_size_log2; + RUY_DCHECK_GE(num_blocks_base_log2, 0); + + const int num_blocks_of_rows_log2 = + num_blocks_base_log2 + rows_rectangularness_log2; + const int num_blocks_of_cols_log2 = + num_blocks_base_log2 + cols_rectangularness_log2; + + const int smallr = + round_down_pot(rows >> num_blocks_of_rows_log2, kernel_rows); + const int smallc = + round_down_pot(cols >> num_blocks_of_cols_log2, kernel_cols); + const int missr = + round_up_pot(rows - (smallr << num_blocks_of_rows_log2), kernel_rows) >> + pot_log2(kernel_rows); + const int missc = + round_up_pot(cols - (smallc << num_blocks_of_cols_log2), kernel_cols) >> + pot_log2(kernel_cols); + + block_map->dims[Side::kLhs] = rows; + block_map->dims[Side::kRhs] = cols; + block_map->kernel_dims[Side::kLhs] = kernel_rows; + block_map->kernel_dims[Side::kRhs] = kernel_cols; + block_map->num_blocks_base_log2 = num_blocks_base_log2; + block_map->rectangularness_log2[Side::kLhs] = rows_rectangularness_log2; + block_map->rectangularness_log2[Side::kRhs] = cols_rectangularness_log2; + block_map->small_block_dims[Side::kLhs] = smallr; + block_map->small_block_dims[Side::kRhs] = smallc; + block_map->large_blocks[Side::kLhs] = missr; + block_map->large_blocks[Side::kRhs] = missc; + // Done last: NumBlocks needs some of the block_map fields to be already set. + block_map->thread_count = + std::min(tentative_thread_count, NumBlocks(*block_map)); +} + +void GetBlockMatrixCoords(Side side, const BlockMap& block_map, int block, + int* start, int* end) { + profiler::ScopeLabel label("GetBlockMatrixCoords"); + *start = block * block_map.small_block_dims[side] + + std::min(block, block_map.large_blocks[side]) * + block_map.kernel_dims[side]; + *end = + *start + block_map.small_block_dims[side] + + (block < block_map.large_blocks[side] ? block_map.kernel_dims[side] : 0); + + RUY_DCHECK_EQ(0, *start % block_map.kernel_dims[side]); + RUY_DCHECK_EQ(0, *end % block_map.kernel_dims[side]); + RUY_DCHECK_LE(*end, block_map.dims[side]); + RUY_DCHECK_LT(*start, *end); + RUY_DCHECK_GE(*start, 0); +} + +void GetBlockMatrixCoords(const BlockMap& block_map, const SidePair& block, + SidePair* start, SidePair* end) { + for (Side side : {Side::kLhs, Side::kRhs}) { + GetBlockMatrixCoords(side, block_map, block[side], &(*start)[side], + &(*end)[side]); + } +} + +} // namespace ruy diff --git a/ruy/block_map.h b/ruy/block_map.h new file mode 100644 index 0000000..5e1cee0 --- /dev/null +++ b/ruy/block_map.h @@ -0,0 +1,161 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_BLOCK_MAP_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_BLOCK_MAP_H_ + +#include "ruy/path.h" +#include "ruy/side_pair.h" + +namespace ruy { + +enum class BlockMapTraversalOrder { + // Plain old row-by-row or column-by-column traversal. + kLinear, + // Fractal Z-order curve, https://en.wikipedia.org/wiki/Z-order_curve + kFractalZ, + // Variant of Z-order doing a U instead of a Z. + kFractalU, + // Hilbert curve, https://en.wikipedia.org/wiki/Hilbert_curve + kFractalHilbert +}; + +// A BlockMap describes a tiling of a matrix, typically the destination matrix +// of a matrix multiplication computation. As is standard in matrix +// multiplication, a tile is called a "block". +// +// Ruy subdivides work by blocks of the destination matrix: each thread fully +// computes a block at once, then moves on to another block; each block is +// produced by a single thread. +// +// This ensures that the workloads for each block are mutually independent, +// which reduces synchronization requirements. +// +// Typically, a matrix multiplication will early on create a BlockMap by +// calling MakeBlockMap. It will then query the number of blocks in that +// BlockMap by calling NumBlocks. It will then create a single atomic integer +// counter indexing these blocks, called the 'index', and will distribute +// work to its N threads by ensuring that each thread works on disjoint sets +// of index values. For a given index value, the thread will call +// GetBlockByIndex to get the corresponding block, then GetBlockMatrixCoords +// to find the actual row and column numbers of this block. +// +// There are two nested levels of subdivision. On a local level, the matrix is +// tiled into a square NxN grid where N is a power of two, specifically: +// N = 2^num_blocks_base_log2. +// +// At a larger scale, around these blocks, there may be one further +// level of subdivision, in only one dimension: either along rows or along +// columns. That is used to handle arbitrarily rectangular matrices. The +// aforementioned high-level block grid is square, so it does not readily fit +// well very rectangular matrices. +// +// Taking together these two nested levels of subdivision, the effective +// tiling is by +// 2^(num_blocks_base_log2 + rows_rectangularness_log2) +// blocks in the row dimension, and by +// 2^(num_blocks_base_log2 + cols_rectangularness_log2) +// blocks in the column dimension. See NumBlocksOfRows, NumBlocksOfCols. +// +// Either rows_rectangularness_log2 or cols_rectangularness_log2 must be zero. +// +// Finally, this BlockMap is designed to operate under alignment constraints: +// two fields, kernel_rows and kernel_cols, describe the requested alignment +// of the effective grid in both dimensions. The idea is to feed matrix +// multiplication kernels with tiles that fit their width as much as possible. +// Of course, if rows (resp. cols) is not a multiple of kernel_rows (resp. +// kernel_cols) then some tile will have to have unaligned size. BlockMap +// will only allow that to happen in the last position along each axis, so +// as to minimize the overhead incurred onto the matrix multiplication kernels. +struct BlockMap { + // The number of threads to use (to distribute the blocks to). + int thread_count; + // The order in which to traverse the matrix of which this BlockMap represents + // a tiling (hereafter "the matrix"). + BlockMapTraversalOrder traversal_order; + // The dimensions of the block_map, that is, of the destination + // matrix rounded up to next multiples of kernel_dims. + SidePair dims; + // Log2 of the minimum number of subdivisions of the grid along either axis. + int num_blocks_base_log2; + // Log2 of the additional subdivision of the rows/columns axis. + SidePair rectangularness_log2; + // Requested alignment of the subdivisions of the grid along the rows/columns + // axis. + SidePair kernel_dims; + // Internal helper. Minimum number of rows/columns in each block. + SidePair small_block_dims; + // Internal helper. Number of blocks along each dimension that need to have + // their size in that dimension be given by (small_block_dims + kernel_dims) + // instead of just small_block_dims. + SidePair large_blocks; +}; + +// Returns the traversal order to be used for the given matrix multiplication +// parameters. +BlockMapTraversalOrder GetTraversalOrder(int rows, int cols, int depth, + int lhs_scalar_size, + int rhs_scalar_size, + int local_data_cache_size, + int shared_data_cache_size); + +// Create a BlockMap suitable for tiling the destination matrix in a +// matrix multiplication with the given parameters. +void MakeBlockMap(int rows, int cols, int depth, int kernel_rows, + int kernel_cols, int lhs_scalar_size, int rhs_scalar_size, + int tentative_thread_count, Path path, + int local_data_cache_size, int shared_data_cache_size, + BlockMap* block_map); + +// Maps an integer index to a block position in the grid. +void GetBlockByIndex(const BlockMap& block_map, int index, + SidePair* block); + +// Given a block position in the grid, returns its actual +// position in the matrix that the BlockMap refers to in the dimension +// referred to by `side`: along rows if side==kLhs, along columns if +// side==kRhs. +void GetBlockMatrixCoords(Side side, const BlockMap& block_map, int block, + int* start, int* end); + +// Given a block position in the grid, returns its actual +// position in the matrix that the BlockMap refers to in terms of +// actual row/column indices. +void GetBlockMatrixCoords(const BlockMap& block_map, const SidePair& block, + SidePair* start, SidePair* end); + +// Returns the number of grid subdivisions along the rows dimension (if +// side == kLhs) or columns dimension (if side == kRhs). +inline int NumBlocksPerSide(Side side, const BlockMap& block_map) { + return 1 << (block_map.num_blocks_base_log2 + + block_map.rectangularness_log2[side]); +} + +// Returns the overall number of blocks in +// the BlockMap. The valid index values to pass to GetBlockByIndex are the +// integers from 0 to N-1 where N is the value returned here. +// +// Note that it is always true that +// NumBlocks == NumBlocksOfRows * NumBlocksOfCols +// because either rows_rectangularness_log2 or cols_rectangularness_log2 is 0. +inline int NumBlocks(const BlockMap& block_map) { + return 1 << (2 * block_map.num_blocks_base_log2 + + block_map.rectangularness_log2[Side::kLhs] + + block_map.rectangularness_log2[Side::kRhs]); +} + +} // namespace ruy + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_BLOCK_MAP_H_ diff --git a/ruy/block_map_test.cc b/ruy/block_map_test.cc new file mode 100644 index 0000000..24646cf --- /dev/null +++ b/ruy/block_map_test.cc @@ -0,0 +1,263 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/block_map.h" + +#include +#include +#include +#include +#include + +#include "testing/base/public/gunit.h" +#include "ruy/cpu_cache_size.h" +#include "ruy/path.h" +#include "ruy/side_pair.h" + +namespace ruy { +namespace { + +#if RUY_PLATFORM(NEON_64) + +// Unless otherwise specified, these tests have been tuned on ARM Cortex-A55. +void MakeBlockMapTuningTest(int rows, int cols, int depth, int kernel_rows, + int kernel_cols, int lhs_scalar_size, + int rhs_scalar_size, int tentative_thread_count, + Path path, int expected_num_blocks_base_log2, + int expected_rectangularness_log2) { + BlockMap block_map; + MakeBlockMap(rows, cols, depth, kernel_rows, kernel_cols, lhs_scalar_size, + rhs_scalar_size, tentative_thread_count, path, + LocalDataCacheSize(path), SharedDataCacheSize(path), &block_map); + EXPECT_EQ(block_map.num_blocks_base_log2, expected_num_blocks_base_log2); + EXPECT_EQ(std::min(block_map.rectangularness_log2[Side::kLhs], + block_map.rectangularness_log2[Side::kRhs]), + 0); + EXPECT_EQ(std::max(block_map.rectangularness_log2[Side::kLhs], + block_map.rectangularness_log2[Side::kRhs]), + expected_rectangularness_log2); +} + +TEST(BlockMapTest, MakeBlockMapTuningTest8bitCubicShapesOneThreadNeonDotprod) { + MakeBlockMapTuningTest(32, 32, 32, 8, 8, 1, 1, /* tentative_thread_count */ 1, + Path::kNeonDotprod, + /* expected_num_blocks_base_log2 */ 0, + /* expected_rectangularness_log2 */ 0); + MakeBlockMapTuningTest(48, 48, 48, 8, 8, 1, 1, /* tentative_thread_count */ 1, + Path::kNeonDotprod, + /* expected_num_blocks_base_log2 */ 0, + /* expected_rectangularness_log2 */ 0); + MakeBlockMapTuningTest(64, 64, 64, 8, 8, 1, 1, /* tentative_thread_count */ 1, + Path::kNeonDotprod, + /* expected_num_blocks_base_log2 */ 0, + /* expected_rectangularness_log2 */ 0); + MakeBlockMapTuningTest(96, 96, 96, 8, 8, 1, 1, /* tentative_thread_count */ 1, + Path::kNeonDotprod, + /* expected_num_blocks_base_log2 */ 0, + /* expected_rectangularness_log2 */ 0); + MakeBlockMapTuningTest(128, 128, 128, 8, 8, 1, 1, + /* tentative_thread_count */ 1, Path::kNeonDotprod, + /* expected_num_blocks_base_log2 */ 0, + /* expected_rectangularness_log2 */ 0); + MakeBlockMapTuningTest(192, 192, 192, 8, 8, 1, 1, + /* tentative_thread_count */ 1, Path::kNeonDotprod, + /* expected_num_blocks_base_log2 */ 0, + /* expected_rectangularness_log2 */ 0); + MakeBlockMapTuningTest(256, 256, 256, 8, 8, 1, 1, + /* tentative_thread_count */ 1, Path::kNeonDotprod, + /* expected_num_blocks_base_log2 */ 1, + /* expected_rectangularness_log2 */ 0); + MakeBlockMapTuningTest(384, 384, 384, 8, 8, 1, 1, + /* tentative_thread_count */ 1, Path::kNeonDotprod, + /* expected_num_blocks_base_log2 */ 1, + /* expected_rectangularness_log2 */ 0); +} + +TEST(BlockMapTest, + MakeBlockMapTuningTest8bitCubicShapesFourThreadsNeonDotprod) { + MakeBlockMapTuningTest(32, 32, 32, 8, 8, 1, 1, /* tentative_thread_count */ 4, + Path::kNeonDotprod, + /* expected_num_blocks_base_log2 */ 1, + /* expected_rectangularness_log2 */ 0); + MakeBlockMapTuningTest(48, 48, 48, 8, 8, 1, 1, /* tentative_thread_count */ 4, + Path::kNeonDotprod, + /* expected_num_blocks_base_log2 */ 1, + /* expected_rectangularness_log2 */ 0); + MakeBlockMapTuningTest(64, 64, 64, 8, 8, 1, 1, /* tentative_thread_count */ 4, + Path::kNeonDotprod, + /* expected_num_blocks_base_log2 */ 1, + /* expected_rectangularness_log2 */ 0); + MakeBlockMapTuningTest(96, 96, 96, 8, 8, 1, 1, /* tentative_thread_count */ 4, + Path::kNeonDotprod, + /* expected_num_blocks_base_log2 */ 1, + /* expected_rectangularness_log2 */ 0); + MakeBlockMapTuningTest(128, 128, 128, 8, 8, 1, 1, + /* tentative_thread_count */ 4, Path::kNeonDotprod, + /* expected_num_blocks_base_log2 */ 1, + /* expected_rectangularness_log2 */ 0); + MakeBlockMapTuningTest(192, 192, 192, 8, 8, 1, 1, + /* tentative_thread_count */ 4, Path::kNeonDotprod, + /* expected_num_blocks_base_log2 */ 1, + /* expected_rectangularness_log2 */ 0); + MakeBlockMapTuningTest(256, 256, 256, 8, 8, 1, 1, + /* tentative_thread_count */ 4, Path::kNeonDotprod, + /* expected_num_blocks_base_log2 */ 2, + /* expected_rectangularness_log2 */ 0); + MakeBlockMapTuningTest(384, 384, 384, 8, 8, 1, 1, + /* tentative_thread_count */ 4, Path::kNeonDotprod, + /* expected_num_blocks_base_log2 */ 2, + /* expected_rectangularness_log2 */ 0); +} + +TEST(BlockMapTest, MakeBlockMapTuningTest32bit) { + MakeBlockMapTuningTest(256, 256, 256, 8, 8, 4, 4, + /* tentative_thread_count */ 4, Path::kNeonDotprod, + /* expected_num_blocks_base_log2 */ 3, + /* expected_rectangularness_log2 */ 0); + MakeBlockMapTuningTest(4096, 4096, 4096, 8, 8, 4, 4, + /* tentative_thread_count */ 4, Path::kNeonDotprod, + /* expected_num_blocks_base_log2 */ 7, + /* expected_rectangularness_log2 */ 0); +} + +TEST(BlockMapTest, MakeBlockMapTuningTestRectangular) { + MakeBlockMapTuningTest(256, 16, 256, 8, 8, 1, 1, + /* tentative_thread_count */ 1, Path::kNeonDotprod, + /* expected_num_blocks_base_log2 */ 0, + /* expected_rectangularness_log2 */ 3); + MakeBlockMapTuningTest(24, 2400, 256, 8, 8, 1, 1, + /* tentative_thread_count */ 1, Path::kNeonDotprod, + /* expected_num_blocks_base_log2 */ 0, + /* expected_rectangularness_log2 */ 6); +} + +#endif + +int L1Distance(const SidePair& a, const SidePair& b) { + return std::abs(a[Side::kLhs] - b[Side::kLhs]) + + std::abs(a[Side::kRhs] - b[Side::kRhs]); +} + +void GetBlockByIndexSquareTest(int num_blocks_base_log2, + BlockMapTraversalOrder traversal_order) { + // Arbitrary, does not affect this test. 3 is just a typical value. + constexpr int kKernelSizeLog2 = 3; + + const int size_log2 = num_blocks_base_log2 + kKernelSizeLog2; + BlockMap block_map; + block_map.thread_count = 1; + block_map.traversal_order = traversal_order; + block_map.num_blocks_base_log2 = num_blocks_base_log2; + for (Side side : {Side::kLhs, Side::kRhs}) { + block_map.dims[side] = 1 << size_log2; + block_map.rectangularness_log2[side] = 0; + block_map.kernel_dims[side] = 1 << kKernelSizeLog2; + block_map.small_block_dims[side] = block_map.kernel_dims[side]; + block_map.large_blocks[side] = 0; + } + + const int num_blocks_per_side = 1 << num_blocks_base_log2; + const int num_blocks = num_blocks_per_side * num_blocks_per_side; + EXPECT_EQ(num_blocks, NumBlocks(block_map)); + + // Perform a full traversal of all blocks, as if computing a whole matrix + // multiplication. + // + // Used to record how many times each block was hit by the traversal. + std::vector block_hit_counts(num_blocks); + // Here we guard an assumption that all traversal orders start at (0, 0). + SidePair previous_block_coords(0, 0); + // Sum of L1 norm of the coordinate change at every step of the traversal. + std::int64_t total_l1_distance = 0; + // Number of jumps i.e. traversal steps with a L1 norm greater than 1. + int discontinuity_count = 0; + for (int block_index = 0; block_index < num_blocks; block_index++) { + SidePair block_coords; + GetBlockByIndex(block_map, block_index, &block_coords); + ++block_hit_counts[block_coords[Side::kLhs] + + num_blocks_per_side * block_coords[Side::kRhs]]; + int distance = L1Distance(block_coords, previous_block_coords); + total_l1_distance += distance; + discontinuity_count += (distance > 1); + previous_block_coords = block_coords; + } + + // Verify that each block was traversed exactly once. + for (int l = 0; l < num_blocks_per_side; l++) { + for (int r = 0; r < num_blocks_per_side; r++) { + EXPECT_EQ(block_hit_counts[l + num_blocks_per_side * r], 1); + } + } + + // Verify that the discontinuity_count and total_l1_distance are as expected + // for the given traversal_order. + switch (traversal_order) { + case BlockMapTraversalOrder::kFractalHilbert: + // No discontinuity at all with this space-filling continuous curve! + EXPECT_EQ(discontinuity_count, 0); + // Therefore, total_l1_distance has to be the number of blocks minus one. + EXPECT_EQ(total_l1_distance, num_blocks - 1); + break; + case BlockMapTraversalOrder::kLinear: + EXPECT_EQ(discontinuity_count, num_blocks_per_side - 1); + EXPECT_EQ(total_l1_distance, + 2 * num_blocks_per_side * (num_blocks_per_side - 1)); + break; + case BlockMapTraversalOrder::kFractalZ: + EXPECT_EQ(discontinuity_count, num_blocks > 1 ? (num_blocks / 2 - 1) : 0); + EXPECT_EQ(total_l1_distance, + 2 * num_blocks_per_side * (num_blocks_per_side - 1)); + break; + case BlockMapTraversalOrder::kFractalU: { + if (num_blocks_base_log2 == 0) { + EXPECT_EQ(discontinuity_count, 0); + EXPECT_EQ(total_l1_distance, 0); + } else { + int expected_discontinuity_count = 0; + int expected_total_l1_distance = 3; + for (int i = 2; i <= num_blocks_base_log2; i++) { + expected_discontinuity_count = 4 * expected_discontinuity_count + 2; + expected_total_l1_distance = + 4 * expected_total_l1_distance + (1 << (i + 1)) - 1; + } + EXPECT_EQ(discontinuity_count, expected_discontinuity_count); + EXPECT_EQ(total_l1_distance, expected_total_l1_distance); + } + break; + } + default: + abort(); + } +} + +TEST(BlockMapTest, GetBlockByIndexSquare) { + for (int num_blocks_base_log2 = 0; num_blocks_base_log2 <= 10; + num_blocks_base_log2++) { + for (BlockMapTraversalOrder traversal_order : + {BlockMapTraversalOrder::kLinear, BlockMapTraversalOrder::kFractalZ, + BlockMapTraversalOrder::kFractalU, + BlockMapTraversalOrder::kFractalHilbert}) { + GetBlockByIndexSquareTest(num_blocks_base_log2, traversal_order); + } + } +} + +} // namespace +} // namespace ruy + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/ruy/blocking_counter.cc b/ruy/blocking_counter.cc new file mode 100644 index 0000000..ffa7ac0 --- /dev/null +++ b/ruy/blocking_counter.cc @@ -0,0 +1,49 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/blocking_counter.h" + +#include "ruy/check_macros.h" +#include "ruy/wait.h" + +namespace ruy { + +void BlockingCounter::Reset(int initial_count) { + int old_count_value = count_.load(std::memory_order_relaxed); + RUY_DCHECK_EQ(old_count_value, 0); + (void)old_count_value; + count_.store(initial_count, std::memory_order_release); +} + +bool BlockingCounter::DecrementCount() { + int old_count_value = count_.fetch_sub(1, std::memory_order_acq_rel); + RUY_DCHECK_GT(old_count_value, 0); + int count_value = old_count_value - 1; + bool hit_zero = (count_value == 0); + if (hit_zero) { + std::lock_guard lock(count_mutex_); + count_cond_.notify_all(); + } + return hit_zero; +} + +void BlockingCounter::Wait() { + const auto& condition = [this]() { + return count_.load(std::memory_order_acquire) == 0; + }; + ruy::Wait(condition, &count_cond_, &count_mutex_); +} + +} // namespace ruy diff --git a/ruy/blocking_counter.h b/ruy/blocking_counter.h new file mode 100644 index 0000000..878f0e7 --- /dev/null +++ b/ruy/blocking_counter.h @@ -0,0 +1,62 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_BLOCKING_COUNTER_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_BLOCKING_COUNTER_H_ + +#include +#include // NOLINT(build/c++11) // IWYU pragma: keep +#include // NOLINT(build/c++11) // IWYU pragma: keep + +namespace ruy { + +// A BlockingCounter lets one thread to wait for N events to occur. +// This is how the master thread waits for all the worker threads +// to have finished working. +// The waiting is done using a naive spinlock waiting for the atomic +// count_ to hit the value 0. This is acceptable because in our usage +// pattern, BlockingCounter is used only to synchronize threads after +// short-lived tasks (performing parts of the same GEMM). It is not used +// for synchronizing longer waits (resuming work on the next GEMM). +class BlockingCounter { + public: + BlockingCounter() : count_(0) {} + + // Sets/resets the counter; initial_count is the number of + // decrementing events that the Wait() call will be waiting for. + void Reset(int initial_count); + + // Decrements the counter; if the counter hits zero, signals + // the threads that were waiting for that, and returns true. + // Otherwise (if the decremented count is still nonzero), + // returns false. + bool DecrementCount(); + + // Waits for the N other threads (N having been set by Reset()) + // to hit the BlockingCounter. + void Wait(); + + private: + std::atomic count_; + + // The condition variable and mutex allowing to passively wait for count_ + // to reach the value zero, in the case of longer waits. + std::condition_variable count_cond_; + std::mutex count_mutex_; +}; + +} // namespace ruy + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_BLOCKING_COUNTER_H_ diff --git a/ruy/build_defs.bzl b/ruy/build_defs.bzl new file mode 100644 index 0000000..964ede3 --- /dev/null +++ b/ruy/build_defs.bzl @@ -0,0 +1,54 @@ +"""Build definitions for Ruy. + +In some cases these are used to configure specific targets for +specific platforms, and dispatch is based on runtime capability detection. +""" + +# 1. Enable -mfpu=neon unconditionally on ARM32. If it turns out that we need to support +# ARM32 without NEON then we'll implement runtime detection and dispatch at that point. +# 2. Explicitly pass -O3 on optimization configs where just "-c opt" means "optimize for code size". + +def ruy_copts_base(): + return select({ + ":armeabi-v7a": [ + "-mfpu=neon", + ], + "//conditions:default": [], + }) + select({ + ":optimized": ["-O3"], + "//conditions:default": [], + }) + +# Used for targets that are compiled with extra features that are skipped at runtime if unavailable. +def ruy_copts_skylake(): + return select({ + ":x86_64": ["-march=skylake-avx512"], + "//conditions:default": [], + }) + +# Used for targets that are compiled with extra features that are skipped at runtime if unavailable. +def ruy_copts_avx2(): + return select({ + ":x86_64": ["-mavx2", "-mfma"], + "//conditions:default": [], + }) + +# TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. +# Optimization is not finished. In particular the dimensions of the kernel +# blocks can be changed as desired. +# +# Used for targets that are compiled with extra features that are skipped at runtime if unavailable. +def ruy_copts_sse42(): + return [] + +# TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. +# Optimization is not finished. In particular the dimensions of the kernel +# blocks can be changed as desired. +# +# Used for targets that are compiled with extra features that are skipped at runtime if unavailable. +def ruy_copts_avxvnni(): + return select({ + # TODO(b/146494398): Reinstate flag, something like "-march=cascadelake". + ":x86_64": [], + "//conditions:default": [], + }) diff --git a/ruy/build_defs.bzl.opensource b/ruy/build_defs.bzl.opensource new file mode 100644 index 0000000..9bccccf --- /dev/null +++ b/ruy/build_defs.bzl.opensource @@ -0,0 +1,40 @@ +"""Build definitions for Ruy.""" + +# 1. Enable -mfpu=neon unconditionally on ARM32. If it turns out that we need to support +# ARM32 without NEON then we'll implement runtime detection and dispatch at that point. +# 2. Explicitly pass -O3 on optimization configs where just "-c opt" means "optimize for code size". + +def ruy_copts_base(): + return select({ + ":armeabi-v7a": [ + "-mfpu=neon", + ], + "//conditions:default": [], + }) + select({ + ":optimized": ["-O3"], + "//conditions:default": [], + }) + +# Used for targets that are compiled with extra features that are skipped at runtime if unavailable. +def ruy_copts_skylake(): + return [] + +# Used for targets that are compiled with extra features that are skipped at runtime if unavailable. +def ruy_copts_avx2(): + return [] + +# TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. +# Optimization is not finished. In particular the dimensions of the kernel +# blocks can be changed as desired. +# +# Used for targets that are compiled with extra features that are skipped at runtime if unavailable. +def ruy_copts_sse42(): + return [] + +# TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. +# Optimization is not finished. In particular the dimensions of the kernel +# blocks can be changed as desired. +# +# Used for targets that are compiled with extra features that are skipped at runtime if unavailable. +def ruy_copts_avxvnni(): + return [] diff --git a/ruy/check_macros.h b/ruy/check_macros.h new file mode 100644 index 0000000..773f37d --- /dev/null +++ b/ruy/check_macros.h @@ -0,0 +1,138 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_CHECK_MACROS_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_CHECK_MACROS_H_ + +#include +#include +#include + +namespace ruy { +namespace check_macros { + +constexpr int kValueBufSize = 32; + +template +struct ToString { + static void Run(const T& value, char* buf) { + snprintf(buf, kValueBufSize, "(?)"); + } +}; + +template <> +struct ToString { + static void Run(float value, char* buf) { + snprintf(buf, kValueBufSize, "%.9g", static_cast(value)); + } +}; + +template <> +struct ToString { + static void Run(double value, char* buf) { + snprintf(buf, kValueBufSize, "%.16g", value); + } +}; + +template +struct ToString::value>::type> { + static void Run(const T& value, char* buf) { + snprintf(buf, kValueBufSize, "%lld", static_cast(value)); + } +}; + +template +struct ToString { + static void Run(T* value, char* buf) { + snprintf(buf, kValueBufSize, "%p", value); + } +}; + +template +struct ToString::value>::type> { + static void Run(const T& value, char* buf) { + snprintf(buf, kValueBufSize, "(enum value %d)", static_cast(value)); + } +}; + +inline void Failure(const char* file, int line, const char* macro, + const char* condition) { + fprintf(stderr, "%s:%d: %s condition not satisfied: %s\n", file, line, macro, + condition); + abort(); +} + +template +inline void Failure(const char* file, int line, const char* macro, + const char* lhs, const LhsType& lhs_value, const char* op, + const char* rhs, const RhsType& rhs_value) { + char lhs_value_buf[kValueBufSize]; + ToString::Run(lhs_value, lhs_value_buf); + char rhs_value_buf[kValueBufSize]; + ToString::Run(rhs_value, rhs_value_buf); + fprintf(stderr, + "%s:%d: %s condition not satisfied: [ %s %s %s ] with values [ " + "%s %s %s ].\n", + file, line, macro, lhs, op, rhs, lhs_value_buf, op, rhs_value_buf); + abort(); +} + +#define RUY_CHECK_IMPL(macro, condition) \ + do { \ + if (!(condition)) { \ + ruy::check_macros::Failure(__FILE__, __LINE__, #macro, #condition); \ + } \ + } while (false) + +#define RUY_CHECK_OP_IMPL(macro, lhs, op, rhs) \ + do { \ + const auto& lhs_value = (lhs); \ + const auto& rhs_value = (rhs); \ + if (!(lhs_value op rhs_value)) { \ + ruy::check_macros::Failure(__FILE__, __LINE__, #macro, #lhs, lhs_value, \ + #op, #rhs, rhs_value); \ + } \ + } while (false) + +#define RUY_CHECK(condition) RUY_CHECK_IMPL(RUY_CHECK, condition) +#define RUY_CHECK_EQ(x, y) RUY_CHECK_OP_IMPL(RUY_CHECK_EQ, x, ==, y) +#define RUY_CHECK_NE(x, y) RUY_CHECK_OP_IMPL(RUY_CHECK_NE, x, !=, y) +#define RUY_CHECK_GE(x, y) RUY_CHECK_OP_IMPL(RUY_CHECK_GE, x, >=, y) +#define RUY_CHECK_GT(x, y) RUY_CHECK_OP_IMPL(RUY_CHECK_GT, x, >, y) +#define RUY_CHECK_LE(x, y) RUY_CHECK_OP_IMPL(RUY_CHECK_LE, x, <=, y) +#define RUY_CHECK_LT(x, y) RUY_CHECK_OP_IMPL(RUY_CHECK_LT, x, <, y) + +#ifdef NDEBUG +#define RUY_DCHECK(condition) +#define RUY_DCHECK_EQ(x, y) +#define RUY_DCHECK_NE(x, y) +#define RUY_DCHECK_GE(x, y) +#define RUY_DCHECK_GT(x, y) +#define RUY_DCHECK_LE(x, y) +#define RUY_DCHECK_LT(x, y) +#else +#define RUY_DCHECK(condition) RUY_CHECK(condition) +#define RUY_DCHECK_EQ(x, y) RUY_CHECK_EQ(x, y) +#define RUY_DCHECK_NE(x, y) RUY_CHECK_NE(x, y) +#define RUY_DCHECK_GE(x, y) RUY_CHECK_GE(x, y) +#define RUY_DCHECK_GT(x, y) RUY_CHECK_GT(x, y) +#define RUY_DCHECK_LE(x, y) RUY_CHECK_LE(x, y) +#define RUY_DCHECK_LT(x, y) RUY_CHECK_LT(x, y) +#endif + +} // end namespace check_macros +} // end namespace ruy + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_CHECK_MACROS_H_ diff --git a/ruy/check_macros_test.cc b/ruy/check_macros_test.cc new file mode 100644 index 0000000..7e47e7f --- /dev/null +++ b/ruy/check_macros_test.cc @@ -0,0 +1,153 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/check_macros.h" + +#include "testing/base/public/gunit.h" + +namespace { + +#define TEST_CONDITION_FOR_FAMILY(family, vacuously_succeeds, condition) \ + do { \ + if (vacuously_succeeds || (condition)) { \ + RUY_##family(condition); \ + } \ + } while (false) + +#define TEST_COMPARISON_FOR_FAMILY(family, vacuously_succeeds, op_name, x, op, \ + y) \ + do { \ + if (vacuously_succeeds || ((x)op(y))) { \ + RUY_##family##_##op_name(x, y); \ + } \ + } while (false) + +#ifdef NDEBUG +#define TEST_CONDITION(condition) \ + do { \ + TEST_CONDITION_FOR_FAMILY(CHECK, false, condition); \ + } while (false) +#define TEST_COMPARISON(op_name, x, op, y) \ + do { \ + TEST_COMPARISON_FOR_FAMILY(CHECK, false, op_name, x, op, y); \ + } while (false) +#else +#define TEST_CONDITION(condition) \ + do { \ + TEST_CONDITION_FOR_FAMILY(CHECK, false, condition); \ + TEST_CONDITION_FOR_FAMILY(DCHECK, false, condition); \ + } while (false) +#define TEST_COMPARISON(op_name, x, op, y) \ + do { \ + TEST_COMPARISON_FOR_FAMILY(CHECK, false, op_name, x, op, y); \ + TEST_COMPARISON_FOR_FAMILY(DCHECK, false, op_name, x, op, y); \ + } while (false) + +#endif + +template +void TestEqualityComparisons(const LhsType& lhs, const RhsType& rhs) { + RUY_CHECK_EQ(lhs, lhs); + TEST_COMPARISON(EQ, lhs, ==, lhs); + RUY_CHECK_EQ(lhs, lhs); + RUY_CHECK_EQ(lhs, lhs); + if (lhs == rhs) { + RUY_CHECK_EQ(lhs, rhs); + } + if (lhs != rhs) { + RUY_CHECK_NE(lhs, rhs); + } +} + +template +void TestComparisons(const LhsType& lhs, const RhsType& rhs) { + TestEqualityComparisons(lhs, rhs); + if (lhs > rhs) { + RUY_CHECK_GT(lhs, rhs); + } + if (lhs >= rhs) { + RUY_CHECK_GE(lhs, rhs); + } + if (lhs < rhs) { + RUY_CHECK_LT(lhs, rhs); + } + if (lhs <= rhs) { + RUY_CHECK_LE(lhs, rhs); + } +} + +TEST(CheckMacrosTest, IntInt) { + TestComparisons(0, 0); + TestComparisons(0, 1); + TestComparisons(1, -1); + TestComparisons(-1, 0); + TestComparisons(123, -456); + TestComparisons(std::numeric_limits::min(), + std::numeric_limits::max()); + TestComparisons(123, std::numeric_limits::max()); + TestComparisons(123, std::numeric_limits::min()); +} + +TEST(CheckMacrosTest, Uint8Uint8) { + TestComparisons(0, 0); + TestComparisons(255, 0); + TestComparisons(0, 255); + TestComparisons(12, 34); +} + +TEST(CheckMacrosTest, Uint8Int) { + TestComparisons(0, std::numeric_limits::min()); + TestComparisons(255, std::numeric_limits::min()); + TestComparisons(0, std::numeric_limits::max()); + TestComparisons(255, std::numeric_limits::max()); +} + +TEST(CheckMacrosTest, FloatFloat) { + TestComparisons(0.f, 0.f); + TestComparisons(0.f, 1.f); + TestComparisons(1.f, -1.f); + TestComparisons(-1.f, 0.f); + TestComparisons(123.f, -456.f); + TestComparisons(std::numeric_limits::lowest(), + std::numeric_limits::max()); + TestComparisons(123.f, std::numeric_limits::max()); + TestComparisons(123.f, std::numeric_limits::lowest()); +} + +TEST(CheckMacrosTest, IntFloat) { + TestComparisons(0, 0.f); + TestComparisons(0, 1.f); + TestComparisons(1, -1.f); + TestComparisons(-1, 0.f); + TestComparisons(123, -456.f); + TestComparisons(std::numeric_limits::lowest(), + std::numeric_limits::max()); + TestComparisons(123, std::numeric_limits::max()); + TestComparisons(123, std::numeric_limits::lowest()); +} + +TEST(CheckMacrosTest, EnumClass) { + enum class SomeEnumClass { kA, kB, kC }; + TestEqualityComparisons(SomeEnumClass::kA, SomeEnumClass::kA); + TestEqualityComparisons(SomeEnumClass::kA, SomeEnumClass::kB); + TestEqualityComparisons(SomeEnumClass::kC, SomeEnumClass::kB); +} + +} // namespace + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/ruy/common.h b/ruy/common.h new file mode 100644 index 0000000..1cd40fe --- /dev/null +++ b/ruy/common.h @@ -0,0 +1,73 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Miscellaneous helpers internal library. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_COMMON_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_COMMON_H_ + +#include +#include + +#include "ruy/check_macros.h" +#include "ruy/matrix.h" +#include "ruy/opt_set.h" +#include "ruy/path.h" +#include "ruy/platform.h" + +#if RUY_OPT_ENABLED(RUY_OPT_PREFETCH_LOAD) +#define RUY_PREFETCH_LOAD(X) X +#else +#define RUY_PREFETCH_LOAD(X) +#endif + +#if RUY_OPT_ENABLED(RUY_OPT_PREFETCH_STORE) +#define RUY_PREFETCH_STORE(X) X +#else +#define RUY_PREFETCH_STORE(X) +#endif + +#define RUY_STR(s) RUY_STR_UNEXPANDED(s) +#define RUY_STR_UNEXPANDED(s) #s + +namespace ruy { + +// Helper for type-erasing a pointer. +// +// Often inside Ruy, a template parameter holds type information statically, but +// we would like to have a function signature that doesn't depend on the +// template parameters, so that we can dispatch indirectly across multiple +// implementations. This helper is at the core of such type-erasure. +// +// The opposite of this operation is just `static_cast(void_ptr)`. +template +void* ToVoidPtr(T* p) { + return const_cast(static_cast(p)); +} + +template +Scalar SymmetricZeroPoint() { + if (std::is_floating_point::value) { + return 0; + } + if (std::is_signed::value) { + return 0; + } + return std::numeric_limits::max() / 2 + 1; +} + +} // namespace ruy + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_COMMON_H_ diff --git a/ruy/context.cc b/ruy/context.cc new file mode 100644 index 0000000..1a70303 --- /dev/null +++ b/ruy/context.cc @@ -0,0 +1,109 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/context.h" + +#include "ruy/check_macros.h" +#include "ruy/detect_arm.h" +#include "ruy/detect_x86.h" +#include "ruy/have_built_path_for.h" +#include "ruy/platform.h" + +namespace ruy { + +void Context::SetRuntimeEnabledPaths(Path paths) { + runtime_enabled_paths_ = paths; +} + +Path Context::GetRuntimeEnabledPaths() { + // This function should always return the same value on a given machine. + // When runtime_enabled_paths_ has its initial value kNone, it performs + // some platform detection to resolve it to specific Path values. + + // Fast path: already resolved. + if (runtime_enabled_paths_ != Path::kNone) { + return runtime_enabled_paths_; + } + + // Need to resolve now. Start by considering all paths enabled. + runtime_enabled_paths_ = kAllPaths; + + // This mechanism is intended to be used for testing and benchmarking. For + // example, one can set RUY_FORCE_DISABLE_PATHS to Path::kAvx512 in order to + // evaluate AVX2 performance on an AVX-512 machine. +#ifdef RUY_FORCE_DISABLE_PATHS + runtime_enabled_paths_ = runtime_enabled_paths_ & ~(RUY_FORCE_DISABLE_PATHS); +#endif + +#if RUY_PLATFORM(ARM) + // Now selectively disable paths that aren't supported on this machine. + if ((runtime_enabled_paths_ & Path::kNeonDotprod) != Path::kNone) { + if (!DetectDotprod()) { + runtime_enabled_paths_ = runtime_enabled_paths_ & ~Path::kNeonDotprod; + // Sanity check. + RUY_DCHECK((runtime_enabled_paths_ & Path::kNeonDotprod) == Path::kNone); + } + } +#endif // RUY_PLATFORM(ARM) + +#if RUY_PLATFORM(X86) + // TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / + // placeholder. Optimization is not finished. In particular the dimensions of + // the kernel blocks can be changed as desired. + // + if ((runtime_enabled_paths_ & Path::kSse42) != Path::kNone) { + if (!(HaveBuiltPathForSse42() && DetectCpuSse42())) { + runtime_enabled_paths_ = runtime_enabled_paths_ & ~Path::kSse42; + // Sanity check. + RUY_DCHECK((runtime_enabled_paths_ & Path::kSse42) == Path::kNone); + } + } + + if ((runtime_enabled_paths_ & Path::kAvx2) != Path::kNone) { + if (!(HaveBuiltPathForAvx2() && DetectCpuAvx2())) { + runtime_enabled_paths_ = runtime_enabled_paths_ & ~Path::kAvx2; + // Sanity check. + RUY_DCHECK((runtime_enabled_paths_ & Path::kAvx2) == Path::kNone); + } + } + + if ((runtime_enabled_paths_ & Path::kAvx512) != Path::kNone) { + if (!(HaveBuiltPathForAvx512() && DetectCpuAvx512())) { + runtime_enabled_paths_ = runtime_enabled_paths_ & ~Path::kAvx512; + // Sanity check. + RUY_DCHECK((runtime_enabled_paths_ & Path::kAvx512) == Path::kNone); + } + } + + // TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / + // placeholder. Optimization is not finished. In particular the dimensions of + // the kernel blocks can be changed as desired. + // + if ((runtime_enabled_paths_ & Path::kAvxVnni) != Path::kNone) { + if (!(HaveBuiltPathForAvxVnni() && DetectCpuAvxVnni())) { + runtime_enabled_paths_ = runtime_enabled_paths_ & ~Path::kAvxVnni; + // Sanity check. + RUY_DCHECK((runtime_enabled_paths_ & Path::kAvxVnni) == Path::kNone); + } + } +#endif // RUY_PLATFORM(X86) + + // Sanity check. We can't possibly have disabled all paths, as some paths + // are universally available (kReference, kStandardCpp). + RUY_DCHECK_NE(runtime_enabled_paths_, Path::kNone); + return runtime_enabled_paths_; +} + +} // namespace ruy diff --git a/ruy/context.h b/ruy/context.h new file mode 100644 index 0000000..330a7e7 --- /dev/null +++ b/ruy/context.h @@ -0,0 +1,109 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_CONTEXT_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_CONTEXT_H_ + +#include +#include +#include + +#include "ruy/allocator.h" +#include "ruy/path.h" +#include "ruy/prepacked_cache.h" +#include "ruy/thread_pool.h" +#include "ruy/trace.h" +#include "ruy/tune.h" + +namespace ruy { + +// The state private to each Ruy thread. +struct PerThreadState { + // Each thread may be running on a different microarchitecture. For example, + // some threads may be on big cores, while others are on little cores. Thus, + // it's best for the tuning to be per-thread. + TuningResolver tuning_resolver; + // Each thread has its own local allocator. + Allocator allocator; +}; + +// A Context holds runtime information used by Ruy. It holds runtime resources +// such as the workers thread pool and the allocator (which holds buffers for +// temporary data), as well as runtime options controlling which Paths are +// enabled (typically based on which instruction sets are detected) and how +// many threads to use. +struct Context final { + Path last_taken_path = Path::kNone; + Tuning explicit_tuning = Tuning::kAuto; + // TODO(benoitjacob) rename that thread_pool. Current name is gemmlowp legacy. + ThreadPool workers_pool; + int max_num_threads = 1; + // State for each thread in the thread pool. Entry 0 is the main thread. + std::vector> per_thread_states; + TracingContext tracing; + CachePolicy cache_policy = CachePolicy::kNoCache; + + Allocator* GetMainAllocator() { + if (!main_allocator_) { + main_allocator_.reset(new Allocator); + } + return main_allocator_.get(); + } + + PrepackedCache* GetPrepackedCache() { + if (!prepacked_cache_) { + prepacked_cache_.reset(new PrepackedCache); + } + return prepacked_cache_.get(); + } + + void ClearPrepackedCache() { prepacked_cache_ = nullptr; } + + void EnsureNPerThreadStates(int thread_count) { + while (per_thread_states.size() < static_cast(thread_count)) { + per_thread_states.emplace_back(new PerThreadState); + } + } + + Tuning GetMainThreadTuning() { + EnsureNPerThreadStates(1); + TuningResolver* tuning_resolver = &per_thread_states[0]->tuning_resolver; + tuning_resolver->SetTuning(explicit_tuning); + return tuning_resolver->Resolve(); + } + + template + Path GetPathToTake() { + last_taken_path = + GetMostSignificantPath(CompiledPaths & GetRuntimeEnabledPaths()); + return last_taken_path; + } + + void SetRuntimeEnabledPaths(Path paths); + Path GetRuntimeEnabledPaths(); + + private: + // Allocator for main thread work before invoking the threadpool. + // Our simple Allocator does not allow reserving/allocating more blocks + // while it's already in committed state, so the main thread needs both + // this allocator, and its per-thread allocator. + std::unique_ptr main_allocator_; + std::unique_ptr prepacked_cache_; + Path runtime_enabled_paths_ = Path::kNone; +}; + +} // end namespace ruy + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_CONTEXT_H_ diff --git a/ruy/context_test.cc b/ruy/context_test.cc new file mode 100644 index 0000000..c189030 --- /dev/null +++ b/ruy/context_test.cc @@ -0,0 +1,63 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/context.h" + +#include "testing/base/public/gunit.h" +#include "ruy/path.h" +#include "ruy/platform.h" + +namespace ruy { +namespace { + +TEST(ContextTest, EnabledPathsGeneral) { + ruy::Context ruy_context; + const auto ruy_paths = ruy_context.GetRuntimeEnabledPaths(); + const auto ruy_paths_repeat = ruy_context.GetRuntimeEnabledPaths(); + ASSERT_EQ(ruy_paths, ruy_paths_repeat); + EXPECT_NE(ruy_paths, Path::kNone); + EXPECT_EQ(ruy_paths & Path::kReference, Path::kReference); + EXPECT_EQ(ruy_paths & Path::kStandardCpp, Path::kStandardCpp); +} + +#if RUY_PLATFORM(X86) +TEST(ContextTest, EnabledPathsX86) { + ruy::Context ruy_context; + ruy_context.SetRuntimeEnabledPaths(Path::kSse42 | Path::kAvx2 | + Path::kAvx512 | Path::kAvxVnni); + const auto ruy_paths = ruy_context.GetRuntimeEnabledPaths(); + EXPECT_EQ(ruy_paths & Path::kReference, Path::kNone); + EXPECT_EQ(ruy_paths & Path::kStandardCpp, Path::kNone); +} +#endif // RUY_PLATFORM(X86) + +#if RUY_PLATFORM(ARM) +TEST(ContextTest, EnabledPathsArm) { + ruy::Context ruy_context; + ruy_context.SetRuntimeEnabledPaths(Path::kNeon | Path::kNeonDotprod); + const auto ruy_paths = ruy_context.GetRuntimeEnabledPaths(); + EXPECT_EQ(ruy_paths & Path::kReference, Path::kNone); + EXPECT_EQ(ruy_paths & Path::kStandardCpp, Path::kNone); + EXPECT_EQ(ruy_paths & Path::kNeon, Path::kNeon); +} +#endif // RUY_PLATFORM(ARM) + +} // namespace +} // namespace ruy + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/ruy/cpu_cache_size.h b/ruy/cpu_cache_size.h new file mode 100644 index 0000000..82f41cc --- /dev/null +++ b/ruy/cpu_cache_size.h @@ -0,0 +1,81 @@ +/* Copyright 2020 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_CPU_CACHE_SIZE_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_CPU_CACHE_SIZE_H_ + +#include "ruy/path.h" +#include "ruy/platform.h" + +namespace ruy { + +// LocalDataCacheSize returns a sane default size for each CPU core's local +// data cache, i.e. the largest data cache that is local to that CPU core, not +// shared with other cores. That allows coarse tuning of code that aims for +// most of its memory accesses to hit such a typically fast data cache. +// +// SharedDataCacheSize returns a sane default size of the total data cache +// accessible to each CPU, including any shared cache. +// +// For example, if we design tune this code for a ARM Cortex-A55 with a local L1 +// cache of 32k, a local L2 cache of 128k and a shared L3 cache of 1M, +// LocalDataCacheSize should return 128k and SharedDataCacheSize +// should return 1M. +// +// Ideally these values would be queried at runtime, and we should probably +// do that on x86, but that is hard to do on ARM. +#if RUY_PLATFORM(ARM_64) +inline int LocalDataCacheSize() { return 1 << 15; } +inline int SharedDataCacheSize() { return 1 << 19; } +#elif RUY_PLATFORM(ARM_32) +inline int LocalDataCacheSize() { return 1 << 14; } +inline int SharedDataCacheSize() { return 1 << 18; } +#elif RUY_PLATFORM(X86) +inline int LocalDataCacheSize() { return 1 << 17; } +inline int SharedDataCacheSize() { return 1 << 21; } +#else +inline int LocalDataCacheSize() { return 1 << 14; } +inline int SharedDataCacheSize() { return 1 << 18; } +#endif +// Variants taking a Path argument which acts +// as a hint telling whether we're targeting more or less recent/powerful CPUs. +inline int LocalDataCacheSize(Path path) { +#if RUY_PLATFORM(ARM_64) + if (path == Path::kNeonDotprod) { + // At the moment, the smallest CPU with dotprod is probably Cortex-A55 with + // 128k L2 local cache. + return 1 << 17; + } +#else + (void)path; +#endif + return LocalDataCacheSize(); +} +inline int SharedDataCacheSize(Path path) { +#if RUY_PLATFORM(ARM_64) + if (path == Path::kNeonDotprod) { + // At the moment, the smallest CPU with dotprod is probably Cortex-A55 with + // 1M L3 shared cache. + return 1 << 20; + } +#else + (void)path; +#endif + return SharedDataCacheSize(); +} + +} // namespace ruy + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_CPU_CACHE_SIZE_H_ diff --git a/ruy/detect_arm.cc b/ruy/detect_arm.cc new file mode 100644 index 0000000..85f7156 --- /dev/null +++ b/ruy/detect_arm.cc @@ -0,0 +1,73 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +/* Detection of dotprod instructions on ARM. + * The current Linux-specific code relies on sufficiently new Linux kernels: + * At least Linux 4.15 in general; on Android, at least Linux 4.14.111 thanks to + * a late backport. This was backported just before the Android 10 release, so + * this is leaving out pre-release Android 10 builds as well as earlier Android + * versions. + * + * It is possible to detect instructions in other ways that don't rely on + * an OS-provided feature identification mechanism: + * + * (A) We used to have a SIGILL-handler-based method that worked at least + * on Linux. Its downsides were (1) crashes on a few devices where + * signal handler installation didn't work as intended; (2) additional + * complexity to generalize to other Unix-ish operating systems including + * iOS; (3) source code complexity and fragility of anything installing + * and restoring signal handlers; (4) confusing behavior under a debugger. + * + * (B) We also experimented with a fork-ing approach where a subprocess + * tries the instruction. Compared to (A), this is much simpler and more + * reliable and portable, but also much higher latency on Android where + * an uncaught signal typically causes a 100 ms latency. + * + * Should there be interest in either technique again in the future, + * code implementing both (A) and (B) can be found in earlier revisions of this + * file - in actual code for (A) and in a comment for (B). + */ + +#include "ruy/detect_arm.h" + +#if defined __linux__ && defined __aarch64__ +#include +#endif + +namespace ruy { + +namespace { + +#if defined __linux__ && defined __aarch64__ +bool DetectDotprodByLinuxAuxvMethod() { + // This is the value of HWCAP_ASIMDDP in sufficiently recent Linux headers, + // however we need to support building against older headers for the time + // being. + const int kLocalHwcapAsimddp = 1 << 20; + return getauxval(AT_HWCAP) & kLocalHwcapAsimddp; +} +#endif + +} // namespace + +bool DetectDotprod() { +#if defined __linux__ && defined __aarch64__ + return DetectDotprodByLinuxAuxvMethod(); +#endif + + return false; +} + +} // namespace ruy diff --git a/ruy/detect_arm.h b/ruy/detect_arm.h new file mode 100644 index 0000000..9a1542d --- /dev/null +++ b/ruy/detect_arm.h @@ -0,0 +1,29 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Temporary dotprod-detection code until we can rely on getauxval. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_DETECT_ARM_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_DETECT_ARM_H_ + +namespace ruy { + +// On A64, returns true if the dotprod extension is present. +// On other architectures, returns false unconditionally. +bool DetectDotprod(); + +} // namespace ruy + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_DETECT_ARM_H_ diff --git a/ruy/detect_x86.cc b/ruy/detect_x86.cc new file mode 100644 index 0000000..ded37b1 --- /dev/null +++ b/ruy/detect_x86.cc @@ -0,0 +1,101 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/detect_x86.h" + +#include + +#if RUY_PLATFORM(X86) && RUY_PLATFORM(X86_ENHANCEMENTS) +#include // IWYU pragma: keep + +#endif + +namespace ruy { +#if RUY_PLATFORM(X86) && RUY_PLATFORM(X86_ENHANCEMENTS) + +namespace { + +// See Intel docs, such as http://goo.gl/c6IkGX. +inline void RunCpuid(std::uint32_t eax, std::uint32_t ecx, + std::uint32_t abcd[4]) { + std::uint32_t ebx, edx; +#if defined(__i386__) && defined(__PIC__) + /* in case of PIC under 32-bit EBX cannot be clobbered */ + asm volatile("movl %%ebx, %%edi \n\t cpuid \n\t xchgl %%ebx, %%edi" + : "=D"(ebx), +#else + asm volatile("cpuid" + : "+b"(ebx), +#endif + "+a"(eax), "+c"(ecx), "=d"(edx)); + abcd[0] = eax; + abcd[1] = ebx; + abcd[2] = ecx; + abcd[3] = edx; +} + +} // namespace + +bool DetectCpuSse42() { + std::uint32_t abcd[4]; + + constexpr std::uint32_t kEcxSse42 = 1u << 20; + RunCpuid(1, 0, abcd); + const bool has_sse4_2_base = (abcd[2] & kEcxSse42) == kEcxSse42; + +#ifdef RUY_ENABLE_AMD_CPUID_CHECKS + constexpr std::uint32_t kEcxAbm = 1u << 5; + RunCpuid(0x80000001, 0, abcd); + const bool has_extras = (abcd[2] & kEcxAbm) == kEcxAbm; +#else + constexpr std::uint32_t kEcxPopcnt = 1u << 23; + RunCpuid(1, 0, abcd); + const bool has_extras = (abcd[2] & kEcxPopcnt) == kEcxPopcnt; +#endif + + return has_sse4_2_base && has_extras; +} + +bool DetectCpuAvx2() { + constexpr std::uint32_t kEbxAvx2 = 1u << 5; + constexpr std::uint32_t kEcxFma = 1u << 12; + + std::uint32_t abcd[4]; + + RunCpuid(7, 0, abcd); + const bool has_avx2 = (abcd[1] & kEbxAvx2) == kEbxAvx2; + RunCpuid(1, 0, abcd); + const bool has_fma = (abcd[2] & kEcxFma) == kEcxFma; + + return has_avx2 && has_fma; +} + +bool DetectCpuAvx512() { + constexpr std::uint32_t kEbxAvx512F = 1u << 16; + constexpr std::uint32_t kEbxAvx512Dq = 1u << 17; + constexpr std::uint32_t kEbxAvx512Cd = 1u << 28; + constexpr std::uint32_t kEbxAvx512Bw = 1u << 30; + constexpr std::uint32_t kEbxAvx512Vl = 1u << 31; + + constexpr std::uint32_t kEbxAvx512Mask = + kEbxAvx512F | kEbxAvx512Dq | kEbxAvx512Cd | kEbxAvx512Bw | kEbxAvx512Vl; + std::uint32_t abcd[4]; + RunCpuid(7, 0, abcd); + + return (abcd[1] & kEbxAvx512Mask) == kEbxAvx512Mask; +} + +#endif +} // namespace ruy diff --git a/ruy/detect_x86.h b/ruy/detect_x86.h new file mode 100644 index 0000000..fede7c7 --- /dev/null +++ b/ruy/detect_x86.h @@ -0,0 +1,49 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_DETECT_X86_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_DETECT_X86_H_ + +#include "ruy/platform.h" + +namespace ruy { + +#if RUY_PLATFORM(X86) +#if RUY_PLATFORM(X86_ENHANCEMENTS) + +// This also checks ABM support, which implies LZCNT and POPCNT. +bool DetectCpuSse42(); +bool DetectCpuAvx2(); +bool DetectCpuAvx512(); +// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. +// Optimization is not finished. In particular the dimensions of the kernel +// blocks can be changed as desired. +// +// TODO(b/146646451): Introduce and activate. +inline bool DetectCpuAvxVnni() { return false; } + +#else // RUY_PLATFORM(X86_ENHANCEMENTS) + +inline bool DetectCpuSse42() { return false; } +inline bool DetectCpuAvx2() { return false; } +inline bool DetectCpuAvx512() { return false; } +inline bool DetectCpuAvxVnni() { return false; } + +#endif // !RUY_PLATFORM(X86_ENHANCEMENTS) +#endif // RUY_PLATFORM(X86) + +} // namespace ruy + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_DETECT_X86_H_ diff --git a/ruy/dispatch.h b/ruy/dispatch.h new file mode 100644 index 0000000..2fd50d0 --- /dev/null +++ b/ruy/dispatch.h @@ -0,0 +1,482 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file implements the translation between Ruy's entry point (ruy::Mul) and +// the internal implementation of matrix multiplication. +// +// The primary elements of this dispatch are: +// - pick suitable gemm kernel and packing routines for the user-specified +// CompiledPaths based on the current CPU. +// - decide on the structure of the packed matrices needed by the internal +// implementation (see pack.h for more information on packing). +// - translate the Mul operation into TrMul (see trmul.h for why that is +// useful). This is done by changing the matrix Layout -- no matrix data is +// actually moved. +// +// This file is also factored to serve as a building block for the advanced API +// as well. +// +// This file also performs some checking of invariants to catch user errors. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_DISPATCH_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_DISPATCH_H_ + +#include +#include +#include // IWYU pragma: keep +#include + +#include "ruy/check_macros.h" +#include "ruy/common.h" +#include "ruy/context.h" +#include "ruy/internal_matrix.h" +#include "ruy/kernel.h" +#include "ruy/kernel_common.h" +#include "ruy/matrix.h" +#include "ruy/opt_set.h" +#include "ruy/pack.h" +#include "ruy/pack_common.h" +#include "ruy/path.h" +#include "ruy/profiler/instrumentation.h" +#include "ruy/side_pair.h" +#include "ruy/size_util.h" +#include "ruy/spec.h" +#include "ruy/trmul.h" +#include "ruy/trmul_params.h" + +namespace ruy { + +// If the Spec's LayoutSupport covers only some special cases, +// this function enforces that the matrix multiplication at hand falls into +// that special case. +template +void EnforceLayoutSupport(const Layout& lhs_layout, const Layout& rhs_layout, + const Layout& dst_layout) { + if (Spec::kLayoutSupport == LayoutSupport::kRCC) { + RUY_DCHECK(IsRowMajor(lhs_layout)); + RUY_DCHECK(IsColMajor(rhs_layout)); + RUY_DCHECK(IsColMajor(dst_layout)); + } +} + +template +bool IsSymmetricZeroPoint(Scalar zero_point) { + return zero_point == SymmetricZeroPoint(); +} + +template +void CheckZeroPoint(Scalar zero_point) { + if (std::is_floating_point::value || + Spec::kZeroPointSupport == ZeroPointSupport::kSymmetric) { + RUY_DCHECK(IsSymmetricZeroPoint(zero_point)); + } +} + +template +void EnforceZeroPointSupport(LhsScalar lhs_zero_point, RhsScalar rhs_zero_point, + DstScalar dst_zero_point) { + // If the Spec's ZeroPointSupport covers only some special cases, + // this function enforces that the matrix multiplication at hand falls into + // that special case. + CheckZeroPoint(lhs_zero_point); + CheckZeroPoint(rhs_zero_point); + CheckZeroPoint(dst_zero_point); + + // Guard against the case when both LHS and RHS zero_point's are equal to + // the minimum representable value. In that case, padding with zero_point + // values will generate the bad case for fast int8 kernels on NEON + // (pre-dotprod) which attempt to multiply-accumulate two pairs of int8 + // into a int16: this is safe except in the bad case -128*-128 + -128*-128. + // See b/131609283. This only affects the kNeon path but we ban this for all + // paths in order for ruy to have the same supported parameter space + // on all paths. + RUY_DCHECK(lhs_zero_point != std::numeric_limits::lowest() || + rhs_zero_point != std::numeric_limits::lowest()); +} + +template +void EnforceDstSpecSupport(const Spec& spec, DstScalar dst_zero_point) { + static_assert(std::is_same::value, ""); + if (!std::is_same::value) return; + + // If user is looking for the raw accumulator, zero_point and all the other + // dequantize fields don't make sense and should not be set. + RUY_DCHECK_EQ(dst_zero_point, 0); + RUY_DCHECK_EQ(spec.clamp_max, std::numeric_limits::max()); + RUY_DCHECK_EQ(spec.clamp_min, std::numeric_limits::min()); + RUY_DCHECK_EQ(spec.multiplier_fixedpoint, 0); + RUY_DCHECK_EQ(spec.multiplier_exponent, 0); + RUY_DCHECK_EQ(spec.multiplier_fixedpoint_perchannel, nullptr); + RUY_DCHECK_EQ(spec.multiplier_exponent_perchannel, nullptr); +} + +inline bool IsColMajorTrMul(const TrMulParams& params) { + return IsColMajor(params.src[Side::kLhs].layout) && + IsColMajor(params.src[Side::kRhs].layout) && + IsColMajor(params.dst.layout); +} + +inline void CreatePackedLayout(const Layout& src, const Type& scalar, + const KernelLayout& kernel_layout, + PackedLayout* packed) { + packed->order = Order::kColMajor; + packed->rows = round_up_pot(src.rows, kernel_layout.rows); + packed->cols = round_up_pot(src.cols, kernel_layout.cols); + packed->kernel = kernel_layout; + int inner_size = packed->rows; + if (RUY_OPT_ENABLED(RUY_OPT_AVOID_ALIASING)) { + packed->stride = + (inner_size * scalar.size) % 1024 ? inner_size : inner_size + 64; + } else { + packed->stride = inner_size; + } +} + +template +void CreatePackedMatrix(Side side, const KernelLayout& kernel_layout, + TrMulParams* params) { + // Ruy always uses 32-bit signed accumulators for quantized + // matrix multiplication, so we would like to always use std::int32_t + // unconditionally for SumsType. + // However, for floating point types, we still need a reasonable type here to + // avoid tripping assertions elsewhere in the code. + using SumsType = + typename std::conditional::value, Scalar, + std::int32_t>::type; + + const DMatrix& src = params->src[side]; + PMatrix* packed = ¶ms->packed[side]; + packed->data_type = Type::Create(); + packed->sums_type = Type::Create(); + CreatePackedLayout(src.layout, packed->data_type, kernel_layout, + &packed->layout); + packed->zero_point = Pack(src.zero_point); +} + +template +void PopulateTrMulParams(TrMulParams* params) { + static_assert((ThePath & Path::kReference) == Path::kNone, + "Path::kReference should not do TrMul"); + // The optimized code paths don't handle the full generality of Ruy's API. + // Fall back to Path::kStandardCpp if necessary. + bool fallback_to_standard_cpp = false; + if (ThePath != Path::kStandardCpp) { + // The optimized code paths currently only handle the case of all matrices + // being column major. + if (!IsColMajorTrMul(*params)) { + fallback_to_standard_cpp = true; + } + } + + if (fallback_to_standard_cpp) { + PopulateTrMulParams(params); + return; + } + + using PackedLhsScalar = PackedType; + using PackedRhsScalar = PackedType; + using Kernel = + Kernel; + using LhsKernelLayout = typename Kernel::LhsLayout; + using RhsKernelLayout = typename Kernel::RhsLayout; + + params->path = ThePath; + + params->local_data_cache_size = Spec::local_data_cache_size(); + params->shared_data_cache_size = Spec::shared_data_cache_size(); + + CreatePackedMatrix( + Side::kLhs, ToKernelLayout(), params); + CreatePackedMatrix( + Side::kRhs, ToKernelLayout(), params); + params->run_pack[Side::kLhs] = + &RunPack; + params->run_pack[Side::kRhs] = + &RunPack; + params->run_kernel = + &RunKernel; + + return; +} + +// PopulateTrMulParamsAllCompiledPaths calls into one of multiple +// instantiations of PopulateTrMulParams. For each bit that is set in +// CompiledPaths, it statically instantiates PopulateTrMulParams with a Path +// corresponding to that single bit. The call to PopulateTrMulParams is +// guarded by a runtime check that it is in fact the dynamically selected path. +// +// PopulateTrMulParamsAllCompiledPaths is implemented with template +// metaprogramming by mutual recursion between PathSearchCountdown and +// PathSearchCompiledPaths. +// +// PopulateTrMulParamsAllCompiledPaths is logically implementing the following +// computation: +// +// template +// void PopulateTrMulParamsAllCompiledPaths(Path the_path, +// TrMulParams* params) { +// for (int bit = 8 * sizeof(Path) - 1; bit != -1; bit--) { // [1] +// Path current_path = static_cast(1 << bit); +// if ((CompiledPaths & current_path) != Path::kNone) { // [2] +// if (current_path == the_path) { // [3] +// PopulateTrMulParams(the_path, params); +// return; +// } +// } +// } +// } +// +// +// +// [1] - Done by the main definition of PathSearchCountdown. The `bit--` is +// done in the recursion of PathSearchOnlyCompiledPaths. +// [2] - Done by PathSearchOnlyCompiledPaths's partial template +// specialization on InCompiledPaths. This is the check which necessitates +// doing the whole computation at C++ compile time. +// [3] - Done by the `if` in the main definition of +// PathSearchOnlyCompiledPaths. +// +// The template metaprogramming is necessary because: +// - In `PopulateTrMulParams`, current_path must be a C++ +// compile-time constant. +// - PopulateTrMulParamsAllCompiledPaths must not instantiate +// inner loops for paths that are not in CompiledPaths, since that can result in +// bogus instantiations which cause a compile time failure. +template +struct PathSearchCountdown; + +template +struct PathSearchOnlyCompiledPaths { + static constexpr Path kCurrentPath = static_cast(1 << BitNumber); + static void Search(Path the_path, TrMulParams* params) { + if (kCurrentPath == the_path) { + PopulateTrMulParams( + params); + return; + } + PathSearchCountdown::Search(the_path, params); + } +}; + +// Skip this iteration if CompiledPaths doesn't contain the specified path. +template +struct PathSearchOnlyCompiledPaths { + static void Search(Path the_path, TrMulParams* params) { + PathSearchCountdown::Search(the_path, params); + } +}; + +template +struct PathSearchCountdown { + static constexpr Path kCurrentPath = static_cast(1 << BitNumber); + static void Search(Path the_path, TrMulParams* params) { + PathSearchOnlyCompiledPaths< + CompiledPaths, (CompiledPaths & kCurrentPath) != Path::kNone, BitNumber, + LhsScalar, RhsScalar, DstScalar, Spec>::Search(the_path, params); + } +}; + +// Termination of the countdown. If the counter reaches -1, then we haven't +// found the specified path. +template +struct PathSearchCountdown { + static void Search(Path the_path, TrMulParams* params) { RUY_DCHECK(false); } +}; + +template +void PopulateTrMulParamsAllCompiledPaths(Path the_path, TrMulParams* params) { + return PathSearchCountdown::Search(the_path, + params); +} + +template +void CreateTrMulParams(const Matrix& lhs, + const Matrix& rhs, const Spec& spec, + Context* context, Matrix* dst, Path the_path, + TrMulParams* params) { + // Fill in the fields we already know. + params->src[Side::kLhs] = ToDMatrix(lhs); + params->src[Side::kRhs] = ToDMatrix(rhs); + params->dst = ToDMatrix(*dst); + params->spec = ToVoidPtr(&spec); + + // Create inner loops and packed matrices based on the Path. + PopulateTrMulParamsAllCompiledPaths(the_path, params); +} + +template +void ReferenceMul(const Matrix& lhs, const Matrix& rhs, + const Spec& spec, Matrix* dst) { + profiler::ScopeLabel label("ReferenceMul"); + for (int i = 0; i < lhs.layout.rows; i++) { + for (int j = 0; j < rhs.layout.cols; j++) { + using AccumScalar = typename Spec::AccumScalar; + AccumScalar accum = 0; + for (int k = 0; k < lhs.layout.cols; k++) { + AccumScalar lhs_val = Element(lhs, i, k); + AccumScalar rhs_val = Element(rhs, k, j); + accum += (lhs_val - lhs.zero_point) * (rhs_val - rhs.zero_point); + } + if (spec.bias) { + accum += spec.bias[i]; + } + ApplyMultiplier(spec, i, &accum); + accum += dst->zero_point; + accum = std::min(accum, spec.clamp_max); + accum = std::max(accum, spec.clamp_min); + *ElementPtr(dst, i, j) = static_cast(accum); + } + } +} + +// Compile-time dispatch to ReferenceMul. This allows us to statically ensure +// that there is no call to ReferenceMul in the user's binary. +template +struct CompileTimeEnabledReferenceMul { + template + static void Run(const Matrix& lhs, const Matrix& rhs, + const Spec& spec, Matrix* dst) { + ReferenceMul(lhs, rhs, spec, dst); + } +}; + +// When this partial specialization is chosen, it ensures that ReferenceMul +// is never compiled. +template <> +struct CompileTimeEnabledReferenceMul { + template + static void Run(const Matrix& lhs, const Matrix& rhs, + const Spec& spec, Matrix* dst) { + RUY_DCHECK(false); + } +}; + +inline void HandlePrepackedCaching(TrMulParams* params, + const SidePair& cacheable, + Context* context) { + if (context->cache_policy == CachePolicy::kNoCache) { + return; + } + + if (context->cache_policy == CachePolicy::kCacheLHSOnNarrowMul) { + // TODO(b/149304278) Cache on dst.cols <= selected kernel width. + if (!cacheable[Side::kLhs] || params->dst.layout.cols > 4) { + return; + } + PrepackedCache* prepacked_cache = context->GetPrepackedCache(); + auto cache_key = std::make_pair(reinterpret_cast(params->run_kernel), + params->src[Side::kLhs].data); + auto it = prepacked_cache->FindAndUpdate(cache_key); + if (it != prepacked_cache->cend()) { + params->packed[Side::kLhs].data = it->second.first.data; + params->packed[Side::kLhs].sums = it->second.first.sums; + params->is_prepacked[Side::kLhs] = true; + return; + } + + // Allocate the prepacked matrix. + PrepackedMatrix prepacked_lhs; + prepacked_lhs.data_size = DataSize(params->packed[Side::kLhs]); + prepacked_lhs.sums_size = SumsSize(params->packed[Side::kLhs]); + prepacked_cache->AllocatePrepackedMatrix(&prepacked_lhs); + params->packed[Side::kLhs].data = prepacked_lhs.data; + params->packed[Side::kLhs].sums = prepacked_lhs.sums; + params->is_prepacked[Side::kLhs] = true; + Tuning tuning = context->GetMainThreadTuning(); + params->RunPack(Side::kLhs, tuning, 0, + params->packed[Side::kLhs].layout.cols); + prepacked_cache->Insert(cache_key, prepacked_lhs); + return; + } +} + +template +void DispatchMul(const Matrix& lhs, const Matrix& rhs, + const Spec& spec, Context* context, Matrix* dst) { + static_assert(CompiledPaths != Path::kNone, "Must compile at least one Path"); + static_assert((CompiledPaths & ~kAllPaths) == Path::kNone, + "CompiledPaths must be a subset of ruy::kAllPaths"); + + profiler::ScopeLabel mul_label("Mul"); + profiler::ScopeLabel shape_specific_label("matmul shape: %dx%dx%d", + lhs.layout.rows, lhs.layout.cols, + rhs.layout.cols); + + EnforceLayoutSupport(lhs.layout, rhs.layout, dst->layout); + EnforceZeroPointSupport(lhs.zero_point, rhs.zero_point, + dst->zero_point); + EnforceDstSpecSupport(spec, dst->zero_point); + + // This should be a constant, for a given machine and CompiledPaths. + // There is a back door to override it for testing, but in production it will + // always be the "best" Path. I.e. the one with the newest SIMD instructions + // available on the present machine, and avoiding Path::kReference unless + // no other path is compiled. + // + // Unfortunately, it is not a *static* constant, since it depends on runtime + // detection of the available SIMD instructions. + Path the_path = context->GetPathToTake(); + + // Production code should probably never execute Path::kReference. + // Path::kReference implements a Mul, not a TrMul like the rest of Ruy, so if + // that's what we need to do, then get it out of the way before going down the + // TrMul path. + if (the_path == Path::kReference) { + constexpr bool ReferenceMulIsEnabled = + (CompiledPaths & Path::kReference) != Path::kNone; + CompileTimeEnabledReferenceMul::Run(lhs, rhs, spec, + dst); + return; + } + + // As described in the comment at the top of this file, Ruy internally + // converts Mul into TrMul. We handle that here. + // + // This is Ruy's main code path. + constexpr Path TrMulCompiledPaths = CompiledPaths & ~Path::kReference; + Matrix transposed_lhs(lhs); + Transpose(&transposed_lhs); + TrMulParams params; + CreateTrMulParams(transposed_lhs, rhs, spec, context, dst, + the_path, ¶ms); + SidePair cacheable(lhs.cacheable, rhs.cacheable); + HandlePrepackedCaching(¶ms, cacheable, context); + TrMul(¶ms, context); +} + +} // namespace ruy + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_DISPATCH_H_ diff --git a/ruy/example.cc b/ruy/example.cc new file mode 100644 index 0000000..3b42c97 --- /dev/null +++ b/ruy/example.cc @@ -0,0 +1,136 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "ruy/ruy.h" + +void ExampleMulFloat(ruy::Context *context) { + const float lhs_data[] = {1, 2, 3, 4}; + const float rhs_data[] = {1, 2, 3, 4}; + float dst_data[4]; + + ruy::Matrix lhs; + ruy::MakeSimpleLayout(2, 2, ruy::Order::kRowMajor, &lhs.layout); + lhs.data = lhs_data; + ruy::Matrix rhs; + ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, &rhs.layout); + rhs.data = rhs_data; + ruy::Matrix dst; + ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, &dst.layout); + dst.data = dst_data; + + ruy::BasicSpec spec; + ruy::Mul(lhs, rhs, spec, context, &dst); + + std::cout << "Example Mul, float:\n"; + std::cout << "LHS:\n" << lhs; + std::cout << "RHS:\n" << rhs; + std::cout << "Result:\n" << dst << "\n"; +} + +void ExampleMulFloatWithBiasAddAndClamp(ruy::Context *context) { + const float lhs_data[] = {1, 2, 3, 4}; + const float rhs_data[] = {1, 2, 3, 4}; + const float bias_data[] = {1, 0}; + float dst_data[4]; + + ruy::Matrix lhs; + ruy::MakeSimpleLayout(2, 2, ruy::Order::kRowMajor, &lhs.layout); + lhs.data = lhs_data; + ruy::Matrix rhs; + ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, &rhs.layout); + rhs.data = rhs_data; + ruy::Matrix dst; + ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, &dst.layout); + dst.data = dst_data; + + ruy::BasicSpec spec; + spec.bias = bias_data; + spec.clamp_min = 0; + spec.clamp_max = 15; + ruy::Mul(lhs, rhs, spec, context, &dst); + + std::cout << "Example Mul, float with bias addition and clamp:\n"; + std::cout << "LHS:\n" << lhs; + std::cout << "RHS:\n" << rhs; + std::cout << "Result:\n" << dst << "\n"; +} + +void ExampleMulUint8AsymmetricQuantized(ruy::Context *context) { + const std::uint8_t lhs_data[] = {124, 125, 126, 127}; + const std::uint8_t rhs_data[] = {129, 130, 131, 132}; + std::uint8_t dst_data[4]; + + ruy::Matrix lhs; + ruy::MakeSimpleLayout(2, 2, ruy::Order::kRowMajor, &lhs.layout); + lhs.data = lhs_data; + lhs.zero_point = 125; + ruy::Matrix rhs; + ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, &rhs.layout); + rhs.data = rhs_data; + rhs.zero_point = 132; + ruy::Matrix dst; + ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, &dst.layout); + dst.data = dst_data; + dst.zero_point = 129; + + ruy::BasicSpec spec; + spec.multiplier_fixedpoint = 1 << 30; + + spec.multiplier_exponent = 0; + ruy::Mul(lhs, rhs, spec, context, &dst); + + std::cout << "Example Mul, uint8 quantized with asymmetric zero points:\n"; + std::cout << "LHS:\n" << lhs; + std::cout << "RHS:\n" << rhs; + std::cout << "Result:\n" << dst << "\n"; +} +void ExampleMulInt8PerChannelQuantized(ruy::Context *context) { + const std::int8_t lhs_data[] = {1, 2, 3, 4}; + const std::int8_t rhs_data[] = {1, 2, 3, 4}; + const std::int32_t multiplier_data[] = {3 << 28, 5 << 28}; + const int exponent_data[] = {1, -2}; + std::int8_t dst_data[4]; + + ruy::Matrix lhs; + ruy::MakeSimpleLayout(2, 2, ruy::Order::kRowMajor, &lhs.layout); + lhs.data = lhs_data; + ruy::Matrix rhs; + ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, &rhs.layout); + rhs.data = rhs_data; + ruy::Matrix dst; + ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, &dst.layout); + dst.data = dst_data; + + ruy::BasicSpec spec; + spec.multiplier_fixedpoint_perchannel = multiplier_data; + spec.multiplier_exponent_perchannel = exponent_data; + ruy::Mul(lhs, rhs, spec, context, &dst); + + std::cout << "Example Mul, int8 quantized with per-channel multipliers\n"; + std::cout << "LHS:\n" << lhs; + std::cout << "RHS:\n" << rhs; + std::cout << "Result:\n" << dst << "\n"; +} + +int main() { + ruy::Context context; + ExampleMulFloat(&context); + ExampleMulFloatWithBiasAddAndClamp(&context); + ExampleMulUint8AsymmetricQuantized(&context); + ExampleMulInt8PerChannelQuantized(&context); +} diff --git a/ruy/example_advanced.cc b/ruy/example_advanced.cc new file mode 100644 index 0000000..9041bdb --- /dev/null +++ b/ruy/example_advanced.cc @@ -0,0 +1,83 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "ruy/ruy_advanced.h" + +// Simple allocator for allocating pre-packed matrices. +class SimpleAllocator { + public: + void* AllocateBytes(std::size_t num_bytes) { + char* p = new char[num_bytes]; + buffers_.emplace_back(p); + return static_cast(p); + } + + private: + std::vector> buffers_; +}; + +void ExamplePrepack(ruy::Context* context) { + const float lhs_data[] = {1, 2, 3, 4}; + const float rhs_data[] = {1, 2, 3, 4}; + float dst_data[4]; + + // Set up the matrix layouts and spec. + ruy::Matrix lhs; + ruy::MakeSimpleLayout(2, 2, ruy::Order::kRowMajor, &lhs.layout); + ruy::Matrix rhs; + ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, &rhs.layout); + ruy::Matrix dst; + ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, &dst.layout); + ruy::BasicSpec spec; + + SimpleAllocator allocator; + auto alloc_fn = [&allocator](std::size_t num_bytes) -> void* { + return allocator.AllocateBytes(num_bytes); + }; + + // In this example, we pre-pack only the RHS, but either will work. + // Note that we only need to set the data pointer for the matrix we are + // pre-packing. + ruy::PrepackedMatrix prepacked_rhs; + rhs.data = rhs_data; + ruy::PrePackForMul(lhs, rhs, spec, context, &dst, + /*prepacked_lhs=*/nullptr, &prepacked_rhs, + alloc_fn); + + // No data will be read from the RHS input matrix when using a pre-packed RHS. + rhs.data = nullptr; + lhs.data = lhs_data; + dst.data = dst_data; + ruy::MulWithPrepacked(lhs, rhs, spec, context, &dst, + /*prepacked_lhs=*/nullptr, + &prepacked_rhs); + rhs.data = rhs_data; + + // Print out the results. + std::cout << "Example Mul with pre-packing RHS, float:\n"; + std::cout << "LHS:\n" << lhs; + std::cout << "RHS:\n" << rhs; + std::cout << "Result:\n" << dst << "\n"; +} + +int main() { + ruy::Context context; + ExamplePrepack(&context); +} diff --git a/ruy/have_built_path_for.h b/ruy/have_built_path_for.h new file mode 100644 index 0000000..8913965 --- /dev/null +++ b/ruy/have_built_path_for.h @@ -0,0 +1,32 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_HAVE_BUILT_PATH_FOR_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_HAVE_BUILT_PATH_FOR_H_ + +#include "ruy/platform.h" + +namespace ruy { + +#if RUY_PLATFORM(X86) +bool HaveBuiltPathForSse42(); +bool HaveBuiltPathForAvx2(); +bool HaveBuiltPathForAvx512(); +bool HaveBuiltPathForAvxVnni(); +#endif // RUY_PLATFORM(X86) + +} // namespace ruy + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_HAVE_BUILT_PATH_FOR_H_ diff --git a/ruy/have_built_path_for_avx2.cc b/ruy/have_built_path_for_avx2.cc new file mode 100644 index 0000000..ceca8a4 --- /dev/null +++ b/ruy/have_built_path_for_avx2.cc @@ -0,0 +1,35 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/have_built_path_for.h" +#include "ruy/opt_set.h" + +namespace ruy { + +#if RUY_PLATFORM(X86) +// IMPORTANT: +// These patterns must match those in the pack and kernel cc files. +#if !(RUY_PLATFORM(AVX2) && RUY_OPT_ENABLED(RUY_OPT_ASM)) + +bool HaveBuiltPathForAvx2() { return false; } + +#else // RUY_PLATFORM(AVX2) && RUY_OPT_ENABLED(RUY_OPT_ASM) + +bool HaveBuiltPathForAvx2() { return true; } + +#endif // RUY_PLATFORM(AVX2) && RUY_OPT_ENABLED(RUY_OPT_ASM) +#endif // RUY_PLATFORM(X86) + +} // namespace ruy diff --git a/ruy/have_built_path_for_avx512.cc b/ruy/have_built_path_for_avx512.cc new file mode 100644 index 0000000..15fba62 --- /dev/null +++ b/ruy/have_built_path_for_avx512.cc @@ -0,0 +1,35 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/have_built_path_for.h" +#include "ruy/opt_set.h" + +namespace ruy { + +#if RUY_PLATFORM(X86) +// IMPORTANT: +// These patterns must match those in the pack and kernel cc files. +#if !(RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_ASM)) + +bool HaveBuiltPathForAvx512() { return false; } + +#else // RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_ASM) + +bool HaveBuiltPathForAvx512() { return true; } + +#endif // RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_ASM) +#endif // RUY_PLATFORM(X86) + +} // namespace ruy diff --git a/ruy/have_built_path_for_avxvnni.cc b/ruy/have_built_path_for_avxvnni.cc new file mode 100644 index 0000000..68ef2a2 --- /dev/null +++ b/ruy/have_built_path_for_avxvnni.cc @@ -0,0 +1,39 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/have_built_path_for.h" +#include "ruy/opt_set.h" + +namespace ruy { + +#if RUY_PLATFORM(X86) +// IMPORTANT: +// These patterns must match those in the pack and kernel cc files. +#if !(RUY_PLATFORM(AVX_VNNI) && RUY_OPT_ENABLED(RUY_OPT_ASM)) + +bool HaveBuiltPathForAvxVnni() { return false; } + +#else // RUY_PLATFORM(AVX_VNNI) && RUY_OPT_ENABLED(RUY_OPT_ASM) + +// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. +// Optimization is not finished. In particular the dimensions of the kernel +// blocks can be changed as desired. +// +bool HaveBuiltPathForAvxVnni() { return true; } + +#endif // RUY_PLATFORM(AVX_VNNI) && RUY_OPT_ENABLED(RUY_OPT_ASM) +#endif // RUY_PLATFORM(X86) + +} // namespace ruy diff --git a/ruy/have_built_path_for_sse42.cc b/ruy/have_built_path_for_sse42.cc new file mode 100644 index 0000000..2141b75 --- /dev/null +++ b/ruy/have_built_path_for_sse42.cc @@ -0,0 +1,39 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/have_built_path_for.h" +#include "ruy/opt_set.h" + +namespace ruy { + +#if RUY_PLATFORM(X86) +// IMPORTANT: +// These patterns must match those in the pack and kernel cc files. +#if !(RUY_PLATFORM(SSE42) && RUY_OPT_ENABLED(RUY_OPT_ASM)) + +bool HaveBuiltPathForSse42() { return false; } + +#else // RUY_PLATFORM(SSE42) && RUY_OPT_ENABLED(RUY_OPT_ASM) + +// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. +// Optimization is not finished. In particular the dimensions of the kernel +// blocks can be changed as desired. +// +bool HaveBuiltPathForSse42() { return true; } + +#endif // RUY_PLATFORM(SSE42) && RUY_OPT_ENABLED(RUY_OPT_ASM) +#endif // RUY_PLATFORM(X86) + +} // namespace ruy diff --git a/ruy/internal_matrix.h b/ruy/internal_matrix.h new file mode 100644 index 0000000..7fe13be --- /dev/null +++ b/ruy/internal_matrix.h @@ -0,0 +1,388 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Internal types and helpers for matrices. +// +// Ruy has a couple slightly different notions of matrices, besides the +// Matrix class that we expose to the user-facing API. +// +// TODO(silvasean): Put parts of this architecture description somewhere more +// prominent. +// +// The 4 main matrix types are: +// - Matrix: This is a user-facing type on Ruy's external API boundary. It is +// also used internally. +// - DMatrix: This is a type-erased version of Matrix. "D" = "dynamic". +// - PMatrix: This represents a packed matrix, which requires tracking kernel +// layout and row/column sums for quantization. It is type-erased. +// - PackedMatrix: This is a statically typed variant of PMatrix for +// convenience inside typed routines. +// +// Note that Matrix is *not* implemented in terms of the internal types. It +// is an independent, simple, and user-facing type. +// +// The use of type-erasure might seem surprising for a library like Ruy with a +// heavily-templated entry point, but it is motivated by the desire for most of +// Ruy's "middle-end" to be non-templated. Ruy can be thought of as having 3 +// main parts: +// - "front-end" (dispatch.h) - this is the highly templated ruy::Mul entry +// point, along with routines that select RunKernel and RunPack implementations +// statically based on those template parameters. +// - "back-end" (kernel.h, pack.h)- this consists of the implementations of +// RunKernel and RunPack, often in assembly code, which are the building blocks +// that Ruy calls to perform matrix multiplication. These are templated so that +// only the requested types/Path's are actually emitted by the compiler. +// - "middle-end" (trmul.h) - this is the part of Ruy that orchestrates the +// calls to the "back-end" optimized building blocks. This layer has to deal +// with issues like cache locality and low-overhead multi-threading. +// +// There is a desire for the "middle-end" to be non-templated in order to +// simplify the implementation and reduce code-size. We type-erase when going +// from the "front-end" to the "middle-end", and un-type-erase going from the +// "middle-end" to the "back-end". The un-type-erasure is possible because the +// "front-end" is responsible for instantiating the needed "back-end" templates, +// and thus the static type information is still present. +// +// Each layer of Ruy uses matrix types: +// - "front-end": Matrix +// - "middle-end": DMatrix, PMatrix +// - "back-end": Matrix, PackedMatrix +// +// The use of separate types for packed matrices is not essential, but makes it +// obvious at a glance whether a matrix is a packed matrix or not. We would +// reconsider this decision if there was significant duplication between packed +// and unpacked matrices, but that doesn't seem to be the case at the moment. +// +// Another goal is to keep the user-facing Matrix as simple and +// understandable as possible. Ideally, a user should be able to read the struct +// definition for Matrix and see a very simple definition with no internal +// details like sums and kernel block layout. +// +// To present another structured view of our various matrix types, here's a +// table: +// Plain matrices Packed matrices +// +---------------------------------- +// Templated | Matrix PackedMatrix +// Type-erased | DMatrix PMatrix +// +// +// There is 1 additional matrix type not mentioned above, due to its low +// importance: +// - PrepackedMatrix: This is a user-facing version of PMatrix. It has the bare +// minimum of fields needed for representing the raw data and sums buffers of a +// packed matrix for the "advanced" explicit pre-packing API. This type plays no +// role in Ruy's internals and can generally by ignored. The only reason it +// exists is so that PMatrix is not exposed to users -- we prefer to keep the +// internal matrix types hidden, even from "advanced" users. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_INTERNAL_MATRIX_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_INTERNAL_MATRIX_H_ + +#include +#include +#include +#include + +#include "ruy/check_macros.h" +#include "ruy/common.h" +#include "ruy/matrix.h" +#include "ruy/size_util.h" + +namespace ruy { + +// KernelLayout describes small-scale block structure in a packed matrix layout. +// It's a runtime (as opposed to compile-time-constant) version of the +// FixedKernelLayout struct used to declare kernel layouts. +// +// This is is sometimes known as "tiling" in other contexts. +// +// For example, consider a packed matrix in column-major format with a +// column-major KernelLayout. The matrix logically has a shape of +// `[cols, rows]`. However, the matrix is laid out as though it were a 4D array +// of shape `[cols / kcols, rows / krows, kcols, krows]`. +// +// Note that in the case of kcols=1, krows=1, this degenerates to +// `[cols, rows, 1, 1]` which is equivalent to having no small-scale block +// structure. +struct KernelLayout { + Order order = Order::kColMajor; + std::uint8_t rows = 1; + std::uint8_t cols = 1; +}; + +// A packed matrix has a small-scale block structure that is not present in in +// the input matrices. This block structure is necessary for the kernels to +// process data efficiently. +// +// This struct is very similar to Layout, but has the extra KernelLayout field. +struct PackedLayout { + std::int32_t rows = 0; + std::int32_t cols = 0; + // Stride is the offset between two adjacent matrix elements + // in the non-contiguous direction. + std::int32_t stride = 0; + Order order = Order::kColMajor; + // Small scale layout shuffling, potentially departing from + // linear row-major or column-major storage. See KernelLayout. + KernelLayout kernel; +}; + +// Dynamic representation for a type. +// +// The most important field in this struct is the size, which Ruy uses to know +// how much memory to allocate without having to be templated on a type. +// Signed-ness and floating-point-ness are mainly present as debugging checks. +// +// Note: Ruy does not use this struct to to dynamically dispatch between +// different typed implementations. As described in the comment at the top of +// this file, Ruy's "front-end", which is templated, instantiates all the +// necessary "back-end" routines with complete static knowledge of all the +// types. +struct Type { + template + static Type Create() { + Type ret; + ret.is_signed = std::is_signed::value; + ret.is_floating_point = std::is_floating_point::value; + ret.size = sizeof(T); + return ret; + } + + template + void AssertIs() const { + RUY_DCHECK_EQ(is_signed, Create().is_signed); + RUY_DCHECK_EQ(is_floating_point, Create().is_floating_point); + RUY_DCHECK_EQ(size, Create().size); + } + + bool is_signed = false; + bool is_floating_point = false; + std::uint8_t size = 0; +}; + +// Type-erased matrix. +struct DMatrix { + Type data_type; + void* data = nullptr; + Layout layout; + std::int32_t zero_point = 0; +}; + +// Type-erased packed matrix. +struct PMatrix { + Type data_type; + void* data = nullptr; + Type sums_type; + void* sums = nullptr; + PackedLayout layout; + std::int32_t zero_point = 0; +}; + +// Convenient typed helper for packed matrices. +template +struct PackedMatrix { + // The row/column sums needed for quantized matrix multiplication when + // the opposite operand of the multiplication uses a non-symmetric zero + // point. + // This member is only relevant for packed matrices. + // Additionally, Ruy always uses 32-bit signed accumulators for quantized + // matrix multiplication. + // For floating point types, there is no quantization, so this pointer + // will always be null. We still need code referencing it to compile + // though, even if it is always branched around. Hence we use Scalar* + // itself as the type in that case. + using SumsType = + typename std::conditional::value, Scalar, + std::int32_t>::type; + + Scalar* data = nullptr; + SumsType* sums = nullptr; + PackedLayout layout; + std::int32_t zero_point = 0; +}; + +template +DMatrix ToDMatrix(const Matrix& matrix) { + DMatrix ret; + ret.data_type = Type::Create(); + ret.data = ToVoidPtr(matrix.data.get()); + ret.layout = matrix.layout; + ret.zero_point = matrix.zero_point; + return ret; +} + +template +Matrix ToMatrix(const DMatrix& dmatrix) { + dmatrix.data_type.AssertIs(); + Matrix ret; + ret.data = static_cast(dmatrix.data); + ret.layout = dmatrix.layout; + ret.zero_point = dmatrix.zero_point; + return ret; +} + +template +PackedMatrix ToPackedMatrix(const PMatrix& pmatrix) { + using SumsType = typename PackedMatrix::SumsType; + pmatrix.data_type.AssertIs(); + pmatrix.sums_type.AssertIs(); + PackedMatrix ret; + ret.data = static_cast(pmatrix.data); + ret.sums = static_cast(pmatrix.sums); + ret.layout = pmatrix.layout; + ret.zero_point = pmatrix.zero_point; + return ret; +} + +// Helpers for Layout / PackedLayout. + +inline bool IsPacked(const Layout& layout) { + if (layout.order == Order::kColMajor) { + return layout.stride == layout.rows; + } else { + return layout.stride == layout.cols; + } +} + +inline bool IsRowMajor(const Layout& layout) { + return layout.order == Order::kRowMajor; +} + +template +inline bool IsColMajor(const LayoutOrPackedLayout& layout) { + return layout.order == Order::kColMajor; +} + +template +inline int FlatSize(const LayoutOrPackedLayout& layout) { + const int outerdim = + layout.order == Order::kColMajor ? layout.cols : layout.rows; + return layout.stride * outerdim; +} + +// TODO(b/130417400) add a unit test +inline int Offset(const Layout& layout, int row, int col) { + // TODO(benoitjacob) - should check this but this make the _slow tests take + // 5x longer. Find a mitigation like in Eigen with an 'internal' variant + // bypassing the check? + // RUY_DCHECK_GE(row, 0); + // RUY_DCHECK_GE(col, 0); + // RUY_DCHECK_LT(row, layout.rows); + // RUY_DCHECK_LT(col, layout.cols); + int row_stride = layout.order == Order::kColMajor ? 1 : layout.stride; + int col_stride = layout.order == Order::kRowMajor ? 1 : layout.stride; + return row * row_stride + col * col_stride; +} + +// TODO(b/130417400) add a unit test +inline int Offset(const PackedLayout& layout, int row, int col) { + RUY_DCHECK(is_pot(layout.kernel.rows)); + RUY_DCHECK(is_pot(layout.kernel.cols)); + int row_outer = row & ~(layout.kernel.rows - 1); + int col_outer = col & ~(layout.kernel.cols - 1); + int row_stride_outer = + layout.order == Order::kColMajor ? layout.kernel.cols : layout.stride; + int col_stride_outer = + layout.order == Order::kRowMajor ? layout.kernel.rows : layout.stride; + int offset_outer = + row_outer * row_stride_outer + col_outer * col_stride_outer; + int row_inner = row - row_outer; + int col_inner = col - col_outer; + int row_stride_inner = + layout.kernel.order == Order::kColMajor ? 1 : layout.kernel.cols; + int col_stride_inner = + layout.kernel.order == Order::kRowMajor ? 1 : layout.kernel.rows; + int offset_inner = + row_inner * row_stride_inner + col_inner * col_stride_inner; + return offset_outer + offset_inner; +} + +// Helpers for Matrix. + +template +const Scalar* ElementPtr(const Matrix& mat, int row, int col) { + return mat.data.get() + Offset(mat.layout, row, col); +} + +template +Scalar* ElementPtr(Matrix* mat, int row, int col) { + return mat->data.get() + Offset(mat->layout, row, col); +} + +template +Scalar Element(const Matrix& mat, int row, int col) { + return *ElementPtr(mat, row, col); +} + +// Helpers for PackedMatrix. +// Duplicated from Matrix, but the duplication seems acceptable. + +template +const Scalar* ElementPtr(const PackedMatrix& mat, int row, int col) { + return mat.data + Offset(mat.layout, row, col); +} + +template +Scalar* ElementPtr(PackedMatrix* mat, int row, int col) { + return mat->data + Offset(mat->layout, row, col); +} + +template +Scalar Element(const PackedMatrix& mat, int row, int col) { + return *ElementPtr(mat, row, col); +} + +// Helpers for PMatrix. + +inline std::size_t DataSize(const PMatrix& packed) { + return FlatSize(packed.layout) * packed.data_type.size; +} + +inline std::size_t SumsSize(const PMatrix& packed) { + // Packed matrices are only relevant for Ruy's TrMul implementations. For + // TrMul, the number of sums is always equal to the number of columns. + return packed.layout.cols * packed.sums_type.size; +} + +// Transpose helpers. + +inline void Transpose(Order* order) { + *order = *order == Order::kColMajor ? Order::kRowMajor : Order::kColMajor; +} + +inline void Transpose(Layout* layout) { + Transpose(&layout->order); + std::swap(layout->rows, layout->cols); +} + +template +inline void Transpose(Matrix* matrix) { + Transpose(&matrix->layout); +} + +// Helpers for KernelLayout. + +template +KernelLayout ToKernelLayout() { + KernelLayout ret; + ret.order = FixedKernelLayout::kOrder; + ret.rows = FixedKernelLayout::kRows; + ret.cols = FixedKernelLayout::kCols; + return ret; +} + +} // namespace ruy + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_INTERNAL_MATRIX_H_ diff --git a/ruy/kernel.h b/ruy/kernel.h new file mode 100644 index 0000000..d7930b4 --- /dev/null +++ b/ruy/kernel.h @@ -0,0 +1,31 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_KERNEL_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_KERNEL_H_ + +#include "ruy/platform.h" + +// IWYU pragma: begin_exports +#if RUY_PLATFORM(NEON) +#include "ruy/kernel_arm.h" +#elif RUY_PLATFORM(X86) +#include "ruy/kernel_x86.h" +#else +#include "ruy/kernel_common.h" +#endif +// IWYU pragma: end_exports + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_KERNEL_H_ diff --git a/ruy/kernel_arm.h b/ruy/kernel_arm.h new file mode 100644 index 0000000..408c23a --- /dev/null +++ b/ruy/kernel_arm.h @@ -0,0 +1,211 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_KERNEL_ARM_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_KERNEL_ARM_H_ + +#include +#include + +#include "ruy/common.h" +#include "ruy/internal_matrix.h" +#include "ruy/kernel_common.h" +#include "ruy/matrix.h" +#include "ruy/opt_set.h" +#include "ruy/path.h" +#include "ruy/platform.h" +#include "ruy/profiler/instrumentation.h" +#include "ruy/side_pair.h" +#include "ruy/size_util.h" +#include "ruy/spec.h" +#include "ruy/tune.h" + +namespace ruy { + +#if RUY_PLATFORM(NEON) && RUY_OPT_ENABLED(RUY_OPT_ASM) + +#if RUY_PLATFORM(NEON_64) +void Kernel8bitNeonOutOfOrder(const KernelParams8bit<4, 4>& params); +void Kernel8bitNeonOutOfOrder1Col(const KernelParams8bit<4, 4>& params); +#elif RUY_PLATFORM(NEON_32) +void Kernel8bitNeonOutOfOrder(const KernelParams8bit<4, 2>& params); +void Kernel8bitNeonOutOfOrder1Col(const KernelParams8bit<4, 2>& params); +#endif +void Kernel8bitNeonInOrder(const KernelParams8bit<4, 4>& params); +void Kernel8bitNeonDotprodOutOfOrder(const KernelParams8bit<8, 8>& params); +void Kernel8bitNeonDotprodOutOfOrder1Col(const KernelParams8bit<8, 8>& params); +void Kernel8bitNeonDotprodInOrder(const KernelParams8bit<8, 8>& params); + +#if RUY_PLATFORM(NEON_64) +template +struct Kernel> { + using LhsLayout = FixedKernelLayout; + using RhsLayout = FixedKernelLayout; + Tuning tuning = Tuning::kAuto; + explicit Kernel(Tuning tuning_) : tuning(tuning_) {} + void Run(const PackedMatrix& lhs, + const PackedMatrix& rhs, + const BasicSpec& spec, int start_row, + int start_col, int end_row, int end_col, + Matrix* dst) const { + KernelParams8bit params; + MakeKernelParams8bit(lhs, rhs, spec, start_row, start_col, end_row, end_col, + dst, ¶ms); + if (dst->layout.cols == 1) { + Kernel8bitNeonOutOfOrder1Col(params); + return; + } + if (__builtin_expect(tuning == Tuning::kInOrder, true)) { + Kernel8bitNeonInOrder(params); + } else { + Kernel8bitNeonOutOfOrder(params); + } + } +}; +#endif + +#if RUY_PLATFORM(NEON_32) +template +struct Kernel> { + using LhsLayout = FixedKernelLayout; + using RhsLayout = FixedKernelLayout; + Tuning tuning = Tuning::kAuto; + explicit Kernel(Tuning tuning_) : tuning(tuning_) {} + void Run(const PackedMatrix& lhs, + const PackedMatrix& rhs, + const BasicSpec& spec, int start_row, + int start_col, int end_row, int end_col, + Matrix* dst) const { + KernelParams8bit params; + MakeKernelParams8bit(lhs, rhs, spec, start_row, start_col, end_row, end_col, + dst, ¶ms); + if (dst->layout.cols == 1) { + Kernel8bitNeonOutOfOrder1Col(params); + return; + } + Kernel8bitNeonOutOfOrder(params); + } +}; +#endif + +#if RUY_PLATFORM(NEON_64) +template +struct Kernel> { + Tuning tuning = Tuning::kAuto; + using LhsLayout = FixedKernelLayout; + using RhsLayout = FixedKernelLayout; + explicit Kernel(Tuning tuning_) : tuning(tuning_) {} + void Run(const PackedMatrix& lhs, + const PackedMatrix& rhs, + const BasicSpec& spec, int start_row, + int start_col, int end_row, int end_col, + Matrix* dst) const { + KernelParams8bit params; + MakeKernelParams8bit(lhs, rhs, spec, start_row, start_col, end_row, end_col, + dst, ¶ms); + if (dst->layout.cols == 1) { + Kernel8bitNeonDotprodOutOfOrder1Col(params); + } else if (__builtin_expect(tuning == Tuning::kInOrder, true)) { + Kernel8bitNeonDotprodInOrder(params); + } else { + Kernel8bitNeonDotprodOutOfOrder(params); + } + } +}; +#endif + +void KernelFloatNeonOutOfOrder(const KernelParamsFloat<8, 8>& params); +void KernelFloatNeonInOrder(const KernelParamsFloat<8, 8>& params); +void KernelFloat32NeonOutOfOrder(const KernelParamsFloat<8, 4>& params); +void KernelFloatNeonDotprodInOrder(const KernelParamsFloat<8, 8>& params); + +#if RUY_PLATFORM(NEON_64) +// A Float kernel for ARM64 Neon. +template <> +struct Kernel> { + Tuning tuning = Tuning::kAuto; + using LhsLayout = FixedKernelLayout; + using RhsLayout = FixedKernelLayout; + explicit Kernel(Tuning tuning_) : tuning(tuning_) {} + void Run(const PackedMatrix& lhs, const PackedMatrix& rhs, + const BasicSpec& spec, int start_row, int start_col, + int end_row, int end_col, Matrix* dst) const { + KernelParamsFloat params; + MakeKernelParamsFloat(lhs, rhs, spec, start_row, start_col, end_row, + end_col, dst, ¶ms); + if (__builtin_expect(tuning == Tuning::kInOrder, true)) { + KernelFloatNeonInOrder(params); + } else { + KernelFloatNeonOutOfOrder(params); + } + } +}; +#endif + +#if RUY_PLATFORM(NEON_32) +// A Float kernel for ARM32 Neon. +template <> +struct Kernel> { + Tuning tuning = Tuning::kAuto; + using LhsLayout = FixedKernelLayout; + using RhsLayout = FixedKernelLayout; + explicit Kernel(Tuning tuning_) : tuning(tuning_) {} + void Run(const PackedMatrix& lhs, const PackedMatrix& rhs, + const BasicSpec& spec, int start_row, int start_col, + int end_row, int end_col, Matrix* dst) const { + KernelParamsFloat<8, 4> params; + + MakeKernelParamsFloat(lhs, rhs, spec, start_row, start_col, end_row, + end_col, dst, ¶ms); + + KernelFloat32NeonOutOfOrder(params); + } +}; +#endif + +// While the dotprod NEON extension does not concern floating-point arithmetic, +// its presence allows us to distinguish, in the in-order tuning case, between +// A53 and A55r1. TODO: should this be folded into tuning? +template <> +struct Kernel> { + Tuning tuning = Tuning::kAuto; + using LhsLayout = FixedKernelLayout; + using RhsLayout = FixedKernelLayout; + using Base = + Kernel>; + explicit Kernel(Tuning tuning_) : tuning(tuning_) {} + void Run(const PackedMatrix& lhs, const PackedMatrix& rhs, + const BasicSpec& spec, int start_row, int start_col, + int end_row, int end_col, Matrix* dst) const { + KernelParamsFloat params; + MakeKernelParamsFloat(lhs, rhs, spec, start_row, start_col, end_row, + end_col, dst, ¶ms); + if (__builtin_expect(tuning == Tuning::kInOrder, true)) { + KernelFloatNeonDotprodInOrder(params); + } else { + KernelFloatNeonOutOfOrder(params); + } + } +}; + +#endif // RUY_PLATFORM(NEON) && RUY_OPT_ENABLED(RUY_OPT_ASM) + +} // namespace ruy + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_KERNEL_ARM_H_ diff --git a/ruy/kernel_arm32.cc b/ruy/kernel_arm32.cc new file mode 100644 index 0000000..d537cfe --- /dev/null +++ b/ruy/kernel_arm32.cc @@ -0,0 +1,2499 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/kernel.h" +#include "ruy/opt_set.h" +#include "ruy/platform.h" +#include "ruy/profiler/instrumentation.h" + +namespace ruy { + +#if RUY_PLATFORM(NEON_32) && RUY_OPT_ENABLED(RUY_OPT_ASM) + +#define RUY_ASM_LABEL_STORE_UINT8 91 +#define RUY_ASM_LABEL_STORE_INT8 92 +#define RUY_ASM_LABEL_STORE_INT16 93 +#define RUY_ASM_LABEL_STORE_INT32 94 +#define RUY_ASM_LABEL_AFTER_STORE 99 + +#define RUY_OFFSET_LHS_BASE_PTR 0 +#define RUY_OFFSET_RHS_BASE_PTR 4 +#define RUY_OFFSET_DST_BASE_PTR 8 +#define RUY_OFFSET_BIAS 12 +#define RUY_OFFSET_START_ROW 16 +#define RUY_OFFSET_START_COL 20 +#define RUY_OFFSET_LAST_ROW 24 +#define RUY_OFFSET_LAST_COL 28 +#define RUY_OFFSET_DST_ROWS 32 +#define RUY_OFFSET_DST_COLS 36 +#define RUY_OFFSET_LHS_STRIDE 40 +#define RUY_OFFSET_RHS_STRIDE 44 +#define RUY_OFFSET_DST_STRIDE 48 +#define RUY_OFFSET_DEPTH 52 +#define RUY_OFFSET_CLAMP_MIN 56 +#define RUY_OFFSET_CLAMP_MAX 60 +#define RUY_OFFSET_FLAGS 64 + +#define RUY_STACK_OFFSET_SIZE 96 +#define RUY_STACK_OFFSET_DST_COL_PTR 0 +#define RUY_STACK_OFFSET_DST_PTR 16 +#define RUY_STACK_OFFSET_ROW 32 +#define RUY_STACK_OFFSET_COL 48 +#define RUY_STACK_OFFSET_LHS_COL_PTR 64 +#define RUY_STACK_OFFSET_RHS_COL_PTR 80 + +template +void CheckOffsetsInKernelParamsFloat32(const Params&) { + static_assert(offsetof(Params, lhs_base_ptr) == RUY_OFFSET_LHS_BASE_PTR, ""); + static_assert(offsetof(Params, rhs_base_ptr) == RUY_OFFSET_RHS_BASE_PTR, ""); + static_assert(offsetof(Params, dst_base_ptr) == RUY_OFFSET_DST_BASE_PTR, ""); + static_assert(offsetof(Params, bias) == RUY_OFFSET_BIAS, ""); + static_assert(offsetof(Params, start_row) == RUY_OFFSET_START_ROW, ""); + static_assert(offsetof(Params, start_col) == RUY_OFFSET_START_COL, ""); + static_assert(offsetof(Params, last_row) == RUY_OFFSET_LAST_ROW, ""); + static_assert(offsetof(Params, last_col) == RUY_OFFSET_LAST_COL, ""); + static_assert(offsetof(Params, dst_rows) == RUY_OFFSET_DST_ROWS, ""); + static_assert(offsetof(Params, lhs_stride) == RUY_OFFSET_LHS_STRIDE, ""); + static_assert(offsetof(Params, rhs_stride) == RUY_OFFSET_RHS_STRIDE, ""); + static_assert(offsetof(Params, dst_stride) == RUY_OFFSET_DST_STRIDE, ""); + static_assert(offsetof(Params, depth) == RUY_OFFSET_DEPTH, ""); + static_assert(offsetof(Params, clamp_min) == RUY_OFFSET_CLAMP_MIN, ""); + static_assert(offsetof(Params, clamp_max) == RUY_OFFSET_CLAMP_MAX, ""); + static_assert(offsetof(Params, flags) == RUY_OFFSET_FLAGS, ""); +} + +// Float kernel for ARM32 out-of-order cores. +// Just like Float 64 version, except accumulate in to 8x4 block to only +// use 16 128-bit NEON registers. This is a "first pass" kernel and not +// tuned. It is meant to run on out-of-order CPUs like the Krait 400 or A9. +void KernelFloat32NeonOutOfOrder(const KernelParamsFloat<8, 4>& params) { + CheckOffsetsInKernelParamsFloat32(params); + profiler::ScopeLabel label( + "Kernel (kNeon, optimized for out-of-order cores)"); + + const float* lhs_ptr = params.lhs_base_ptr; + const float* rhs_ptr = params.rhs_base_ptr; + // In ARM32 NEON, there are 16 128-bit "q" registers. These registers are + // each composed of two 64-bit "d" registers. The asm kernel below has the + // following NEON register allocation: + // Registers q3 -- q10 are accumulators. During accumulation, + // q0 -- q2 (d0 -- d5) are used to load data from LHS and RHS. q0 and q1 + // are used to load a 8x1 block of LHS, and q2 is used to load a 1x4 block + // of RHS, like this: + + // Register layout in "q" registers: + // RHS 1x4 block + // /--------------------------\ + // |q2.s[0] ... q2.s[3] | + // \--------------------------/ + // LHS 8x1 block + // /---------------------\ /--------------------- \ + // | q0.s[0] | | q3.s[0] ... q9.s[0] | + // | ... | | ... ... | + // | q0.s[3] | | q3.s[3] q9.s[3] | + // | q1.s[0] | | q4.s[0] q10.s[0] | + // | ... | | ... ... ... | + // | q1.s[3] | | q4.s[3] .. q10.s[3] | + // \---------------------/ \--------------------------/ + // accumulators 8x4 block + // q11, q14, q15 currently unused. q12 and q13 are used to load + // parameters used for the post-accumulation part of the kernel. + // For completeness, here is the register layout in "d" registers: + // RHS 1x4 block + // /--------------------------\ + // |d4[0] ... d5[1] | + // \--------------------------/ + // LHS 8x1 block + // /---------------------\ /--------------------------\ + // | d0[0] | | d6[0] ... d18[0] | + // | ... | | ... ... | + // | d1[1] | | d7[1] d19[1] | + // | d2[0] | | d8[0] d20[0] | + // | ... | | ... ... ... | + // | d3[1] | | d9[1] ... d21[1] | + // \---------------------/ \--------------------------/ + // accumulators 8x4 block + asm volatile( +#define RUY_MAKE_ZERO(reg) "vmov.f32 " #reg ", #0.0\n" + + // clang-format off + + // Load the first 32 bytes of LHS and RHS data. + // Load q0, q1 + "vld1.32 {d0, d1}, [%[lhs_ptr]]!\n" + "vld1.32 {d2, d3}, [%[lhs_ptr]]!\n" + RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n") + // Load q2 + "vld1.32 {d4, d5}, [%[rhs_ptr]]!\n" + RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n") + + "sub sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n" + + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n" + "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n" + + "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n" + "str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" + "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + + "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_START_COL) "]\n" + "str r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" + + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" + + "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_RHS_BASE_PTR) "]\n" + "str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" + // Clear accumulators. + RUY_MAKE_ZERO(q3) + RUY_MAKE_ZERO(q4) + RUY_MAKE_ZERO(q5) + RUY_MAKE_ZERO(q6) + RUY_MAKE_ZERO(q7) + RUY_MAKE_ZERO(q8) + RUY_MAKE_ZERO(q9) + RUY_MAKE_ZERO(q10) + + // r1 is the number of levels of depth that we have already loaded + // LHS and RHS data for. Corresponding to the initial ld1 instructions + // above, this is currently 1. + "mov r1, #1\n" + + // Main loop of the whole GEMM, over rows and columns of the + // destination matrix. + "1:\n" + + // Accumulation loop + "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" + "cmp r1, r2\n" + "beq 79f\n" + + "2:\n" + + "vmla.f32 q3, q0, d4[0]\n" + "vmla.f32 q5, q0, d4[1]\n" + "vmla.f32 q7, q0, d5[0]\n" + "vmla.f32 q9, q0, d5[1]\n" + "vld1.32 {d0, d1}, [%[lhs_ptr]]!\n" // Reload LHS + + "vmla.f32 q4, q1, d4[0]\n" + "vmla.f32 q6, q1, d4[1]\n" + "vmla.f32 q8, q1, d5[0]\n" + "vmla.f32 q10, q1, d5[1]\n" + "vld1.32 {d2, d3}, [%[lhs_ptr]]!\n" // Reload LHS + RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n") + "vld1.32 {d4, d5}, [%[rhs_ptr]]!\n" // Reload RHS + RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n") + + "add r1, r1, #1\n" + "cmp r1, r2\n" + + "blt 2b\n" + + "79:\n" + + // End of the inner loop on depth. Now perform the remaining + // multiply-adds of the last level of depth, for which the LHS + // and RHS data is already loaded. + + "vmla.f32 q3, q0, d4[0]\n" + "vmla.f32 q5, q0, d4[1]\n" + "vmla.f32 q7, q0, d5[0]\n" + "vmla.f32 q9, q0, d5[1]\n" + + "vmla.f32 q4, q1, d4[0]\n" + "vmla.f32 q6, q1, d4[1]\n" + "vmla.f32 q8, q1, d5[0]\n" + "vmla.f32 q10, q1, d5[1]\n" + + // End of accumulation. The registers q3 -- q10 contain the final + // float32 accumulator values of the current 8x8 destination block. + // We now have to compute the final values from these accumulators + // and advance to the next 8x8 block. We intertwine + // these two aspects whenever possible for optimal pipelining, both + // at the data flow level (prefetch data for next block as early as + // possible) and instruction pipelining level (some of the next-block + // work can dual-issue with some of the final work on the current + // block). + + // Logic to advance to the next block in preparation for the next + // iteration of the main loop. For now, we only want to compute + // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are + // not yet ready to update the values of row and col, as we still need + // the current values for the rest of the work on the current block. + + "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" + "ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + "cmp r1, r3\n" // Have we finished the last row? + + "bge 4f\n" // If finished last row, go to 4 + // Not finished last row: then advance to next row. + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" + "add r4, r4, r1, lsl #3\n" + "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" + "b 5f\n" + "4:\n" // Finished last row... + "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + // Go back to first row + "str r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" + // Now we need to advance to the next column. If we already + // finished the last column, then in principle we are done, however + // we can't just return here, as we need to allow the end work of the + // current block to complete. The good news is that at this point it + // doesn't matter what data we load for the next column, since + // we will exit from the main loop below before actually storing + // anything computed from that data. + "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" + "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" + "cmp r8, r4\n" // Have we finished the last column? + "bge 5f\n" // If yes, just carry on without updating the column pointer. + // Not finished last column: then advance to next column. + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" + "ldr r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" + "add r10, r10, r1, lsl #2\n" + "str r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" + "5:\n" + + // Set the LHS and RHS data pointers to the start of the columns just + // computed. + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" + "mov %[lhs_ptr], r4\n" + "ldr r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" + "mov %[rhs_ptr], r5\n" + + // Load some parameters needed for the end work on current block. + "ldrb r4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" + + // Offset these base pointers as needed given the current row, col. + "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + "add r5, r1, r8, lsl #2\n" + + "tst r4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" + "it ne\n" + "movne r1, r5\n" + + // Load 8 bias values. + "vld1.32 {d24, d25, d26, d27}, [r1]\n" + + // Now that we know what LHS and RHS data the next iteration of the + // main loop will need to load, we start loading the first 32 bytes of + // each of LHS and RHS, into q0 -- q2, as we don't need q0 -- q2 anymore + // in the rest of the work on the current block. + // Load q0, q1 + "vld1.32 {d0, d1, d2, d3}, [%[lhs_ptr]]!\n" + RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n") + // Load q2 + "vld1.32 {d4, d5}, [%[rhs_ptr]]!\n" + RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n") + + // Perform the bias-addition (per the above, we have just folded into + // the bias the (depth * lhs_zero_point * rhs_zero_point) term.) + "vadd.f32 q3, q3, q12\n" + "vadd.f32 q4, q4, q13\n" + "vadd.f32 q5, q5, q12\n" + "vadd.f32 q6, q6, q13\n" + "vadd.f32 q7, q7, q12\n" + "vadd.f32 q8, q8, q13\n" + "vadd.f32 q9, q9, q12\n" + "vadd.f32 q10, q10, q13\n" + + // Load the clamp_min, clamp_max bounds + "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + "vdup.32 q12, r2\n" // clamp_min + "vdup.32 q13, r3\n" // clamp_max + + // Apply the clamp_min bound + "vmax.f32 q3, q3, q12\n" + "vmax.f32 q4, q4, q12\n" + "vmax.f32 q5, q5, q12\n" + "vmax.f32 q6, q6, q12\n" + "vmax.f32 q7, q7, q12\n" + "vmax.f32 q8, q8, q12\n" + "vmax.f32 q9, q9, q12\n" + "vmax.f32 q10, q10, q12\n" + + // Apply the clamp_max bound + "vmin.f32 q3, q3, q13\n" + "vmin.f32 q4, q4, q13\n" + "vmin.f32 q5, q5, q13\n" + "vmin.f32 q6, q6, q13\n" + "vmin.f32 q7, q7, q13\n" + "vmin.f32 q8, q8, q13\n" + "vmin.f32 q9, q9, q13\n" + "vmin.f32 q10, q10, q13\n" + + // Compute how much of the 8x4 block of destination values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 8x4, there are some 8x8 blocks along the boundaries that do + // not fit entirely. + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n" + "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + "sub r1, r1, r8\n" + + "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n" + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" + "sub r2, r2, r4\n" + "mov r3, #8\n" + "mov r5, #4\n" + "cmp r1, #8\n" + // Compute r1 = how many rows of the 8x4 block fit + "it gt\n" + "movgt r1, r3\n" + "cmp r2, #4\n" + // Compute r2 = how many cols of the 8x4 block fit + "it gt\n" + "movgt r2, r5\n" + + // Test if r1==8 && r2 == 4, i.e. if all of the 8x4 block fits. + "cmp r1, r3\n" + "it eq\n" + "cmpeq r2, r5\n" + // Yes, all of the 8x4 block fits, go to fast path. + "beq 30f\n" + // Not all of the 8x4 block fits. + // Set (r3 address, r4 stride) to write to dst_tmp_buf + "mov r3, %[dst_tmp_buf]\n" + "mov r4, #32\n" + "b 31f\n" + "30:\n" + // Yes, all of the 8x4 block fits. + // Set (r3 address, r4 stride) to write directly to destination matrix. + "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" + "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "mov r4, r5\n" + "31:\n" + + // Write our float values to the destination described by + // (r3 address, r4 stride) + "vst1.32 {d6, d7, d8, d9}, [r3]\n" + "add r3, r3, r4\n" + RUY_MAKE_ZERO(q3) + RUY_MAKE_ZERO(q4) + "vst1.32 {d10, d11, d12, d13}, [r3]\n" + "add r3, r3, r4\n" + RUY_MAKE_ZERO(q5) + RUY_MAKE_ZERO(q6) + "vst1.32 {d14, d15, d16, d17}, [r3]\n" + "add r3, r3, r4\n" + RUY_MAKE_ZERO(q7) + RUY_MAKE_ZERO(q8) + "vst1.32 {d18, d19, d20, d21}, [r3]\n" + "add r3, r3, r4\n" + RUY_MAKE_ZERO(q9) + RUY_MAKE_ZERO(q10) + + // If all of the 8x4 block fits, we just finished writing it to the + // destination, so we skip the next part. + "beq 41f\n" + // Not all of the 8x8 block fits in the destination matrix. We just + // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over + // it to copy into the destination matrix the part that fits. + "ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" + "mov r3, %[dst_tmp_buf]\n" + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "mov r6, #0\n" + "50:\n" + "mov r5, #0\n" + "51:\n" + "ldr r10, [r3, r5, lsl #2]\n" + "str r10, [r4, r5, lsl #2]\n" + "add r5, r5, #1\n" + "cmp r5, r1\n" + "blt 51b\n" + "add r6, r6, #1\n" + "add r3, r3, #32\n" + "add r4, r4, r8\n" + // r2 = how many cols of the 8x4 block fit + "cmp r6, r2\n" + "blt 50b\n" + "41:\n" + // Load dst_ptr, increment, and write back. + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "add r4, r4, #32\n" + "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + // At this point we have completely finished writing values to the + // destination matrix for the current block. + + // Reload some params --- we had used r3, r5, r10 for a few other things + // since the last time we had loaded them. + "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + "ldr r6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" + "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" + + // Move to the next block of the destination matrix, for the next iter + // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already + // been updated earlier. + // Have we reached the end row? + "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + "cmp r8, r3\n" + + "beq 20f\n" // yes, end row. + // Not end row. Move to the next row. + "add r8, r8, #8\n" + // Store new value of row + "str r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + + "b 21f\n" + "20:\n" + // Was already at end row. + // Move back to first row. + "str r6, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + // Move to the next column. + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" + "add r4, r4, #4\n" + "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" + + "ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" + "ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n" + // Increment dst_col_ptr by 4 * dst_stride (i.e. 4 columns) + "add r1, r1, r8, lsl #2\n" + // Store dst_col_ptr + "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n" + // Store dst_ptr + "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "21:\n" + + // Main loop exit condition: have we hit the end column? + "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" + "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" + "cmp r8, r4\n" + + // r1 is the number of levels of depth that we have already loaded + // LHS and RHS data for. Corresponding to the initial ld1 instructions + // above, this is currently 1. + "mov r1, #1\n" + + "ble 1b\n" + + // Restore stack pointer. + "add sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n" + + // clang-format on + : [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr) + : [ params ] "r"(¶ms), [dst_tmp_buf] "r"(params.dst_tmp_buf) + // Clobber list must specify q registers (and not their constituent + // d registers). There is a (currently unexplained) slowdown if + // d registers are listed in the clobbers list. + : "r0", "r1", "r2", "r3", "r4", "r5", "r6", "r8", "r10", "cc", + "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", + "q9", "q10", "q12", "q13"); +} + +#undef RUY_MAKE_ZERO +#undef RUY_STACK_OFFSET_SIZE +#undef RUY_STACK_OFFSET_DST_COL_PTR +#undef RUY_STACK_OFFSET_DST_PTR +#undef RUY_STACK_OFFSET_ROW +#undef RUY_STACK_OFFSET_COL +#undef RUY_STACK_OFFSET_LHS_COL_PTR +#undef RUY_STACK_OFFSET_RHS_COL_PTR + +#undef RUY_OFFSET_LHS_BASE_PTR +#undef RUY_OFFSET_RHS_BASE_PTR +#undef RUY_OFFSET_DST_BASE_PTR +#undef RUY_OFFSET_BIAS +#undef RUY_OFFSET_START_ROW +#undef RUY_OFFSET_START_COL +#undef RUY_OFFSET_LAST_ROW +#undef RUY_OFFSET_LAST_COL +#undef RUY_OFFSET_DST_ROWS +#undef RUY_OFFSET_DST_COLS +#undef RUY_OFFSET_LHS_STRIDE +#undef RUY_OFFSET_RHS_STRIDE +#undef RUY_OFFSET_DST_STRIDE +#undef RUY_OFFSET_DEPTH +#undef RUY_OFFSET_CLAMP_MIN +#undef RUY_OFFSET_CLAMP_MAX +#undef RUY_OFFSET_FLAGS + +#define RUY_OFFSET_BIAS 0 +#define RUY_OFFSET_LHS_SUMS 4 +#define RUY_OFFSET_RHS_SUMS 8 +#define RUY_OFFSET_LHS_BASE_PTR 12 +#define RUY_OFFSET_MULTIPLIER_FIXEDPOINT 16 +#define RUY_OFFSET_MULTIPLIER_EXPONENT 20 +#define RUY_OFFSET_RHS_BASE_PTR 24 +#define RUY_OFFSET_DST_BASE_PTR 28 +#define RUY_OFFSET_LHS_ZERO_POINT 32 +#define RUY_OFFSET_RHS_ZERO_POINT 36 +#define RUY_OFFSET_DST_ZERO_POINT 40 +#define RUY_OFFSET_PROD_ZP_DEPTH 44 +#define RUY_OFFSET_START_ROW 48 +#define RUY_OFFSET_START_COL 52 +#define RUY_OFFSET_LAST_ROW 56 +#define RUY_OFFSET_LAST_COL 60 +#define RUY_OFFSET_DST_ROWS 64 +#define RUY_OFFSET_DST_COLS 68 +#define RUY_OFFSET_LHS_STRIDE 72 +#define RUY_OFFSET_RHS_STRIDE 76 +#define RUY_OFFSET_DST_STRIDE 80 +#define RUY_OFFSET_DEPTH 84 +#define RUY_OFFSET_CLAMP_MIN 88 +#define RUY_OFFSET_CLAMP_MAX 92 +#define RUY_OFFSET_FLAGS 96 +#define RUY_OFFSET_DST_TYPE_ID 97 + +#define RUY_STACK_OFFSET_SIZE 96 +#define RUY_STACK_OFFSET_DST_COL_PTR 0 +#define RUY_STACK_OFFSET_DST_PTR 16 +#define RUY_STACK_OFFSET_ROW 32 +#define RUY_STACK_OFFSET_COL 48 +#define RUY_STACK_OFFSET_LHS_COL_PTR 64 +#define RUY_STACK_OFFSET_RHS_COL_PTR 80 + +template +void CheckOffsetsInKernelParams8bit(const Params&) { + static_assert(offsetof(Params, lhs_zero_point) == RUY_OFFSET_LHS_ZERO_POINT, + ""); + static_assert(offsetof(Params, rhs_zero_point) == RUY_OFFSET_RHS_ZERO_POINT, + ""); + static_assert(offsetof(Params, dst_zero_point) == RUY_OFFSET_DST_ZERO_POINT, + ""); + static_assert(offsetof(Params, prod_zp_depth) == RUY_OFFSET_PROD_ZP_DEPTH, + ""); + static_assert(offsetof(Params, multiplier_fixedpoint) == + RUY_OFFSET_MULTIPLIER_FIXEDPOINT, + ""); + static_assert( + offsetof(Params, multiplier_exponent) == RUY_OFFSET_MULTIPLIER_EXPONENT, + ""); + static_assert(offsetof(Params, clamp_min) == RUY_OFFSET_CLAMP_MIN, ""); + static_assert(offsetof(Params, clamp_max) == RUY_OFFSET_CLAMP_MAX, ""); + static_assert(offsetof(Params, bias) == RUY_OFFSET_BIAS, ""); + static_assert(offsetof(Params, lhs_sums) == RUY_OFFSET_LHS_SUMS, ""); + static_assert(offsetof(Params, rhs_sums) == RUY_OFFSET_RHS_SUMS, ""); + static_assert(offsetof(Params, flags) == RUY_OFFSET_FLAGS, ""); + static_assert(offsetof(Params, lhs_base_ptr) == RUY_OFFSET_LHS_BASE_PTR, ""); + static_assert(offsetof(Params, start_row) == RUY_OFFSET_START_ROW, ""); + static_assert(offsetof(Params, last_row) == RUY_OFFSET_LAST_ROW, ""); + static_assert(offsetof(Params, last_col) == RUY_OFFSET_LAST_COL, ""); + static_assert(offsetof(Params, lhs_stride) == RUY_OFFSET_LHS_STRIDE, ""); + static_assert(offsetof(Params, rhs_stride) == RUY_OFFSET_RHS_STRIDE, ""); + static_assert(offsetof(Params, dst_stride) == RUY_OFFSET_DST_STRIDE, ""); + static_assert(offsetof(Params, depth) == RUY_OFFSET_DEPTH, ""); +} + +// Fast-int8 kernel, ported from ARM 64 version. +// Relevant target CPUs for this kernel include Krait 400 and A9, +// since these are 32-bit, out-of-order CPUs. +void Kernel8bitNeonOutOfOrder(const KernelParams8bit<4, 2>& params) { + profiler::ScopeLabel label( + "Kernel (kNeon, optimized for out-of-order cores)"); + + CheckOffsetsInKernelParams8bit(params); + + const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; + const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; + const std::int8_t* lhs_ptr = lhs_col_ptr; + const std::int8_t* rhs_ptr = rhs_col_ptr; + + // The asm kernel below has the following NEON register allocation: + // + // q6 - q13 are 128-bit (4x32b) accumulators. + // During accumulation, d0 -- d7 are used to load int8 data from LHS and + // d8 -- d11 from RHS: + // int8 RHS 16x2 block + // /-----------------------------\ + // |d8.b[0-7] ..... d10.b[0-7]| + // | ... ... | + // |d9.b[0-7] ..... d11.b[0-7]| + // \-----------------------------/ + // int8 LHS 4x16 block + // /------------------------\ /-----------------------------\ + // |d0.b[0-7] ... d1.b[0-7] | | q6 ..... q10 | + // |d2.b[0-7] ... d3.b[0-7] | | q7 ..... q11 | + // (Reload d0, d1, d2, d3) + // |d0.b[0-7] ... d1.b[0-7] | | q8 ..... q12 | + // |d2.b[0-7] ... d3.b[0-7] | | q9 ..... q13 | + // \------------------------/ \-----------------------------/ + // 128-bit accumulators 4x2 block + // + // No attempt had been made so far at implementing the RUY_OPT_MAX_STREAMING + // optimization for this kernel. + asm volatile( +#define RUY_MAKE_ZERO(reg) "vmov.i32 " #reg ", #0x00000000\n" + + // clang-format off + + // Load the first 64 bytes of LHS and RHS data. + "vld1.8 {d0, d1}, [%[lhs_ptr]]!\n" + // Clear accumulators. + RUY_MAKE_ZERO(q6) + "vld1.8 {d2, d3}, [%[lhs_ptr]]!\n" + RUY_MAKE_ZERO(q8) + RUY_MAKE_ZERO(q9) + RUY_MAKE_ZERO(q10) + "vld1.8 {d8, d9}, [%[rhs_ptr]]!\n" + RUY_MAKE_ZERO(q11) + "vld1.8 {d10, d11}, [%[rhs_ptr]]!\n" + + "sub sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n" + + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n" + RUY_MAKE_ZERO(q12) + "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n" + + "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n" + RUY_MAKE_ZERO(q13) + "str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" + RUY_MAKE_ZERO(q14) + "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + + "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_START_COL) "]\n" + RUY_MAKE_ZERO(q15) + "str r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" + + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" + + "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_RHS_BASE_PTR) "]\n" + "str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" + + + // r1 is the number of levels of depth that we have already loaded + // LHS and RHS data for. Corresponding to the initial ld1 instructions + // above, this is currently 16. + "mov r1, #16\n" + + // Main loop of the whole GEMM, over rows and columns of the + // destination matrix. + "1:\n" + + // r1 is how many levels of depth we have already loaded + // data for, r10 is the total depth. + "ldr r10, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" + "cmp r1, r10\n" + "beq 79f\n" + + "2:\n" + + // Mult, mult-acc in to q14, q15, q2, q3 + "vmull.s8 q14, d0, d8\n" + "vmull.s8 q2, d0, d10\n" + + "vmull.s8 q15, d2, d8\n" + "vmull.s8 q3, d2, d10\n" + + "vmlal.s8 q14, d1, d9\n" + "vmlal.s8 q2, d1, d11\n" + "vmlal.s8 q15, d3, d9\n" + "vmlal.s8 q3, d3, d11\n" + "vld1.8 {d0, d1, d2, d3}, [%[lhs_ptr]]!\n" // Reload LHS + + // Then pairwise accumulate in to q6, q7, q10, q11 + "vpadal.s16 q6, q14\n" + "vpadal.s16 q7, q15\n" + "vpadal.s16 q10, q2\n" + "vpadal.s16 q11, q3\n" + + // Mult, mult-acc in to q14, q15, q2, q3 + "vmull.s8 q14, d0, d8\n" + "vmull.s8 q2, d0, d10\n" + + "vmull.s8 q15, d2, d8\n" + "vmull.s8 q3, d2, d10\n" + + "vmlal.s8 q14, d1, d9\n" + "vmlal.s8 q2, d1, d11\n" + "vmlal.s8 q15, d3, d9\n" + "vmlal.s8 q3, d3, d11\n" + "vld1.8 {d0, d1, d2, d3}, [%[lhs_ptr]]!\n" // Reload LHS + + // Then pairwise accumulate in to q8, q9, q12, q13 + "vpadal.s16 q8, q14\n" + "vld1.8 {d8, d9, d10, d11}, [%[rhs_ptr]]!\n" + "vpadal.s16 q9, q15\n" + "vpadal.s16 q12, q2\n" + "vpadal.s16 q13, q3\n" + + // Prefetch the next 64 bytes of LHS and RHS data. + RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n") + RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n") + + // Each iteration of this loop advances by 16 levels of depth. + "add r1, r1, #16\n" + + // Loop termination condition + "cmp r1, r10\n" + + "blt 2b\n" + + "79:\n" + + // Mult, mult-acc in to q14, q15, q2, q3 + "vmull.s8 q14, d0, d8\n" + "vmull.s8 q2, d0, d10\n" + + "vmull.s8 q15, d2, d8\n" + "vmull.s8 q3, d2, d10\n" + + "vmlal.s8 q14, d1, d9\n" + "vmlal.s8 q2, d1, d11\n" + "vmlal.s8 q15, d3, d9\n" + "vmlal.s8 q3, d3, d11\n" + "vld1.8 {d0, d1, d2, d3}, [%[lhs_ptr]]!\n" // Reload LHS + + // Then pairwise accumulate in to q6, q7, q10, q11 + "vpadal.s16 q6, q14\n" + "vpadal.s16 q7, q15\n" + "vpadal.s16 q10, q2\n" + "vpadal.s16 q11, q3\n" + + // Mult, mult-acc in to q14, q15, q2, q3 + "vmull.s8 q14, d0, d8\n" + "vmull.s8 q2, d0, d10\n" + + "vmull.s8 q15, d2, d8\n" + "vmull.s8 q3, d2, d10\n" + + "vmlal.s8 q14, d1, d9\n" + "vmlal.s8 q2, d1, d11\n" + "vmlal.s8 q15, d3, d9\n" + "vmlal.s8 q3, d3, d11\n" + + // Then pairwise accumulate in to q8, q9, q12, q13 + "vpadal.s16 q8, q14\n" + "vpadal.s16 q9, q15\n" + "vpadal.s16 q12, q2\n" + "vpadal.s16 q13, q3\n" + + + // All accumulation over depth done. q6 - q13 contain the 4x32b + // accumulators for the 4x2 final matrix. + // We now have to compute the final 8-bit values from these int32 + // accumulators, and advance to the next 4x2 block. We intertwine + // these two aspects whenever possible for optimal pipelining, both + // at the data flow level (prefetch data for next block as early as + // possible) and instruction pipelining level (some of the next-block + // work can dual-issue with some of the final work on the current + // block). + + // q6-q13 now contain 4 x 32b + "vpadd.i32 d0, d12, d13\n" + "vpadd.i32 d1, d14, d15\n" + "vpadd.i32 d2, d16, d17\n" + "vpadd.i32 d3, d18, d19\n" + "vpadd.i32 d4, d20, d21\n" + "vpadd.i32 d5, d22, d23\n" + "vpadd.i32 d6, d24, d25\n" + "vpadd.i32 d7, d26, d27\n" + + // d0-d7 each contain 2 x 32b accumulators. + // Need to add pairwise to get 1 x 32b for each of the 4x2 entries + // of destination, (Four 'd' registers total) + "vpadd.i32 d28, d0, d1\n" + "vpadd.i32 d29, d2, d3\n" + "vpadd.i32 d30, d4, d5\n" + "vpadd.i32 d31, d6, d7\n" + + //Now d28 - d31 have the 1 x 32b accumulators for the 4x2 entries + + // Logic to advance to the next block in preparation for the next + // iteration of the main loop. For now, we only want to compute + // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are + // not yet ready to update the values of row and col, as we still need + // the current values for the rest of the work on the current block. + + "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" + "ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + "cmp r1, r3\n" // Have we finished the last row? + + "bge 4f\n" // If finished last row, go to 4 + // Not finished last row: then advance to next row. + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" + "add r4, r4, r1, lsl #2\n" + "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" + "b 5f\n" + "4:\n" // Finished last row... + "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + // Go back to first row + "str r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" + + // Now we need to advance to the next column. If we already + // finished the last column, then in principle we are done, however + // we can't just return here, as we need to allow the end work of the + // current block to complete. The good news is that at this point it + // doesn't matter what data we load for the next column, since + // we will exit from the main loop below before actually storing + // anything computed from that data. + + "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" + "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" + "cmp r8, r4\n" // Have we finished the last column? + "bge 5f\n" // If yes, just carry on without updating the column pointer. + // Not finished last column: then advance to next column. + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" + "ldr r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" + "add r10, r10, r1, lsl #1\n" + "str r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" + "5:\n" + + // Set the LHS and RHS data pointers to the start of the columns just + // computed. + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" + "mov %[lhs_ptr], r4\n" + "ldr r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" + "mov %[rhs_ptr], r5\n" + + // Now we load: bias data, LHS sums data, RHS sums data. + + // First, load the base pointers from the params. + "ldrb r4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" + + // Offset these base pointers as needed given the current row, col. + "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + "add r5, r1, r8, lsl #2\n" + + "tst r4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" + "it ne\n" + "movne r1, r5\n" + + // Load 4 bias values. + "vld1.32 {d24, d25}, [r1]\n" + + // Now that we know what LHS and RHS data the next iteration of the + // main loop will need to load, we start loading the first 32 bytes of + // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore + // in the rest of the work on the current block. + "vld1.8 {d0, d1, d2, d3}, [%[lhs_ptr]]!\n" + RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n") + "vld1.8 {d8, d9, d10, d11}, [%[rhs_ptr]]!\n" + RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n") + + // Add to the bias values the product + // (depth * lhs_zero_point * rhs_zero_point), + // See the term NZ1Z2 in equation (7) in + // https://arxiv.org/pdf/1712.05877.pdf + "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n" + "vdup.32 q9, r3\n" + "vadd.i32 q12, q12, q9\n" + + // Perform the bias-addition (per the above, we have just folded into + // the bias the (depth * lhs_zero_point * rhs_zero_point) term.) + "vadd.i32 q14, q14, q12\n" + "vadd.i32 q15, q15, q12\n" + + // LHS/RHS zero points + // Has RHS sums + "ldrb r6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" + "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n" + "beq 401f\n" + "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n" + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" + // Offset by current col * number of bytes per value + "add r3, r3, r4, lsl #2\n" + "vld1.32 { d12 }, [r3]\n" + "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n" + "vdup.32 q10, r5\n" // create lhs_zero_point_vec + // Subtract rhs_sums * lhs_zero_point, per + // equation (7) in https://arxiv.org/pdf/1712.05877.pdf + "vmls.i32 q14, q10, d12[0]\n" + "vmls.i32 q15, q10, d12[1]\n" + "401:\n" + + // Has LHS sums + "ldrb r6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" + "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n" + "beq 402f\n" + "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n" + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + // Offset by current row * number of bytes per value + "add r2, r2, r4, lsl #2\n" + "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n" + + // Load 4 lhs_sums values. + "vld1.32 {d22, d23}, [r2]\n" + "vdup.32 d13, r5\n" // rhs_zero_point + + // Compute lhs_sums * rhs_zero_point. + "vmul.i32 q11, q11, d13[1]\n" + // Subtract lhs_sums * rhs_zero_point, per + // equation (7) in https://arxiv.org/pdf/1712.05877.pdf + "vsub.s32 q14, q14, q11\n" + "vsub.s32 q15, q15, q11\n" + + // If the destination is int32, it means the user asks for the raw + // accumulators, no need for us to downquantize the value. + "ldrb r10, [%[params], #" RUY_STR(RUY_OFFSET_DST_TYPE_ID) "]\n" + "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n" + "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n" + + "402:\n" + + // At this point we have computed the final int32 values. Now we + // start down-quantizing them to obtain the final 8bit values from them. + + // As part of this down-quantization, our int32 values will be + // multiplied by a multiplier that has a fixed-point component and an + // exponent component. + + //Load the exponent part of the multiplier. + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n" + "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + "add r5, r1, r4, lsl #2\n" + "it ne\n" + "movne r1, r5\n" + + "vld1.32 {q10}, [r1]\n" + + RUY_MAKE_ZERO(q8) + "vmax.s32 q12, q10, q8\n" + + "vshl.s32 q14, q14, q12\n" + "vshl.s32 q15, q15, q12\n" + + "vmin.s32 q12, q10, q8\n" + + // Load fixed point part of the multiplier + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n" + // r6 has flags, r4 has row + "add r5, r1, r4, lsl #2\n" + "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" + "it ne\n" + "movne r1, r5\n" + "vld1.32 {q10}, [r1]\n" // multiplier_fixedpoint + + // Apply the fixed-point part of the multiplier. + "vqrdmulh.s32 q14, q14, q10\n" + "vqrdmulh.s32 q15, q15, q10\n" + + // We have some rounding division-by-power-of-two to do. This should + // always use "round to nearest". We allow for some + // freedom in how ties are broken, to strike a good compromise of + // performance on given hardware vs. perfect agreement of results + // across hardware. + // + // When RUY_OPT_NATIVE_ROUNDING is enabled, we allow for implementation + // defined tie-breaks to help performance. On NEON, this means that we + // can just use the NEON rounding instructions, such as srshl. They + // happen to be breaking ties upward. + // + // When RUY_OPT_NATIVE_ROUNDING is disabled, we implement strict + // break-ties-away-from zero, as described in Appendix B of + // https://arxiv.org/pdf/1712.05877.pdf + // When we wrote that, we thought that that would be better unbiased + // than the NEON upwards tie-breaks, and we had observed some + // improvement on some model. However, that is only more unbiased for + // data centered at zero, which was likely the case in that model, + // but is not always the case. If we wanted something more consistently + // unbiased then we should try breaking ties toward-nearest-even. +#if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING) + // Fix up values to be right-shifted, so that the (round to nearest, + // break ties upward) behavior of srshl applied to these fixed-up + // values, produces the same result as the desired (round to nearest, + // break ties away from zero) behavior on the original values. + "vand q8, q14, q12\n" + "vand q9, q15, q12\n" + "vshr.s32 q8, q8, #31\n" + "vshr.s32 q9, q9, #31\n" + "vqadd.s32 q14, q14, q8\n" + "vqadd.s34 q15, q15, q9\n" + +#endif + // At this point we have reduced the problem of correctly implementing + // rounding divide-by-power-of-two, to what the SRSHL instruction can + // do. + "vrshl.s32 q14, q14, q12\n" + "vrshl.s32 q15, q15, q12\n" + + "ldrb r10, [%[params], #" RUY_STR(RUY_OFFSET_DST_TYPE_ID) "]\n" + "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n" + "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n" + "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n" + "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n" + + // Store uint8 values: + RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n" + + // Cast-and-saturate from int32 to int16 + // After this, all values for output are in q14. + "vqmovn.s32 d28, q14\n" + "vqmovn.s32 d29, q15\n" + + // At this point, d12 -- d26, d30, d31 aren't used anymore for the + // current block, so we can start clearing these accumulators for the + // next block (next iteration of the main loop). + RUY_MAKE_ZERO(q6) + RUY_MAKE_ZERO(q7) + RUY_MAKE_ZERO(q8) + RUY_MAKE_ZERO(q9) + RUY_MAKE_ZERO(q10) + RUY_MAKE_ZERO(q11) + RUY_MAKE_ZERO(q12) + RUY_MAKE_ZERO(q13) + RUY_MAKE_ZERO(q15) + + // Load the destination zero point into each of the 8 16-bit slots + // in a q register. + "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" + "vdup.16 q13, r4\n" // dst_zero_point + + // Add the destination zero point + "vadd.i16 q14, q14, q13\n" + + // Cast-and-saturate from int16 to uint8 + // Now all 8 1-byte values are in d30. + "vqmovun.s16 d30, q14\n" + + // Load the clamp_min, clamp_max bounds + "ldrb r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + "ldrb r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + "vdup.8 d28, r2\n" // clamp_min + "vdup.8 d29, r3\n" // clamp_max + + // Apply the clamp_min bound + "vmax.u8 d30, d30, d28\n" + // Apply the clamp_max bound + "vmin.u8 d30, d30, d29\n" + + // Compute how much of the 4x2 block of destination 8bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 4x2, there are some 4x2 blocks along the boundaries that do + // not fit entirely. + + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n" + "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + "sub r1, r1, r8\n" + + "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n" + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" + "sub r2, r2, r4\n" + "mov r3, #4\n" + "mov r5, #2\n" + "cmp r1, #4\n" + // Compute r1 = how many rows of the 4x2 block fit + "it gt\n" + "movgt r1, r3\n" + + "cmp r2, #2\n" + // Compute r2 = how many cols of the 4x2 block fit + "it gt\n" + "movgt r2, r5\n" + + // Test if r1==4 && r2 == 2, i.e. if all of the 4x2 block fits. + "cmp r1, r3\n" + "it eq\n" + "cmpeq r2, r5\n" + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" + // Yes, all of the 4x2 block fits, go to fast path. + "beq 30f\n" + // Not all of the 4x2 block fits. + // Store to dst_tmp_buf + // Set r3 address to write to dst_tmp_buf. + "mov r3, %[dst_tmp_buf]\n" + "vst1.8 {d30}, [r3]\n" + + // Slow loop copying from dst_tmp_buf to dst. + "mov r6, #0\n" + "50:\n" + "mov r8, #0\n" + "51:\n" + "ldrb r10, [r3, r8]\n" + "strb r10, [r4, r8]\n" + "add r8, r8, #1\n" + "cmp r8, r1\n" + "blt 51b\n" + "add r6, r6, #1\n" + "add r3, r3, #4\n" + "add r4, r4, r5\n" + "cmp r6, r2\n" + "blt 50b\n" + "b 31f\n" + "30:\n" + // Yes, all of the 4x2 block fits. + // r3 address, r5 stride + "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "mov r4, r3\n" + "mov r6, #1\n" + + "vst1.32 {d30[0]}, [r3]\n" + "add r4, r4, r5\n" + "mov r3, r4\n" + "vst1.32 {d30[1]}, [r3]\n" + + "31:\n" + + // Load dst_ptr, increment, and write back. + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "add r4, r4, #4\n" + "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + + RUY_MAKE_ZERO(q13) + RUY_MAKE_ZERO(q14) + RUY_MAKE_ZERO(q15) + + "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" + + // Store int8 values: + RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n" + + // Cast-and-saturate from int32 to int16 + // After this, all values for output are in q14. + "vqmovn.s32 d28, q14\n" + "vqmovn.s32 d29, q15\n" + + // At this point, d12 -- d26, d30, d31 aren't used anymore for the + // current block, so we can start clearing these accumulators for the + // next block (next iteration of the main loop). + RUY_MAKE_ZERO(q6) + RUY_MAKE_ZERO(q7) + RUY_MAKE_ZERO(q8) + RUY_MAKE_ZERO(q9) + RUY_MAKE_ZERO(q10) + RUY_MAKE_ZERO(q11) + RUY_MAKE_ZERO(q12) + RUY_MAKE_ZERO(q13) + RUY_MAKE_ZERO(q15) + + // Load the destination zero point into each of the 8 16-bit slots + // in a q register. + "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" + "vdup.16 q13, r4\n" // dst_zero_point + + // Add the destination zero point + "vadd.i16 q14, q14, q13\n" + + // Cast-and-saturate from int16 to int8 + // Now all 8 1-byte values are in d30. + "vqmovn.s16 d30, q14\n" + + // Load the clamp_min, clamp_max bounds + "ldrb r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + "ldrb r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + "vdup.8 d28, r2\n" // clamp_min + "vdup.8 d29, r3\n" // clamp_max + + // Apply the clamp_min bound + "vmax.s8 d30, d30, d28\n" + // Apply the clamp_max bound + "vmin.s8 d30, d30, d29\n" + + // Compute how much of the 4x2 block of destination 8bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 4x2, there are some 4x2 blocks along the boundaries that do + // not fit entirely. + + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n" + "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + "sub r1, r1, r8\n" + + "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n" + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" + "sub r2, r2, r4\n" + "mov r3, #4\n" + "mov r5, #2\n" + "cmp r1, #4\n" + // Compute r1 = how many rows of the 4x2 block fit + "it gt\n" + "movgt r1, r3\n" + + "cmp r2, #2\n" + // Compute r2 = how many cols of the 4x2 block fit + "it gt\n" + "movgt r2, r5\n" + + // Test if r1==4 && r2 == 2, i.e. if all of the 4x2 block fits. + "cmp r1, r3\n" + "it eq\n" + "cmpeq r2, r5\n" + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" + // Yes, all of the 4x2 block fits, go to fast path. + "beq 30f\n" + // Not all of the 4x2 block fits. + // Store to dst_tmp_buf + // Set r3 address to write to dst_tmp_buf. + "mov r3, %[dst_tmp_buf]\n" + "vst1.8 {d30}, [r3]\n" + + // Slow loop copying from dst_tmp_buf to dst. + "mov r6, #0\n" + "50:\n" + "mov r8, #0\n" + "51:\n" + "ldrb r10, [r3, r8]\n" + "strb r10, [r4, r8]\n" + "add r8, r8, #1\n" + "cmp r8, r1\n" + "blt 51b\n" + "add r6, r6, #1\n" + "add r3, r3, #4\n" + "add r4, r4, r5\n" + "cmp r6, r2\n" + "blt 50b\n" + "b 31f\n" + "30:\n" + // Yes, all of the 4x2 block fits. + // r3 address, r5 stride + "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "mov r4, r3\n" + "mov r6, #1\n" + + "vst1.32 {d30[0]}, [r3]\n" + "add r4, r4, r5\n" + "mov r3, r4\n" + "vst1.32 {d30[1]}, [r3]\n" + + "31:\n" + + // Load dst_ptr, increment, and write back. + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "add r4, r4, #4\n" + "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + + RUY_MAKE_ZERO(q13) + RUY_MAKE_ZERO(q14) + RUY_MAKE_ZERO(q15) + + "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n" + + // Load the destination zero point into each of the 4 32-bit slots + // in a q register. + "ldrsh r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" + "vdup.32 q13, r4\n" // dst_zero_point + // Add the destination zero point + "vadd.s32 q14, q14, q13\n" + "vadd.s32 q15, q15, q13\n" + + // Cast-and-saturate from int32 to int16 + // After this, all values for output are in q14. + "vqmovn.s32 d28, q14\n" + "vqmovn.s32 d29, q15\n" + + // At this point, v18 -- v31 aren't used anymore for the current block, + // so we can start clearing these accumulators for the next block + // (next iteration of the main loop). + RUY_MAKE_ZERO(q6) + RUY_MAKE_ZERO(q7) + RUY_MAKE_ZERO(q8) + RUY_MAKE_ZERO(q9) + RUY_MAKE_ZERO(q10) + RUY_MAKE_ZERO(q11) + RUY_MAKE_ZERO(q15) + + // Load the clamp_min, clamp_max bounds + "ldrh r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + "ldrh r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + "vdup.16 q12, r2\n" // clamp_min + "vdup.16 q13, r3\n" // clamp_max + + // Apply the clamp_min bound + "vmax.s16 q14, q14, q12\n" + // Apply the clamp_max bound + "vmin.s16 q14, q14, q13\n" + + RUY_MAKE_ZERO(q12) + RUY_MAKE_ZERO(q13) + + // Compute how much of the 4x2 block of destination 16-bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 4x2, there are some 4x2 blocks along the boundaries that do + // not fit entirely. + + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n" + "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + "sub r1, r1, r8\n" + + "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n" + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" + "sub r2, r2, r4\n" + "mov r3, #4\n" + "mov r5, #2\n" + "cmp r1, #4\n" + // Compute r1 = how many rows of the 4x2 block fit + "it gt\n" + "movgt r1, r3\n" + + "cmp r2, #2\n" + // Compute r2 = how many cols of the 4x2 block fit + "it gt\n" + "movgt r2, r5\n" + + // Test if r1==4 && r2 == 2, i.e. if all of the 4x2 block fits. + "cmp r1, r3\n" + "it eq\n" + "cmpeq r2, r5\n" + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" + // Yes, all of the 4x2 block fits, go to fast path. + "beq 30f\n" + // Not all of the 4x2 block fits. + // Store to dst_tmp_buf + // Set r3 address to write to dst_tmp_buf. + "mov r3, %[dst_tmp_buf]\n" + "vst1.16 {q14}, [r3]\n" + + // Slow loop copying from dst_tmp_buf to dst. + "mov r6, #0\n" + "50:\n" + "mov r8, #0\n" + "51:\n" + // Shift of offset register for half-word loads not allowed in A32, + // so we shift, load/store, then shift back r8. + "lsl r8, r8, #1\n" + "ldrh r10, [r3, r8]\n" + "strh r10, [r4, r8]\n" + "lsr r8, r8, #1\n" + "add r8, r8, #1\n" + "cmp r8, r1\n" + "blt 51b\n" + "add r6, r6, #1\n" + "add r3, r3, #8\n" + "add r4, r4, r5\n" + "cmp r6, r2\n" + "blt 50b\n" + "b 31f\n" + "30:\n" + // Yes, all of the 4x2 block fits. + // r3 address, r5 stride + "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "mov r4, r3\n" + "mov r6, #2\n" + + "vst1.16 {d28[0]}, [r3], r6\n" + "add r4, r4, r5\n" + "vst1.16 {d28[1]}, [r3], r6\n" + "vst1.16 {d28[2]}, [r3], r6\n" + "vst1.16 {d28[3]}, [r3], r6\n" + "mov r3, r4\n" + "vst1.16 {d29[0]}, [r3], r6\n" + "vst1.16 {d29[1]}, [r3], r6\n" + "vst1.16 {d29[2]}, [r3], r6\n" + "vst1.16 {d29[3]}, [r3], r6\n" + "31:\n" + + // Load dst_ptr, increment, and write back. + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "add r4, r4, #8\n" + "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + + RUY_MAKE_ZERO(q14) + + "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n" + + // Since the store type is the same as the accum type, no need for + // downcast. There's also no need for clamp by min/max. + + // At this point, v20 -- v31 aren't used anymore for the current block, + // so we can start clearing these accumulators for the next block + // (next iteration of the main loop). + // Clear accumulators. + RUY_MAKE_ZERO(q6) + RUY_MAKE_ZERO(q7) + RUY_MAKE_ZERO(q8) + RUY_MAKE_ZERO(q9) + RUY_MAKE_ZERO(q10) + RUY_MAKE_ZERO(q11) + RUY_MAKE_ZERO(q12) + RUY_MAKE_ZERO(q13) + + // Compute how much of the 4x2 block of destination 32 bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 4x2, there are some 4x4 blocks along the boundaries that do + // not fit entirely. + + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n" + "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + "sub r1, r1, r8\n" + + "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n" + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" + "sub r2, r2, r4\n" + "mov r3, #4\n" + "mov r5, #2\n" + "cmp r1, #4\n" + // Compute r1 = how many rows of the 4x2 block fit + "it gt\n" + "movgt r1, r3\n" + + "cmp r2, #2\n" + // Compute r2 = how many cols of the 4x2 block fit + "it gt\n" + "movgt r2, r5\n" + + // Test if r1==4 && r2 == 2, i.e. if all of the 4x2 block fits. + "cmp r1, r3\n" + "it eq\n" + "cmpeq r2, r5\n" + // Yes, all of the 4x2 block fits, go to fast path. + "beq 30f\n" + // Not all of the 4x2 block fits. + // Set (r3 address, r4 stride) to write to dst_tmp_buf + "mov r3, %[dst_tmp_buf]\n" + "mov r4, #16\n" + "b 31f\n" + + "30:\n" + // Yes, all of the 4x2 block fits. + "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" + // r3 address, r4 stride + "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "mov r4, r5\n" + + "31:\n" + + "vst1.32 {d28, d29}, [r3]\n" + "add r3, r3, r4\n" + "vst1.32 {d30, d31}, [r3]\n" + + // If all of the 4x2 block fits, we just finished writing it to the + // destination, so we skip the next part. + "beq 41f\n" + // Not all of the 4x2 block fits in the destination matrix. We just + // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over + // it to copy into the destination matrix the part that fits. + "ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" + "mov r3, %[dst_tmp_buf]\n" + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "mov r6, #0\n" + "50:\n" + "mov r5, #0\n" + "51:\n" + "ldr r10, [r3, r5, lsl #2]\n" + "str r10, [r4, r5, lsl #2]\n" + "add r5, r5, #1\n" + "cmp r5, r1\n" + "blt 51b\n" + "add r6, r6, #1\n" + "add r3, r3, #16\n" + "add r4, r4, r8\n" + // r2 = how many cols of the 8x4 block fit + "cmp r6, r2\n" + "blt 50b\n" + + "41:\n" + // Load dst_ptr, increment, and write back. + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "add r4, r4, #16\n" + "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + + RUY_MAKE_ZERO(q10) + RUY_MAKE_ZERO(q11) + + "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" + + RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n" + + // Reload some params --- we had used x5 -- x7 for a few other things + // since the last time we had loaded them. + "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + "ldr r6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" + "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" + + // Move to the next block of the destination matrix, for the next iter + // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already + // been updated earlier. + // Have we reached the end row? + "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + "cmp r8, r3\n" + + "beq 20f\n" // yes, end row. + // Not end row. Move to the next row. + "add r8, r8, #4\n" + // Store new value of row + "str r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + + "b 21f\n" + "20:\n" + // Was already at end row. + // Move back to first row. + "str r6, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + // Move to the next column. + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" + "add r4, r4, #2\n" + "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" + + "ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" + "ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n" + // Increment dst_col_ptr by 2 * dst_stride (i.e. 2 columns) + "add r1, r1, r8, lsl #1\n" + // Store dst_col_ptr + "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n" + // Store dst_ptr + "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "21:\n" + + // Main loop exit condition: have we hit the end column? + "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" + "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" + "cmp r8, r4\n" + + // w1 is the number of levels of depth that we have already loaded + // LHS and RHS data for. Corresponding to the initial ld1 instructions + // above, this is currently 16. + "mov r1, #16\n" + + "ble 1b\n" + + // Restore stack pointer. + "add sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n" + + // clang-format on + + : [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr) + : [ params ] "r"(¶ms), [dst_tmp_buf] "r"(params.dst_tmp_buf) + : "r0", "r1", "r2", "r3", "r4", "r5", "r6", "r8", "r10", "cc", + // Clobber list must specify q registers (and not their constituent + // d registers). There is a (currently unexplained) slowdown if + // d registers are listed in the clobbers list. + "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", + "q9", "q10", "q12", "q13", "q14", "q15"); +} + +// Fast-int8 true "GEMV" kernel (RHS has 1 column). We assume the RHS +// is still packed as if it has two columns +void Kernel8bitNeonOutOfOrder1Col(const KernelParams8bit<4, 2>& params) { + profiler::ScopeLabel label( + "Kernel (kNeon, optimized for out-of-order cores)"); + + CheckOffsetsInKernelParams8bit(params); + + const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; + const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; + const std::int8_t* lhs_ptr = lhs_col_ptr; + const std::int8_t* rhs_ptr = rhs_col_ptr; + + // The asm kernel below has the following NEON register allocation: + // + // q6 - q13 are 128-bit (4x32b) accumulators. + // During accumulation, d0 -- d7 are used to load int8 data from LHS and + // d8 -- d11 from RHS: + // int8 RHS 16x1 block + // /------------\ + // | d8.b[0] | + // | ... | + // | d8.b[7] | + // | d9.b[0] | + // | ... | + // | d9.b[7] | + // \------------/ + // int8 LHS 4x16 block + // /-----------------------------------------\ /------------\ + // |d0.b[0] ... d0.b[7] d1.b[0] ... d1.b[7] | | q6 | + // |d2.b[0] ... d2.b[7] d3.b[0] ... d3.b[7] | | q7 | + // |d4.b[0] ... d4.b[7] d5.b[0] ... d5.b[7] | | q8 | + // |d6.b[0] ... d6.b[7] d7.b[0] ... d7.b[7] | | q9 | + // \-----------------------------------------/ \------------/ + // 128-bit accumulators 4x1 block + // + // No attempt had been made so far at implementing the RUY_OPT_MAX_STREAMING + // optimization for this kernel. + asm volatile( +#define RUY_MAKE_ZERO(reg) "vmov.i32 " #reg ", #0x00000000\n" + + // clang-format off + + // Load the first 64 bytes of LHS and RHS data. + "vld1.8 {d0, d1}, [%[lhs_ptr]]!\n" + "vld1.8 {d2, d3}, [%[lhs_ptr]]!\n" + "vld1.8 {d4, d5}, [%[lhs_ptr]]!\n" + "vld1.8 {d6, d7}, [%[lhs_ptr]]!\n" + "vld1.8 {d8, d9}, [%[rhs_ptr]]!\n" + // Skip the other column and advance the pointer. + "add %[rhs_ptr], %[rhs_ptr], #16\n" + + "sub sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n" + + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n" + "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n" + + "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_BASE_PTR) "]\n" + "str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" + "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + + "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_START_COL) "]\n" + "str r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" + + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" + + "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_RHS_BASE_PTR) "]\n" + "str r2, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" + + // Clear accumulators. + RUY_MAKE_ZERO(q6) + RUY_MAKE_ZERO(q7) + RUY_MAKE_ZERO(q8) + RUY_MAKE_ZERO(q9) + RUY_MAKE_ZERO(q10) + RUY_MAKE_ZERO(q11) + RUY_MAKE_ZERO(q12) + RUY_MAKE_ZERO(q13) + RUY_MAKE_ZERO(q14) + RUY_MAKE_ZERO(q15) + + // r1 is the number of levels of depth that we have already loaded + // LHS and RHS data for. Corresponding to the initial ld1 instructions + // above, this is currently 16. + "mov r1, #16\n" + + // Main loop of the whole GEMM, over rows and columns of the + // destination matrix. + "1:\n" + + // r1 is how many levels of depth we have already loaded + // data for, r10 is the total depth. + "ldr r10, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" + "cmp r1, r10\n" + "beq 79f\n" + + "2:\n" + + // Mult, mult-acc in to q14, q15 + "vmull.s8 q14, d0, d8\n" + "vmull.s8 q15, d2, d8\n" + "vmlal.s8 q14, d1, d9\n" + "vmlal.s8 q15, d3, d9\n" + + // Then pairwise accumulate in to q6, q7 + "vpadal.s16 q6, q14\n" + "vpadal.s16 q7, q15\n" + + // Mult, mult-acc in to q14, q15 + "vmull.s8 q14, d4, d8\n" + "vmull.s8 q15, d6, d8\n" + "vmlal.s8 q14, d5, d9\n" + "vmlal.s8 q15, d7, d9\n" + + // Then pairwise accumulate in to q8, q9 + "vpadal.s16 q8, q14\n" + "vpadal.s16 q9, q15\n" + + + // Load the next 64 bytes of LHS and RHS data. + "vld1.8 {d0, d1}, [%[lhs_ptr]]!\n" + "vld1.8 {d2, d3}, [%[lhs_ptr]]!\n" + "vld1.8 {d4, d5}, [%[lhs_ptr]]!\n" + "vld1.8 {d6, d7}, [%[lhs_ptr]]!\n" + RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n") + "vld1.8 {d8, d9}, [%[rhs_ptr]]!\n" + // Skip the other column and advance the pointer. + "add %[rhs_ptr], %[rhs_ptr], #16\n" + RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n") + + // Each iteration of this loop advances by 16 levels of depth. + "add r1, r1, #16\n" + + // Loop termination condition + "cmp r1, r10\n" + + "blt 2b\n" + + "79:\n" + + // Mult, mult-acc in to q14, q15 + "vmull.s8 q14, d0, d8\n" + "vmull.s8 q15, d2, d8\n" + "vmlal.s8 q14, d1, d9\n" + "vmlal.s8 q15, d3, d9\n" + + // Then pairwise accumulate in to q6, q7 + "vpadal.s16 q6, q14\n" + "vpadal.s16 q7, q15\n" + + // Mult, mult-acc in to q14, q15 + "vmull.s8 q14, d4, d8\n" + "vmull.s8 q15, d6, d8\n" + "vmlal.s8 q14, d5, d9\n" + "vmlal.s8 q15, d7, d9\n" + + // Then pairwise accumulate in to q8, q9 + "vpadal.s16 q8, q14\n" + "vpadal.s16 q9, q15\n" + + // All accumulation over depth done. q6 - q9 contain the 4x32b + // accumulators for the 4x1 final matrix. + // We now have to compute the final 8-bit values from these int32 + // accumulators, and advance to the next 4x2 block. We intertwine + // these two aspects whenever possible for optimal pipelining, both + // at the data flow level (prefetch data for next block as early as + // possible) and instruction pipelining level (some of the next-block + // work can dual-issue with some of the final work on the current + // block). + + // q6-q9 now contain 4 x 32b + "vpadd.i32 d0, d12, d13\n" + "vpadd.i32 d1, d14, d15\n" + "vpadd.i32 d2, d16, d17\n" + "vpadd.i32 d3, d18, d19\n" + + // d0-d4 each contain 2 x 32b accumulators. + // Need to add pairwise to get 1 x 32b for each of the 4x1 entries + // of destination, (Four 'd' registers total) + "vpadd.i32 d28, d0, d1\n" + "vpadd.i32 d29, d2, d3\n" + + // Now d28,d29 have the 1 x 32b accumulators for the 4x1 entries. + + // Logic to advance to the next block in preparation for the next + // iteration of the main loop. For now, we only want to compute + // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are + // not yet ready to update the values of row and col, as we still need + // the current values for the rest of the work on the current block. + + "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" + "ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + "cmp r1, r3\n" // Have we finished the last row? + + "bge 4f\n" // If finished last row, go to 4 + // Not finished last row: then advance to next row. + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" + "add r4, r4, r1, lsl #2\n" + "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" + "b 5f\n" + "4:\n" // Finished last row... + "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + // Go back to first row + "str r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" + + // Now we need to advance to the next column. If we already + // finished the last column, then in principle we are done, however + // we can't just return here, as we need to allow the end work of the + // current block to complete. The good news is that at this point it + // doesn't matter what data we load for the next column, since + // we will exit from the main loop below before actually storing + // anything computed from that data. + + "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" + "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" + "cmp r8, r4\n" // Have we finished the last column? + "bge 5f\n" // If yes, just carry on without updating the column pointer. + // Not finished last column: then advance to next column. + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" + "ldr r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" + "add r10, r10, r1, lsl #1\n" + "str r10, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" + "5:\n" + + // Set the LHS and RHS data pointers to the start of the columns just + // computed. + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_LHS_COL_PTR) "]\n" + "mov %[lhs_ptr], r4\n" + "ldr r5, [sp, #" RUY_STR(RUY_STACK_OFFSET_RHS_COL_PTR) "]\n" + "mov %[rhs_ptr], r5\n" + + // Now we load: bias data, LHS sums data, RHS sums data. + + // First, load the base pointers from the params. + "ldrb r4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" + + // Offset these base pointers as needed given the current row, col. + "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + "add r5, r1, r8, lsl #2\n" + + "tst r4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" + "it ne\n" + "movne r1, r5\n" + + // Load 4 bias values. + "vld1.32 {d24, d25}, [r1]\n" + + // Now that we know what LHS and RHS data the next iteration of the + // main loop will need to load, we start loading the first 32 bytes of + // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore + // in the rest of the work on the current block. + "vld1.8 {d0, d1}, [%[lhs_ptr]]!\n" + "vld1.8 {d2, d3}, [%[lhs_ptr]]!\n" + "vld1.8 {d4, d5}, [%[lhs_ptr]]!\n" + "vld1.8 {d6, d7}, [%[lhs_ptr]]!\n" + RUY_PREFETCH_LOAD("pld [%[lhs_ptr]]\n") + "vld1.8 {d8, d9}, [%[rhs_ptr]]!\n" + // Skip the other column and advance the pointer. + "add %[rhs_ptr], %[rhs_ptr], #16\n" + RUY_PREFETCH_LOAD("pld [%[rhs_ptr]]\n") + + // Add to the bias values the product + // (depth * lhs_zero_point * rhs_zero_point), + // See the term NZ1Z2 in equation (7) in + // https://arxiv.org/pdf/1712.05877.pdf + "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n" + "vdup.32 q9, r3\n" + "vadd.i32 q12, q12, q9\n" + + // Perform the bias-addition (per the above, we have just folded into + // the bias the (depth * lhs_zero_point * rhs_zero_point) term.) + "vadd.i32 q14, q14, q12\n" + + // LHS/RHS zero points + // Has RHS sums + "ldrb r6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" + "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n" + "beq 401f\n" + "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n" + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" + // Offset by current col * number of bytes per value + "add r3, r3, r4, lsl #2\n" + "vld1.32 { d12 }, [r3]\n" + "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n" + "vdup.32 q10, r5\n" // create lhs_zero_point_vec + // Subtract rhs_sums * lhs_zero_point, per + // equation (7) in https://arxiv.org/pdf/1712.05877.pdf + "vmls.i32 q14, q10, d12[0]\n" + "401:\n" + + // Has LHS sums + "ldrb r6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" + "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n" + "beq 402f\n" + "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n" + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + // Offset by current row * number of bytes per value + "add r2, r2, r4, lsl #2\n" + "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n" + + // Load 4 lhs_sums values. + "vld1.32 {d22, d23}, [r2]\n" + "vdup.32 d13, r5\n" // rhs_zero_point + + // Compute lhs_sums * rhs_zero_point. + "vmul.i32 q11, q11, d13[1]\n" + // Subtract lhs_sums * rhs_zero_point, per + // equation (7) in https://arxiv.org/pdf/1712.05877.pdf + "vsub.s32 q14, q14, q11\n" + + // If the destination is int32, it means the user asks for the raw + // accumulators, no need for us to downquantize the value. + "ldrb r10, [%[params], #" RUY_STR(RUY_OFFSET_DST_TYPE_ID) "]\n" + "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n" + "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n" + + "402:\n" + + // At this point we have computed the final int32 values. Now we + // start down-quantizing them to obtain the final 8bit values from them. + + // As part of this down-quantization, our int32 values will be + // multiplied by a multiplier that has a fixed-point component and an + // exponent component. + + //Load the exponent part of the multiplier. + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n" + "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + "add r5, r1, r4, lsl #2\n" + "it ne\n" + "movne r1, r5\n" + + "vld1.32 {q10}, [r1]\n" + + RUY_MAKE_ZERO(q8) + "vmax.s32 q12, q10, q8\n" + + "vshl.s32 q14, q14, q12\n" + + "vmin.s32 q12, q10, q8\n" + + // Load fixed point part of the multiplier + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n" + // r6 has flags, r4 has row + "add r5, r1, r4, lsl #2\n" + "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" + "it ne\n" + "movne r1, r5\n" + "vld1.32 {q10}, [r1]\n" // multiplier_fixedpoint + + // Apply the fixed-point part of the multiplier. + "vqrdmulh.s32 q14, q14, q10\n" + + // We have some rounding division-by-power-of-two to do. This should + // always use "round to nearest". We allow for some + // freedom in how ties are broken, to strike a good compromise of + // performance on given hardware vs. perfect agreement of results + // across hardware. + // + // When RUY_OPT_NATIVE_ROUNDING is enabled, we allow for implementation + // defined tie-breaks to help performance. On NEON, this means that we + // can just use the NEON rounding instructions, such as srshl. They + // happen to be breaking ties upward. + // + // When RUY_OPT_NATIVE_ROUNDING is disabled, we implement strict + // break-ties-away-from zero, as described in Appendix B of + // https://arxiv.org/pdf/1712.05877.pdf + // When we wrote that, we thought that that would be better unbiased + // than the NEON upwards tie-breaks, and we had observed some + // improvement on some model. However, that is only more unbiased for + // data centered at zero, which was likely the case in that model, + // but is not always the case. If we wanted something more consistently + // unbiased then we should try breaking ties toward-nearest-even. +#if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING) + // Fix up values to be right-shifted, so that the (round to nearest, + // break ties upward) behavior of srshl applied to these fixed-up + // values, produces the same result as the desired (round to nearest, + // break ties away from zero) behavior on the original values. + "vand q8, q14, q12\n" + "vshr.s32 q8, q8, #31\n" + "vqadd.s32 q14, q14, q8\n" + +#endif + // At this point we have reduced the problem of correctly implementing + // rounding divide-by-power-of-two, to what the SRSHL instruction can + // do. + "vrshl.s32 q14, q14, q12\n" + + "ldrb r10, [%[params], #" RUY_STR(RUY_OFFSET_DST_TYPE_ID) "]\n" + "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n" + "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n" + "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n" + "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n" + + // Store uint8 values: + RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n" + + // Cast-and-saturate from int32 to int16 + // After this, all values for output are in d28. + "vqmovn.s32 d28, q14\n" + + // At this point, d12 -- d26, d29, d30, d31 aren't used anymore for the + // current block, so we can start clearing these accumulators for the + // next block (next iteration of the main loop). + RUY_MAKE_ZERO(q6) + RUY_MAKE_ZERO(q7) + RUY_MAKE_ZERO(q8) + RUY_MAKE_ZERO(q9) + RUY_MAKE_ZERO(q10) + RUY_MAKE_ZERO(q11) + RUY_MAKE_ZERO(q12) + RUY_MAKE_ZERO(q13) + RUY_MAKE_ZERO(q15) + + // Load the destination zero point into each of the 8 16-bit slots + // in a q register. + "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" + "vdup.16 q13, r4\n" // dst_zero_point + + // Add the destination zero point + "vadd.i16 q14, q14, q13\n" + + // Cast-and-saturate from int16 to uint8 + "vqmovun.s16 d30, q14\n" + // At this point, we only need 4 8-bit values in the lower half + // of d30. + + + // Load the clamp_min, clamp_max bounds + "ldrb r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + "ldrb r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + "vdup.8 d28, r2\n" // clamp_min + "vdup.8 d29, r3\n" // clamp_max + + // Apply the clamp_min bound + "vmax.u8 d30, d30, d28\n" + // Apply the clamp_max bound + "vmin.u8 d30, d30, d29\n" + + // Compute how much of the 4x1 block of destination 8bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 4x1, there are some 4x1 blocks along the boundaries that do + // not fit entirely. + + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n" + "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + "sub r1, r1, r8\n" + + "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n" + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" + "sub r2, r2, r4\n" + "mov r3, #4\n" + "mov r5, #2\n" + "cmp r1, #4\n" + // Compute r1 = how many rows of the 4x1 block fit + "it gt\n" + "movgt r1, r3\n" + + // Test if r1==4, i.e. if all of the 4x1 block fits. + "cmp r1, r3\n" + + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" + // Yes, all of the 4x1 block fits, go to fast path. + "beq 30f\n" + // Not all of the 4x1 block fits. + // Store to dst_tmp_buf + // Set r3 address to write to dst_tmp_buf. + "mov r3, %[dst_tmp_buf]\n" + "vst1.8 {d30}, [r3]\n" + + // Slow loop copying from dst_tmp_buf to dst. + "50:\n" + "mov r8, #0\n" + "51:\n" + "ldrb r10, [r3, r8]\n" + "strb r10, [r4, r8]\n" + "add r8, r8, #1\n" + "cmp r8, r1\n" + "blt 51b\n" + "b 31f\n" + "30:\n" + // Yes, all of the 4x1 block fits. + // r3 address, r5 stride + "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "mov r4, r3\n" + "mov r6, #1\n" + + "vst1.8 {d30[0]}, [r3], r6\n" + "vst1.8 {d30[1]}, [r3], r6\n" + "vst1.8 {d30[2]}, [r3], r6\n" + "vst1.8 {d30[3]}, [r3], r6\n" + "31:\n" + + // Load dst_ptr, increment, and write back. + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "add r4, r4, #4\n" + "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + + RUY_MAKE_ZERO(q13) + RUY_MAKE_ZERO(q14) + RUY_MAKE_ZERO(q15) + + "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" + + // Store int8 values: + RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n" + + // Cast-and-saturate from int32 to int16 + // After this, all values for output are in d28. + "vqmovn.s32 d28, q14\n" + + // At this point, d12 -- d26, d29, d30, d31 aren't used anymore for the + // current block, so we can start clearing these accumulators for the + // next block (next iteration of the main loop). + RUY_MAKE_ZERO(q6) + RUY_MAKE_ZERO(q7) + RUY_MAKE_ZERO(q8) + RUY_MAKE_ZERO(q9) + RUY_MAKE_ZERO(q10) + RUY_MAKE_ZERO(q11) + RUY_MAKE_ZERO(q12) + RUY_MAKE_ZERO(q13) + RUY_MAKE_ZERO(q15) + + // Load the destination zero point into each of the 8 16-bit slots + // in a q register. + "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" + "vdup.16 q13, r4\n" // dst_zero_point + + // Add the destination zero point + "vadd.i16 q14, q14, q13\n" + + // Cast-and-saturate from int16 to int8 + "vqmovn.s16 d30, q14\n" + // At this point, we only need 4 8-bit values in the lower half + // of d30. + + // Load the clamp_min, clamp_max bounds + "ldrb r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + "ldrb r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + "vdup.8 d28, r2\n" // clamp_min + "vdup.8 d29, r3\n" // clamp_max + + // Apply the clamp_min bound + "vmax.s8 d30, d30, d28\n" + // Apply the clamp_max bound + "vmin.s8 d30, d30, d29\n" + + // Compute how much of the 4x1 block of destination 8bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 4x2, there are some 4x2 blocks along the boundaries that do + // not fit entirely. + + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n" + "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + "sub r1, r1, r8\n" + + "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n" + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" + "sub r2, r2, r4\n" + "mov r3, #4\n" + "mov r5, #2\n" + "cmp r1, #4\n" + // Compute r1 = how many rows of the 4x2 block fit + "it gt\n" + "movgt r1, r3\n" + + // Test if r1==4 i.e. if all of the 4x1 block fits. + "cmp r1, r3\n" + + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" + // Yes, all of the 4x2 block fits, go to fast path. + "beq 30f\n" + // Not all of the 4x2 block fits. + // Store to dst_tmp_buf + // Set r3 address to write to dst_tmp_buf. + "mov r3, %[dst_tmp_buf]\n" + "vst1.8 {d30}, [r3]\n" + + // Slow loop copying from dst_tmp_buf to dst. + "50:\n" + "mov r8, #0\n" + "51:\n" + "ldrb r10, [r3, r8]\n" + "strb r10, [r4, r8]\n" + "add r8, r8, #1\n" + "cmp r8, r1\n" + "blt 51b\n" + "b 31f\n" + "30:\n" + // Yes, all of the 4x1 block fits. + // r3 address, r5 stride + "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "mov r4, r3\n" + "mov r6, #1\n" + + "vst1.8 {d30[0]}, [r3], r6\n" + "vst1.8 {d30[1]}, [r3], r6\n" + "vst1.8 {d30[2]}, [r3], r6\n" + "vst1.8 {d30[3]}, [r3], r6\n" + "31:\n" + + // Load dst_ptr, increment, and write back. + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "add r4, r4, #4\n" + "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + + RUY_MAKE_ZERO(q13) + RUY_MAKE_ZERO(q14) + RUY_MAKE_ZERO(q15) + + "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n" + + // Load the destination zero point into each of the 4 32-bit slots + // in a q register. + "ldrsh r4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" + "vdup.32 q13, r4\n" // dst_zero_point + // Add the destination zero point + "vadd.s32 q14, q14, q13\n" + //"vadd.s32 q15, q15, q13\n" + + // Cast-and-saturate from int32 to int16 + // After this, all values for output are in d28. + "vqmovn.s32 d28, q14\n" + + // At this point, d12 -- d26, d29, d30, d31 aren't used anymore for the + // so we can start clearing these accumulators for the next block + // (next iteration of the main loop). + RUY_MAKE_ZERO(q6) + RUY_MAKE_ZERO(q7) + RUY_MAKE_ZERO(q8) + RUY_MAKE_ZERO(q9) + RUY_MAKE_ZERO(q10) + RUY_MAKE_ZERO(q11) + RUY_MAKE_ZERO(q15) + + // Load the clamp_min, clamp_max bounds + "ldrh r2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + "ldrh r3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + "vdup.16 d24, r2\n" // clamp_min + "vdup.16 d26, r3\n" // clamp_max + + // Apply the clamp_min bound + "vmax.s16 d28, d28, d24\n" + // Apply the clamp_max bound + "vmin.s16 d28, d28, d26\n" + + RUY_MAKE_ZERO(q12) + RUY_MAKE_ZERO(q13) + + // Compute how much of the 4x1 block of destination 16-bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 4x1, there are some 4x1 blocks along the boundaries that do + // not fit entirely. + + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n" + "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + "sub r1, r1, r8\n" + + "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n" + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" + "sub r2, r2, r4\n" + "mov r3, #4\n" + "mov r5, #2\n" + "cmp r1, #4\n" + // Compute r1 = how many rows of the 4x1 block fit + "it gt\n" + "movgt r1, r3\n" + + // Test if r1==4, i.e. if all of the 4x1 block fits. + "cmp r1, r3\n" + + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" + // Yes, all of the 4x1 block fits, go to fast path. + "beq 30f\n" + // Not all of the 4x1 block fits. + // Store to dst_tmp_buf + // Set r3 address to write to dst_tmp_buf. + "mov r3, %[dst_tmp_buf]\n" + "vst1.16 {d28}, [r3]\n" + + // Slow loop copying from dst_tmp_buf to dst. + "50:\n" + "mov r8, #0\n" + "51:\n" + // Shift of offset register for half-word loads not allowed in A32, + // so we shift, load/store, then shift back r8. + "lsl r8, r8, #1\n" + "ldrh r10, [r3, r8]\n" + "strh r10, [r4, r8]\n" + "lsr r8, r8, #1\n" + "add r8, r8, #1\n" + "cmp r8, r1\n" + "blt 51b\n" + "b 31f\n" + "30:\n" + // Yes, all of the 4x1 block fits. + // r3 address, r5 stride + "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "mov r4, r3\n" + "mov r6, #2\n" + + "vst1.16 {d28[0]}, [r3], r6\n" + "vst1.16 {d28[1]}, [r3], r6\n" + "vst1.16 {d28[2]}, [r3], r6\n" + "vst1.16 {d28[3]}, [r3], r6\n" + "31:\n" + + // Load dst_ptr, increment, and write back. + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "add r4, r4, #8\n" + "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + + RUY_MAKE_ZERO(q14) + + "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n" + + // Since the store type is the same as the accum type, no need for + // downcast. There's also no need for clamp by min/max. + + // At this point, v20 -- v31 aren't used anymore for the current block, + // so we can start clearing these accumulators for the next block + // (next iteration of the main loop). + // Clear accumulators. + RUY_MAKE_ZERO(q6) + RUY_MAKE_ZERO(q7) + RUY_MAKE_ZERO(q8) + RUY_MAKE_ZERO(q9) + RUY_MAKE_ZERO(q10) + RUY_MAKE_ZERO(q11) + RUY_MAKE_ZERO(q12) + RUY_MAKE_ZERO(q13) + + // Compute how much of the 4x1 block of destination 32 bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 4x2, there are some 4x4 blocks along the boundaries that do + // not fit entirely. + + "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_DST_ROWS) "]\n" + "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + "sub r1, r1, r8\n" + + "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_DST_COLS) "]\n" + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" + "sub r2, r2, r4\n" + "mov r3, #4\n" + "mov r5, #2\n" + "cmp r1, #4\n" + // Compute r1 = how many rows of the 4x2 block fit + "it gt\n" + "movgt r1, r3\n" + + // Test if r1==4, i.e. if all of the 4x1 block fits. + "cmp r1, r3\n" + + // Yes, all of the 4x1 block fits, go to fast path. + "beq 30f\n" + // Not all of the 4x1 block fits. + // Set (r3 address, r4 stride) to write to dst_tmp_buf + "mov r3, %[dst_tmp_buf]\n" + "mov r4, #16\n" + "b 31f\n" + + "30:\n" + // Yes, all of the 4x1 block fits. + "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" + // r3 address, r4 stride + "ldr r3, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "mov r4, r5\n" + + "31:\n" + + "vst1.32 {d28, d29}, [r3]\n" + + // If all of the 4x1 block fits, we just finished writing it to the + // destination, so we skip the next part. + "beq 41f\n" + // Not all of the 4x1 block fits in the destination matrix. We just + // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over + // it to copy into the destination matrix the part that fits. + "ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" + "mov r3, %[dst_tmp_buf]\n" + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "50:\n" + "mov r5, #0\n" + "51:\n" + "ldr r10, [r3, r5, lsl #2]\n" + "str r10, [r4, r5, lsl #2]\n" + "add r5, r5, #1\n" + "cmp r5, r1\n" + "blt 51b\n" + + "41:\n" + // Load dst_ptr, increment, and write back. + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "add r4, r4, #16\n" + "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + + RUY_MAKE_ZERO(q10) + RUY_MAKE_ZERO(q11) + + "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" + + RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n" + + // Reload some params --- we had used x5 -- x7 for a few other things + // since the last time we had loaded them. + "ldr r5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + "ldr r6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" + "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" + + // Move to the next block of the destination matrix, for the next iter + // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already + // been updated earlier. + // Have we reached the end row? + "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + "cmp r8, r3\n" + + "beq 20f\n" // yes, end row. + // Not end row. Move to the next row. + "add r8, r8, #4\n" + // Store new value of row + "str r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + + "b 21f\n" + "20:\n" + // Was already at end row. + // Move back to first row. + "str r6, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n" + // Move to the next column. + "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" + "add r4, r4, #2\n" + "str r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" + + "ldr r8, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" + "ldr r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n" + // Increment dst_col_ptr by dst_stride (i.e. 1 column) + "add r1, r1, r8\n" + // Store dst_col_ptr + "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_COL_PTR) "]\n" + // Store dst_ptr + "str r1, [sp, #" RUY_STR(RUY_STACK_OFFSET_DST_PTR) "]\n" + "21:\n" + + // Main loop exit condition: have we hit the end column? + "ldr r4, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" + "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_COL) "]\n" + "cmp r8, r4\n" + + // w1 is the number of levels of depth that we have already loaded + // LHS and RHS data for. Corresponding to the initial ld1 instructions + // above, this is currently 16. + "mov r1, #16\n" + + "ble 1b\n" + + // Restore stack pointer. + "add sp, sp, #" RUY_STR(RUY_STACK_OFFSET_SIZE) "\n" + + // clang-format on + + : [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr) + : [ params ] "r"(¶ms), [dst_tmp_buf] "r"(params.dst_tmp_buf) + : "r0", "r1", "r2", "r3", "r4", "r5", "r6", "r8", "r10", "cc", + // Clobber list must specify q registers (and not their constituent + // d registers). There is a (currently unexplained) slowdown if + // d registers are listed in the clobbers list. + "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", + "q9", "q10", "q12", "q13", "q14", "q15"); +} + +#undef RUY_OFFSET_BIAS +#undef RUY_OFFSET_LHS_SUMS +#undef RUY_OFFSET_RHS_SUMS +#undef RUY_OFFSET_LHS_BASE_PTR +#undef RUY_OFFSET_MULTIPLIER_FIXEDPOINT +#undef RUY_OFFSET_MULTIPLIER_EXPONENT +#undef RUY_OFFSET_RHS_BASE_PTR +#undef RUY_OFFSET_DST_BASE_PTR +#undef RUY_OFFSET_LHS_ZERO_POINT +#undef RUY_OFFSET_RHS_ZERO_POINT +#undef RUY_OFFSET_DST_ZERO_POINT +#undef RUY_OFFSET_PROD_ZP_DEPTH +#undef RUY_OFFSET_START_ROW +#undef RUY_OFFSET_START_COL +#undef RUY_OFFSET_LAST_ROW +#undef RUY_OFFSET_LAST_COL +#undef RUY_OFFSET_DST_ROWS +#undef RUY_OFFSET_DST_COLS +#undef RUY_OFFSET_LHS_STRIDE +#undef RUY_OFFSET_RHS_STRIDE +#undef RUY_OFFSET_DST_STRIDE +#undef RUY_OFFSET_DEPTH +#undef RUY_OFFSET_CLAMP_MIN +#undef RUY_OFFSET_CLAMP_MAX +#undef RUY_OFFSET_FLAGS +#undef RUY_OFFSET_DST_TYPE_ID + +#undef RUY_STACK_OFFSET_SIZE +#undef RUY_STACK_OFFSET_DST_COL_PTR +#undef RUY_STACK_OFFSET_DST_PTR +#undef RUY_STACK_OFFSET_ROW +#undef RUY_STACK_OFFSET_COL +#undef RUY_STACK_OFFSET_LHS_COL_PTR +#undef RUY_STACK_OFFSET_RHS_COL_PTR + +#endif // RUY_PLATFORM(NEON_32) && (RUY_OPT_ENABLED(RUY_OPT_ASM) +} // namespace ruy diff --git a/ruy/kernel_arm64.cc b/ruy/kernel_arm64.cc new file mode 100644 index 0000000..38af032 --- /dev/null +++ b/ruy/kernel_arm64.cc @@ -0,0 +1,7835 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "ruy/common.h" +#include "ruy/kernel.h" +#include "ruy/opt_set.h" +#include "ruy/platform.h" +#include "ruy/profiler/instrumentation.h" + +namespace ruy { + +#if RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM) + +#define RUY_ASM_LABEL_STORE_UINT8 91 +#define RUY_ASM_LABEL_STORE_INT8 92 +#define RUY_ASM_LABEL_STORE_INT16 93 +#define RUY_ASM_LABEL_STORE_INT32 94 +#define RUY_ASM_LABEL_AFTER_STORE 99 + +#define RUY_OFFSET_BIAS 0 +#define RUY_OFFSET_LHS_SUMS 8 +#define RUY_OFFSET_RHS_SUMS 16 +#define RUY_OFFSET_LHS_BASE_PTR 24 +#define RUY_OFFSET_MULTIPLIER_FIXEDPOINT 32 +#define RUY_OFFSET_MULTIPLIER_EXPONENT 40 +#define RUY_OFFSET_RHS_BASE_PTR 48 +#define RUY_OFFSET_DST_BASE_PTR 56 +#define RUY_OFFSET_LHS_ZERO_POINT 64 +#define RUY_OFFSET_RHS_ZERO_POINT 68 +#define RUY_OFFSET_DST_ZERO_POINT 72 +#define RUY_OFFSET_PROD_ZP_DEPTH 76 +#define RUY_OFFSET_START_ROW 80 +#define RUY_OFFSET_START_COL 84 +#define RUY_OFFSET_LAST_ROW 88 +#define RUY_OFFSET_LAST_COL 92 +#define RUY_OFFSET_DST_ROWS 96 +#define RUY_OFFSET_DST_COLS 100 +#define RUY_OFFSET_LHS_STRIDE 104 +#define RUY_OFFSET_RHS_STRIDE 108 +#define RUY_OFFSET_DST_STRIDE 112 +#define RUY_OFFSET_DEPTH 116 +#define RUY_OFFSET_CLAMP_MIN 120 +#define RUY_OFFSET_CLAMP_MAX 124 +#define RUY_OFFSET_FLAGS 128 + +template +void CheckOffsetsInKernelParams8bit(const Params&) { + static_assert(offsetof(Params, lhs_zero_point) == RUY_OFFSET_LHS_ZERO_POINT, + ""); + static_assert(offsetof(Params, rhs_zero_point) == RUY_OFFSET_RHS_ZERO_POINT, + ""); + static_assert(offsetof(Params, dst_zero_point) == RUY_OFFSET_DST_ZERO_POINT, + ""); + static_assert(offsetof(Params, prod_zp_depth) == RUY_OFFSET_PROD_ZP_DEPTH, + ""); + static_assert(offsetof(Params, multiplier_fixedpoint) == + RUY_OFFSET_MULTIPLIER_FIXEDPOINT, + ""); + static_assert( + offsetof(Params, multiplier_exponent) == RUY_OFFSET_MULTIPLIER_EXPONENT, + ""); + static_assert(offsetof(Params, clamp_min) == RUY_OFFSET_CLAMP_MIN, ""); + static_assert(offsetof(Params, clamp_max) == RUY_OFFSET_CLAMP_MAX, ""); + static_assert(offsetof(Params, bias) == RUY_OFFSET_BIAS, ""); + static_assert(offsetof(Params, lhs_sums) == RUY_OFFSET_LHS_SUMS, ""); + static_assert(offsetof(Params, rhs_sums) == RUY_OFFSET_RHS_SUMS, ""); + static_assert(offsetof(Params, flags) == RUY_OFFSET_FLAGS, ""); + static_assert(offsetof(Params, lhs_base_ptr) == RUY_OFFSET_LHS_BASE_PTR, ""); + static_assert(offsetof(Params, start_row) == RUY_OFFSET_START_ROW, ""); + static_assert(offsetof(Params, last_row) == RUY_OFFSET_LAST_ROW, ""); + static_assert(offsetof(Params, last_col) == RUY_OFFSET_LAST_COL, ""); + static_assert(offsetof(Params, lhs_stride) == RUY_OFFSET_LHS_STRIDE, ""); + static_assert(offsetof(Params, rhs_stride) == RUY_OFFSET_RHS_STRIDE, ""); + static_assert(offsetof(Params, dst_stride) == RUY_OFFSET_DST_STRIDE, ""); + static_assert(offsetof(Params, depth) == RUY_OFFSET_DEPTH, ""); +} + +// Fast-int8-trick kernel, similar to this production gemmlowp kernel: +// NEON_64bit_GEMM_Int8Operands_AccumTwoWithin16Bits +// https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L2296 +// +// Relevant target CPUs for this kernel include ARM Cortex-A73 and Cortex-A75, +// since these are 64-bit, out-of-order and without dotprod support. +void Kernel8bitNeonOutOfOrder(const KernelParams8bit<4, 4>& params) { + profiler::ScopeLabel label( + "Kernel (kNeon, optimized for out-of-order cores)"); + + CheckOffsetsInKernelParams8bit(params); + + const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; + const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; + const std::int8_t* lhs_ptr = lhs_col_ptr; + const std::int8_t* rhs_ptr = rhs_col_ptr; + void* dst_col_ptr = params.dst_base_ptr; + void* dst_ptr = dst_col_ptr; + int row = params.start_row; + int col = params.start_col; + + // The asm kernel below has the following NEON register allocation: + // + // v16 -- v31 are int32 accumulators. + // During accumulation, v0 -- v3 are used to load int8 data from LHS and + // v4 -- v7 from RHS: + // + // int8 RHS 16x4 block + // /-----------------------------------------\ + // |v4.b[0] ... v7.b[0] | + // | ... ... | + // |v4.b[15] ... v7.b[15] | + // \-----------------------------------------/ + // int8 LHS 4x16 block + // /---------------------\ /-----------------------------------------\ + // |v0.b[0] ... v0.b[15] | |v16.4s ... v28.4s | + // |v1.b[0] ... v1.b[15] | |v17.4s ... v29.4s | + // |v2.b[0] ... v2.b[15] | |v18.4s ... v30.4s | + // |v3.b[0] ... v3.b[15] | |v19.4s ... v31.4s | + // \---------------------/ \-----------------------------------------/ + // int32 accumulators 4x4 block + // + // No attempt had been made so far at implementing the RUY_OPT_MAX_STREAMING + // optimization for this kernel. + asm volatile( +#define RUY_MAKE_ZERO(reg) "dup " #reg ".4s, wzr\n" + + // clang-format off + + // Load some parameters into registers. + "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" + "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" + "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" + "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" + "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" + "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" + "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" + + // Load the first 64 bytes of LHS and RHS data. + "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v2.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v3.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v4.16b}, [%[rhs_ptr]], #16\n" + "ld1 {v5.16b}, [%[rhs_ptr]], #16\n" + "ld1 {v6.16b}, [%[rhs_ptr]], #16\n" + "ld1 {v7.16b}, [%[rhs_ptr]], #16\n" + + // Clear accumulators. + RUY_MAKE_ZERO(v16) + RUY_MAKE_ZERO(v17) + RUY_MAKE_ZERO(v18) + RUY_MAKE_ZERO(v19) + RUY_MAKE_ZERO(v20) + RUY_MAKE_ZERO(v21) + RUY_MAKE_ZERO(v22) + RUY_MAKE_ZERO(v23) + RUY_MAKE_ZERO(v24) + RUY_MAKE_ZERO(v25) + RUY_MAKE_ZERO(v26) + RUY_MAKE_ZERO(v27) + RUY_MAKE_ZERO(v28) + RUY_MAKE_ZERO(v29) + RUY_MAKE_ZERO(v30) + RUY_MAKE_ZERO(v31) + + // w1 is the number of levels of depth that we have already loaded + // LHS and RHS data for. Corresponding to the initial ld1 instructions + // above, this is currently 16. + "mov w1, #16\n" + + // Perform the first few multiply-adds on the data that we have already + // loaded. + "smull v8.8h, v0.8b, v4.8b\n" + "smull v9.8h, v1.8b, v4.8b\n" + "smull v10.8h, v2.8b, v4.8b\n" + "smull v11.8h, v3.8b, v4.8b\n" + "smull v12.8h, v0.8b, v5.8b\n" + "smull v13.8h, v1.8b, v5.8b\n" + "smull v14.8h, v2.8b, v5.8b\n" + "smull v15.8h, v3.8b, v5.8b\n" + + // Multiply-accumulate second-half, again into the same + // 16bit local accumulator registers. This is where we + // take advantage of having int8 instead of uint8 and therefore + // being able to accumulate two products into int16. + "smlal2 v8.8h, v0.16b, v4.16b\n" + "smlal2 v9.8h, v1.16b, v4.16b\n" + "smlal2 v10.8h, v2.16b, v4.16b\n" + "smlal2 v11.8h, v3.16b, v4.16b\n" + "smlal2 v12.8h, v0.16b, v5.16b\n" + "smlal2 v13.8h, v1.16b, v5.16b\n" + "smlal2 v14.8h, v2.16b, v5.16b\n" + "smlal2 v15.8h, v3.16b, v5.16b\n" + + + // Main loop of the whole GEMM, over rows and columns of the + // destination matrix. + "1:\n" + + // Reminder - w1 is how many levels of depth we have already loaded + // data for, w12 is the total depth. + "cmp w1, w12\n" + "beq 79f\n" + + "2:\n" + + // Some multiplications and 16-bit accumulation were already done above, + // so we start right away in the middle. + "sadalp v16.4s, v8.8h\n" + "ld1 {v4.16b}, [%[rhs_ptr]], #16\n" + "smull v8.8h, v0.8b, v6.8b\n" + "sadalp v17.4s, v9.8h\n" + "ld1 {v5.16b}, [%[rhs_ptr]], #16\n" + "smull v9.8h, v1.8b, v6.8b\n" + "sadalp v18.4s, v10.8h\n" + "smull v10.8h, v2.8b, v6.8b\n" + "sadalp v19.4s, v11.8h\n" + "smull v11.8h, v3.8b, v6.8b\n" + "sadalp v20.4s, v12.8h\n" + "smull v12.8h, v0.8b, v7.8b\n" + "sadalp v21.4s, v13.8h\n" + "smull v13.8h, v1.8b, v7.8b\n" + "sadalp v22.4s, v14.8h\n" + "smull v14.8h, v2.8b, v7.8b\n" + "sadalp v23.4s, v15.8h\n" + "smull v15.8h, v3.8b, v7.8b\n" + + // Multiply-accumulate second-half, again into the same + // 16bit local accumulator registers. This is where we + // take advantage of having int8 instead of uint8 and therefore + // being able to accumulate two products into int16. + "smlal2 v8.8h, v0.16b, v6.16b\n" + "smlal2 v9.8h, v1.16b, v6.16b\n" + "smlal2 v10.8h, v2.16b, v6.16b\n" + "smlal2 v11.8h, v3.16b, v6.16b\n" + + "ld1 {v6.16b}, [%[rhs_ptr]], #16\n" + + "smlal2 v12.8h, v0.16b, v7.16b\n" + "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" + "smlal2 v13.8h, v1.16b, v7.16b\n" + "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" + "smlal2 v14.8h, v2.16b, v7.16b\n" + "ld1 {v2.16b}, [%[lhs_ptr]], #16\n" + "smlal2 v15.8h, v3.16b, v7.16b\n" + "ld1 {v3.16b}, [%[lhs_ptr]], #16\n" + + "sadalp v24.4s, v8.8h\n" + "smull v8.8h, v0.8b, v4.8b\n" + "sadalp v25.4s, v9.8h\n" + "ld1 {v7.16b}, [%[rhs_ptr]], #16\n" + "smull v9.8h, v1.8b, v4.8b\n" + "sadalp v26.4s, v10.8h\n" + "smull v10.8h, v2.8b, v4.8b\n" + "sadalp v27.4s, v11.8h\n" + "smull v11.8h, v3.8b, v4.8b\n" + "sadalp v28.4s, v12.8h\n" + "smull v12.8h, v0.8b, v5.8b\n" + "sadalp v29.4s, v13.8h\n" + "smull v13.8h, v1.8b, v5.8b\n" + "sadalp v30.4s, v14.8h\n" + "smull v14.8h, v2.8b, v5.8b\n" + "sadalp v31.4s, v15.8h\n" + "smull v15.8h, v3.8b, v5.8b\n" + + // Multiply-accumulate second-half, again into the same + // 16bit local accumulator registers. This is where we + // take advantage of having int8 instead of uint8 and therefore + // being able to accumulate two products into int16. + "smlal2 v8.8h, v0.16b, v4.16b\n" + "smlal2 v9.8h, v1.16b, v4.16b\n" + "smlal2 v10.8h, v2.16b, v4.16b\n" + "smlal2 v11.8h, v3.16b, v4.16b\n" + + "smlal2 v12.8h, v0.16b, v5.16b\n" + "smlal2 v13.8h, v1.16b, v5.16b\n" + "smlal2 v14.8h, v2.16b, v5.16b\n" + "smlal2 v15.8h, v3.16b, v5.16b\n" + + + + // Each iteration of this loop advances by 16 levels of depth. + "add w1, w1, #16\n" + + // Loop termination condition + "cmp w1, w12\n" + + "blt 2b\n" + + "79:\n" + + "sadalp v16.4s, v8.8h\n" + "smull v8.8h, v0.8b, v6.8b\n" + "sadalp v17.4s, v9.8h\n" + "smull v9.8h, v1.8b, v6.8b\n" + "sadalp v18.4s, v10.8h\n" + "smull v10.8h, v2.8b, v6.8b\n" + "sadalp v19.4s, v11.8h\n" + "smull v11.8h, v3.8b, v6.8b\n" + "sadalp v20.4s, v12.8h\n" + "smull v12.8h, v0.8b, v7.8b\n" + "sadalp v21.4s, v13.8h\n" + "smull v13.8h, v1.8b, v7.8b\n" + "sadalp v22.4s, v14.8h\n" + "smull v14.8h, v2.8b, v7.8b\n" + "sadalp v23.4s, v15.8h\n" + "smull v15.8h, v3.8b, v7.8b\n" + + // Multiply-accumulate second-half, again into the same + // 16bit local accumulator registers. This is where we + // take advantage of having int8 instead of uint8 and therefore + // being able to accumulate two products into int16. + "smlal2 v8.8h, v0.16b, v6.16b\n" + "smlal2 v9.8h, v1.16b, v6.16b\n" + "smlal2 v10.8h, v2.16b, v6.16b\n" + "smlal2 v11.8h, v3.16b, v6.16b\n" + + "smlal2 v12.8h, v0.16b, v7.16b\n" + "smlal2 v13.8h, v1.16b, v7.16b\n" + "smlal2 v14.8h, v2.16b, v7.16b\n" + "smlal2 v15.8h, v3.16b, v7.16b\n" + + "sadalp v24.4s, v8.8h\n" + "sadalp v25.4s, v9.8h\n" + "sadalp v26.4s, v10.8h\n" + "sadalp v27.4s, v11.8h\n" + "sadalp v28.4s, v12.8h\n" + "sadalp v29.4s, v13.8h\n" + "sadalp v30.4s, v14.8h\n" + "sadalp v31.4s, v15.8h\n" + + // End of accumulation. The registers v16 -- v31 contain the final + // int32 accumulator values of the current 4x4 destination block. + // We now have to compute the final 8-bit values from these int32 + // accumulators, and advance to the next 4x4 block. We intertwine + // these two aspects whenever possible for optimal pipelining, both + // at the data flow level (prefetch data for next block as early as + // possible) and instruction pipelining level (some of the next-block + // work can dual-issue with some of the final work on the current + // block). + + // Reduce 32bit accumulators horizontally. + "addp v16.4s, v16.4s, v17.4s\n" + "addp v18.4s, v18.4s, v19.4s\n" + "addp v20.4s, v20.4s, v21.4s\n" + "addp v22.4s, v22.4s, v23.4s\n" + "addp v24.4s, v24.4s, v25.4s\n" + "addp v26.4s, v26.4s, v27.4s\n" + "addp v28.4s, v28.4s, v29.4s\n" + "addp v30.4s, v30.4s, v31.4s\n" + + // Reduce 32bit accumulators horizontally, second pass + // (each pass adds pairwise. we need to add 4-wise). + "addp v16.4s, v16.4s, v18.4s\n" + "addp v17.4s, v20.4s, v22.4s\n" + "addp v18.4s, v24.4s, v26.4s\n" + "addp v19.4s, v28.4s, v30.4s\n" + + // Logic to advance to the next block in preparation for the next + // iteration of the main loop. For now, we only want to compute + // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are + // not yet ready to update the values of row and col, as we still need + // the current values for the rest of the work on the current block. + + "cmp %w[row], w7\n" // Have we finished the last row? + "bge 4f\n" // If finished last row, go to 4 + // Not finished last row: then advance to next row. + "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #2\n" + "b 5f\n" + "4:\n" // Finished last row... + "mov %[lhs_col_ptr], x5\n" // Go back to first row + // Now we need to advance to the next column. If we already + // finished the last column, then in principle we are done, however + // we can't just return here, as we need to allow the end work of the + // current block to complete. The good news is that at this point it + // doesn't matter what data we load for the next column, since + // we will exit from the main loop below before actually storing + // anything computed from that data. + "cmp %w[col], w8\n" // Have we finished the last column? + "bge 5f\n" // If yes, just carry on without updating the column pointer. + // Not finished last column: then advance to next column. + "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #2\n" + "5:\n" + + // Set the LHS and RHS data pointers to the start of the columns just + // computed. + "mov %[lhs_ptr], %[lhs_col_ptr]\n" + "mov %[rhs_ptr], %[rhs_col_ptr]\n" + + // Load some parameters needed for the end work on current block. + RUY_MAKE_ZERO(v8) + "ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" + "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n" + "ins v13.h[4], w4\n" // dst_zero_point + "ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n" + "ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" + "dup v9.4s, w3\n" // create prod_zp_depth_vec + "add x5, x4, %x[row], lsl #2\n" + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" + "csel x4, x4, x5, eq\n" + + "ld1 {v15.4s}, [x4]\n" // multiplier_fixedpoint + + // Now we load: bias data, LHS sums data, RHS sums data. + + // First, load the base pointers from the params. + "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" + + "add x5, x1, %x[row], lsl #2\n" + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" + "csel x1, x1, x5, eq\n" + + // Load 4 bias values. + "ld1 {v14.4s}, [x1]\n" + + // Now that we know what LHS and RHS data the next iteration of the + // main loop will need to load, we start loading the first 32 bytes of + // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore + // in the rest of the work on the current block. + "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v2.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v3.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v4.16b}, [%[rhs_ptr]], #16\n" + "ld1 {v5.16b}, [%[rhs_ptr]], #16\n" + "ld1 {v6.16b}, [%[rhs_ptr]], #16\n" + "ld1 {v7.16b}, [%[rhs_ptr]], #16\n" + + // Add to the bias values the product (depth * lhs_zero_point * rhs_zero_point), + // See the term NZ1Z2 in equation (7) in https://arxiv.org/pdf/1712.05877.pdf + "add v14.4s, v14.4s, v9.4s\n" + + // Perform the bias-addition (per the above, we have just folded into + // the bias the (depth * lhs_zero_point * rhs_zero_point) term.) + "add v16.4s, v16.4s, v14.4s\n" + "add v17.4s, v17.4s, v14.4s\n" + "add v18.4s, v18.4s, v14.4s\n" + "add v19.4s, v19.4s, v14.4s\n" + + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n" + "beq 401f\n" + "ldr x3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n" + "add x3, x3, %x[col], lsl #2\n" + "ld1 {v14.4s}, [x3]\n" + "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n" + "dup v10.4s, w5\n" // create lhs_zero_point_vec + // Subtract rhs_sums * lhs_zero_point, per + // equation (7) in https://arxiv.org/pdf/1712.05877.pdf + "mls v16.4s, v10.4s, v14.s[0]\n" + "mls v17.4s, v10.4s, v14.s[1]\n" + "mls v18.4s, v10.4s, v14.s[2]\n" + "mls v19.4s, v10.4s, v14.s[3]\n" + "401:\n" + + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n" + "beq 402f\n" + "ldr x2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n" + "add x2, x2, %x[row], lsl #2\n" + "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n" + // Load 4 lhs_sums values. + "ld1 {v11.4s}, [x2]\n" + "ins v13.s[1], w5\n" // rhs_zero_point + // Compute lhs_sums * rhs_zero_point. + "mul v11.4s, v11.4s, v13.s[1]\n" + // Subtract lhs_sums * rhs_zero_point, per + // equation (7) in https://arxiv.org/pdf/1712.05877.pdf + "sub v16.4s, v16.4s, v11.4s\n" + "sub v17.4s, v17.4s, v11.4s\n" + "sub v18.4s, v18.4s, v11.4s\n" + "sub v19.4s, v19.4s, v11.4s\n" + + // If the destination is int32, it means the user asks for the raw + // accumulators, no need for us to downquantize the value. + "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n" + "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n" + + "402:\n" + + // At this point we have computed the final int32 values. Now we + // start down-quantizing them to obtain the final 8bit values from them. + + // As part of this down-quantization, our int32 values will be + // multiplied by a multiplier that has a fixed-point component and an + // exponent component. + + //Load the exponent part of the multiplier. + "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n" + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" + "add x5, x1, %x[row], lsl #2\n" + "csel x1, x1, x5, eq\n" + + "ld1 {v14.4s}, [x1]\n" + + "smax v12.4s, v14.4s, v8.4s\n" + + "sshl v16.4s, v16.4s, v12.4s\n" + "sshl v17.4s, v17.4s, v12.4s\n" + "sshl v18.4s, v18.4s, v12.4s\n" + "sshl v19.4s, v19.4s, v12.4s\n" + + "smin v12.4s, v14.4s, v8.4s\n" + + // Apply the fixed-point part of the multiplier. + "sqrdmulh v16.4s, v16.4s, v15.4s\n" + "sqrdmulh v17.4s, v17.4s, v15.4s\n" + "sqrdmulh v18.4s, v18.4s, v15.4s\n" + "sqrdmulh v19.4s, v19.4s, v15.4s\n" + + // We have some rounding division-by-power-of-two to do. This should + // always use "round to nearest". We allow for some + // freedom in how ties are broken, to strike a good compromise of + // performance on given hardware vs. perfect agreement of results + // across hardware. + // + // When RUY_OPT_NATIVE_ROUNDING is enabled, we allow for implementation + // defined tie-breaks to help performance. On NEON, this means that we + // can just use the NEON rounding instructions, such as srshl. They + // happen to be breaking ties upward. + // + // When RUY_OPT_NATIVE_ROUNDING is disabled, we implement strict + // break-ties-away-from zero, as described in Appendix B of + // https://arxiv.org/pdf/1712.05877.pdf + // When we wrote that, we thought that that would be better unbiased + // than the NEON upwards tie-breaks, and we had observed some + // improvement on some model. However, that is only more unbiased for + // data centered at zero, which was likely the case in that model, + // but is not always the case. If we wanted something more consistently + // unbiased then we should try breaking ties toward-nearest-even. +#if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING) + // Fix up values to be right-shifted, so that the (round to nearest, + // break ties upward) behavior of srshl applied to these fixed-up + // values, produces the same result as the desired (round to nearest, + // break ties away from zero) behavior on the original values. + "and v8.16b, v16.16b, v12.16b\n" + "and v9.16b, v17.16b, v12.16b\n" + "and v14.16b, v18.16b, v12.16b\n" + "and v15.16b, v19.16b, v12.16b\n" + "sshr v8.4s, v8.4s, #31\n" + "sshr v9.4s, v9.4s, #31\n" + "sshr v14.4s, v14.4s, #31\n" + "sshr v15.4s, v15.4s, #31\n" + "sqadd v16.4s, v16.4s, v8.4s\n" + "sqadd v17.4s, v17.4s, v9.4s\n" + "sqadd v18.4s, v18.4s, v14.4s\n" + "sqadd v19.4s, v19.4s, v15.4s\n" +#endif + // At this point we have reduced the problem of correctly implementing + // rounding divide-by-power-of-two, to what the SRSHL instruction can + // do. + "srshl v16.4s, v16.4s, v12.4s\n" + "srshl v17.4s, v17.4s, v12.4s\n" + "srshl v18.4s, v18.4s, v12.4s\n" + "srshl v19.4s, v19.4s, v12.4s\n" + + "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n" + "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n" + "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n" + "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n" + + // Cast-and-saturate from int32 to int16 + "sqxtn v16.4h, v16.4s\n" + "sqxtn2 v16.8h, v17.4s\n" + "sqxtn v17.4h, v18.4s\n" + "sqxtn2 v17.8h, v19.4s\n" + + // At this point, v18 -- v31 aren't used anymore for the current block, + // so we can start clearing these accumulators for the next block + // (next iteration of the main loop). + RUY_MAKE_ZERO(v18) + RUY_MAKE_ZERO(v19) + RUY_MAKE_ZERO(v20) + RUY_MAKE_ZERO(v21) + RUY_MAKE_ZERO(v22) + RUY_MAKE_ZERO(v23) + RUY_MAKE_ZERO(v24) + RUY_MAKE_ZERO(v25) + RUY_MAKE_ZERO(v26) + RUY_MAKE_ZERO(v27) + RUY_MAKE_ZERO(v28) + RUY_MAKE_ZERO(v29) + RUY_MAKE_ZERO(v30) + RUY_MAKE_ZERO(v31) + + // Add the destination zero point + "dup v14.8h, v13.h[4]\n" + "add v16.8h, v16.8h, v14.8h\n" + "add v17.8h, v17.8h, v14.8h\n" + + // Cast-and-saturate from int16 to uint8 + "sqxtun v16.8b, v16.8h\n" + "sqxtun2 v16.16b, v17.8h\n" + + // Load the clamp_min, clamp_max bounds + "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + "dup v14.16b, w2\n" // clamp_min + "dup v15.16b, w3\n" // clamp_max + + // Apply the clamp_min bound + "umax v16.16b, v16.16b, v14.16b\n" + // Apply the clamp_max bound + "umin v16.16b, v16.16b, v15.16b\n" + + // Compute how much of the 4x4 block of destination 8bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 4x4, there are some 4x4 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + "sub w2, %w[dst_cols], %w[col]\n" + "mov w3, #4\n" + "cmp w1, #4\n" + // Compute w1 = how many rows of the 4x4 block fit + "csel w1, w1, w3, le\n" + "cmp w2, #4\n" + // Compute w2 = how many cols of the 4x4 block fit + "csel w2, w2, w3, le\n" + + // Test if w1==4 && w2 == 4, i.e. if all of the 4x4 block fits. + "cmp w1, w3\n" + "ccmp w2, w3, 0, eq\n" + "mov x4, %[dst_ptr]\n" + // Yes, all of the 4x4 block fits, go to fast path. + "beq 30f\n" + // Not all of the 4x4 block fits. + // Store to dst_tmp_buf + "st1 {v16.16b}, [%[dst_tmp_buf]]\n" + // Slow loop copying from dst_tmp_buf to dst. + "mov x3, %[dst_tmp_buf]\n" + "mov w6, #0\n" + "50:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov w5, #0\n" + "51:\n" + "ldrb w7, [x3, w5, uxtw]\n" + "strb w7, [x4, w5, uxtw]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 51b\n" + "add w6, w6, #1\n" + "add x3, x3, #4\n" + "add x4, x4, x11\n" + "cmp w6, w2\n" + "blt 50b\n" + "b 31f\n" + "30:\n" + // Yes, all of the 4x4 block fits. + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v16.b}[0], [x3], #1\n" + "add x4, x4, x11\n" + "st1 {v16.b}[1], [x3], #1\n" + "st1 {v16.b}[2], [x3], #1\n" + "st1 {v16.b}[3], [x3], #1\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v16.b}[4], [x3], #1\n" + "add x4, x4, x11\n" + "st1 {v16.b}[5], [x3], #1\n" + "st1 {v16.b}[6], [x3], #1\n" + "st1 {v16.b}[7], [x3], #1\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v16.b}[8], [x3], #1\n" + "add x4, x4, x11\n" + "st1 {v16.b}[9], [x3], #1\n" + "st1 {v16.b}[10], [x3], #1\n" + "st1 {v16.b}[11], [x3], #1\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v16.b}[12], [x3], #1\n" + "add x4, x4, x11\n" + "st1 {v16.b}[13], [x3], #1\n" + "st1 {v16.b}[14], [x3], #1\n" + "st1 {v16.b}[15], [x3], #1\n" + "31:\n" + + "add %[dst_ptr], %[dst_ptr], #4\n" + + RUY_MAKE_ZERO(v16) + RUY_MAKE_ZERO(v17) + + "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n" + + // Cast-and-saturate from int32 to int16 + "sqxtn v16.4h, v16.4s\n" + "sqxtn2 v16.8h, v17.4s\n" + "sqxtn v17.4h, v18.4s\n" + "sqxtn2 v17.8h, v19.4s\n" + + // At this point, v18 -- v31 aren't used anymore for the current block, + // so we can start clearing these accumulators for the next block + // (next iteration of the main loop). + RUY_MAKE_ZERO(v18) + RUY_MAKE_ZERO(v19) + RUY_MAKE_ZERO(v20) + RUY_MAKE_ZERO(v21) + RUY_MAKE_ZERO(v22) + RUY_MAKE_ZERO(v23) + RUY_MAKE_ZERO(v24) + RUY_MAKE_ZERO(v25) + RUY_MAKE_ZERO(v26) + RUY_MAKE_ZERO(v27) + RUY_MAKE_ZERO(v28) + RUY_MAKE_ZERO(v29) + RUY_MAKE_ZERO(v30) + RUY_MAKE_ZERO(v31) + + // Add the destination zero point + "dup v14.8h, v13.h[4]\n" + "add v16.8h, v16.8h, v14.8h\n" + "add v17.8h, v17.8h, v14.8h\n" + + // Cast-and-saturate from int16 to int8 + "sqxtn v16.8b, v16.8h\n" + "sqxtn2 v16.16b, v17.8h\n" + + // Load the clamp_min, clamp_max bounds + "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + "dup v14.16b, w2\n" // clamp_min + "dup v15.16b, w3\n" // clamp_max + + // Apply the clamp_min bound + "smax v16.16b, v16.16b, v14.16b\n" + // Apply the clamp_max bound + "smin v16.16b, v16.16b, v15.16b\n" + + // Compute how much of the 4x4 block of destination 8bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 4x4, there are some 4x4 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + "sub w2, %w[dst_cols], %w[col]\n" + "mov w3, #4\n" + "cmp w1, #4\n" + // Compute w1 = how many rows of the 4x4 block fit + "csel w1, w1, w3, le\n" + "cmp w2, #4\n" + // Compute w2 = how many cols of the 4x4 block fit + "csel w2, w2, w3, le\n" + + // Test if w1==4 && w2 == 4, i.e. if all of the 4x4 block fits. + "cmp w1, w3\n" + "ccmp w2, w3, 0, eq\n" + "mov x4, %[dst_ptr]\n" + // Yes, all of the 4x4 block fits, go to fast path. + "beq 30f\n" + // Not all of the 4x4 block fits. + // Store to dst_tmp_buf + "st1 {v16.16b}, [%[dst_tmp_buf]]\n" + // Slow loop copying from dst_tmp_buf to dst. + "mov x3, %[dst_tmp_buf]\n" + "mov w6, #0\n" + "50:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov w5, #0\n" + "51:\n" + "ldrb w7, [x3, w5, uxtw]\n" + "strb w7, [x4, w5, uxtw]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 51b\n" + "add w6, w6, #1\n" + "add x3, x3, #4\n" + "add x4, x4, x11\n" + "cmp w6, w2\n" + "blt 50b\n" + "b 31f\n" + "30:\n" + // Yes, all of the 4x4 block fits. + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v16.b}[0], [x3], #1\n" + "add x4, x4, x11\n" + "st1 {v16.b}[1], [x3], #1\n" + "st1 {v16.b}[2], [x3], #1\n" + "st1 {v16.b}[3], [x3], #1\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v16.b}[4], [x3], #1\n" + "add x4, x4, x11\n" + "st1 {v16.b}[5], [x3], #1\n" + "st1 {v16.b}[6], [x3], #1\n" + "st1 {v16.b}[7], [x3], #1\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v16.b}[8], [x3], #1\n" + "add x4, x4, x11\n" + "st1 {v16.b}[9], [x3], #1\n" + "st1 {v16.b}[10], [x3], #1\n" + "st1 {v16.b}[11], [x3], #1\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v16.b}[12], [x3], #1\n" + "add x4, x4, x11\n" + "st1 {v16.b}[13], [x3], #1\n" + "st1 {v16.b}[14], [x3], #1\n" + "st1 {v16.b}[15], [x3], #1\n" + "31:\n" + + "add %[dst_ptr], %[dst_ptr], #4\n" + + RUY_MAKE_ZERO(v16) + RUY_MAKE_ZERO(v17) + + "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n" + + // Add the destination zero point + "dup v14.4h, v13.h[4]\n" + "saddw v16.4s, v16.4s, v14.4h\n" + "saddw v17.4s, v17.4s, v14.4h\n" + "saddw v18.4s, v18.4s, v14.4h\n" + "saddw v19.4s, v19.4s, v14.4h\n" + + // Cast-and-saturate from int32 to int16 + "sqxtn v16.4h, v16.4s\n" + "sqxtn2 v16.8h, v17.4s\n" + "sqxtn v17.4h, v18.4s\n" + "sqxtn2 v17.8h, v19.4s\n" + + // At this point, v18 -- v31 aren't used anymore for the current block, + // so we can start clearing these accumulators for the next block + // (next iteration of the main loop). + RUY_MAKE_ZERO(v18) + RUY_MAKE_ZERO(v19) + RUY_MAKE_ZERO(v20) + RUY_MAKE_ZERO(v21) + RUY_MAKE_ZERO(v22) + RUY_MAKE_ZERO(v23) + RUY_MAKE_ZERO(v24) + RUY_MAKE_ZERO(v25) + RUY_MAKE_ZERO(v26) + RUY_MAKE_ZERO(v27) + RUY_MAKE_ZERO(v28) + RUY_MAKE_ZERO(v29) + RUY_MAKE_ZERO(v30) + RUY_MAKE_ZERO(v31) + + // Load the clamp_min, clamp_max bounds + "ldrh w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + "ldrh w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + "dup v14.8h, w2\n" // clamp_min + "dup v15.8h, w3\n" // clamp_max + + // Apply the clamp_min bound + "smax v16.8h, v16.8h, v14.8h\n" + "smax v17.8h, v17.8h, v14.8h\n" + // Apply the clamp_max bound + "smin v16.8h, v16.8h, v15.8h\n" + "smin v17.8h, v17.8h, v15.8h\n" + + // Compute how much of the 4x4 block of destination 8bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 4x4, there are some 4x4 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + "sub w2, %w[dst_cols], %w[col]\n" + "mov w3, #4\n" + "cmp w1, #4\n" + // Compute w1 = how many rows of the 4x4 block fit + "csel w1, w1, w3, le\n" + "cmp w2, #4\n" + // Compute w2 = how many cols of the 4x4 block fit + "csel w2, w2, w3, le\n" + + // Test if w1==4 && w2 == 4, i.e. if all of the 8x8 block fits. + "cmp w1, w3\n" + "ccmp w2, w3, 0, eq\n" + "mov x4, %[dst_ptr]\n" + // Yes, all of the 4x4 block fits, go to fast path. + "beq 30f\n" + // Not all of the 4x4 block fits. + // Store to dst_tmp_buf + "str q16, [%[dst_tmp_buf], #0]\n" + "str q17, [%[dst_tmp_buf], #16]\n" + // Slow loop copying from dst_tmp_buf to dst. + "mov x3, %[dst_tmp_buf]\n" + "mov w6, #0\n" + "50:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov w5, #0\n" + "51:\n" + "ldrh w7, [x3, x5, lsl #1]\n" + "strh w7, [x4, x5, lsl #1]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 51b\n" + "add w6, w6, #1\n" + "add x3, x3, #8\n" + "add x4, x4, x11\n" + "cmp w6, w2\n" + "blt 50b\n" + "b 31f\n" + "30:\n" + // Yes, all of the 4x4 block fits. + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v16.h}[0], [x3], #2\n" + "add x4, x4, x11\n" + "st1 {v16.h}[1], [x3], #2\n" + "st1 {v16.h}[2], [x3], #2\n" + "st1 {v16.h}[3], [x3], #2\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v16.h}[4], [x3], #2\n" + "add x4, x4, x11\n" + "st1 {v16.h}[5], [x3], #2\n" + "st1 {v16.h}[6], [x3], #2\n" + "st1 {v16.h}[7], [x3], #2\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v17.h}[0], [x3], #2\n" + "add x4, x4, x11\n" + "st1 {v17.h}[1], [x3], #2\n" + "st1 {v17.h}[2], [x3], #2\n" + "st1 {v17.h}[3], [x3], #2\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v17.h}[4], [x3], #2\n" + "add x4, x4, x11\n" + "st1 {v17.h}[5], [x3], #2\n" + "st1 {v17.h}[6], [x3], #2\n" + "st1 {v17.h}[7], [x3], #2\n" + "31:\n" + + "add %[dst_ptr], %[dst_ptr], #8\n" + + RUY_MAKE_ZERO(v16) + RUY_MAKE_ZERO(v17) + + "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n" + + // Since the store type is the same as the accum type, no need for + // downcast. There's also no need for clamp by min/max. + + // At this point, v20 -- v31 aren't used anymore for the current block, + // so we can start clearing these accumulators for the next block + // (next iteration of the main loop). + RUY_MAKE_ZERO(v20) + RUY_MAKE_ZERO(v21) + RUY_MAKE_ZERO(v22) + RUY_MAKE_ZERO(v23) + RUY_MAKE_ZERO(v24) + RUY_MAKE_ZERO(v25) + RUY_MAKE_ZERO(v26) + RUY_MAKE_ZERO(v27) + RUY_MAKE_ZERO(v28) + RUY_MAKE_ZERO(v29) + RUY_MAKE_ZERO(v30) + RUY_MAKE_ZERO(v31) + + // Compute how much of the 4x4 block of destination 8bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 4x4, there are some 4x4 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + "sub w2, %w[dst_cols], %w[col]\n" + "mov w3, #4\n" + "cmp w1, #4\n" + // Compute w1 = how many rows of the 4x4 block fit + "csel w1, w1, w3, le\n" + "cmp w2, #4\n" + // Compute w2 = how many cols of the 4x4 block fit + "csel w2, w2, w3, le\n" + + // Test if w1==4 && w2 == 4, i.e. if all of the 8x8 block fits. + "cmp w1, w3\n" + "ccmp w2, w3, 0, eq\n" + "mov x4, %[dst_ptr]\n" + // Yes, all of the 4x4 block fits, go to fast path. + "beq 30f\n" + // Not all of the 4x4 block fits. + // Store to dst_tmp_buf + "str q16, [%[dst_tmp_buf], #0]\n" + "str q17, [%[dst_tmp_buf], #16]\n" + "str q18, [%[dst_tmp_buf], #32]\n" + "str q19, [%[dst_tmp_buf], #48]\n" + // Slow loop copying from dst_tmp_buf to dst. + "mov x3, %[dst_tmp_buf]\n" + "mov w6, #0\n" + "50:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov w5, #0\n" + "51:\n" + "ldr w7, [x3, x5, lsl #2]\n" + "str w7, [x4, x5, lsl #2]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 51b\n" + "add w6, w6, #1\n" + "add x3, x3, #16\n" + "add x4, x4, x11\n" + "cmp w6, w2\n" + "blt 50b\n" + "b 31f\n" + "30:\n" + // Yes, all of the 4x4 block fits. + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v16.s}[0], [x3], #4\n" + "add x4, x4, x11\n" + "st1 {v16.s}[1], [x3], #4\n" + "st1 {v16.s}[2], [x3], #4\n" + "st1 {v16.s}[3], [x3], #4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v17.s}[0], [x3], #4\n" + "add x4, x4, x11\n" + "st1 {v17.s}[1], [x3], #4\n" + "st1 {v17.s}[2], [x3], #4\n" + "st1 {v17.s}[3], [x3], #4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v18.s}[0], [x3], #4\n" + "add x4, x4, x11\n" + "st1 {v18.s}[1], [x3], #4\n" + "st1 {v18.s}[2], [x3], #4\n" + "st1 {v18.s}[3], [x3], #4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v19.s}[0], [x3], #4\n" + "add x4, x4, x11\n" + "st1 {v19.s}[1], [x3], #4\n" + "st1 {v19.s}[2], [x3], #4\n" + "st1 {v19.s}[3], [x3], #4\n" + "31:\n" + + "add %[dst_ptr], %[dst_ptr], #16\n" + + RUY_MAKE_ZERO(v16) + RUY_MAKE_ZERO(v17) + RUY_MAKE_ZERO(v18) + RUY_MAKE_ZERO(v19) + + RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n" + + // For the next block: perform the first few multiply-adds on the data + // that we have already loaded. + "smull v8.8h, v0.8b, v4.8b\n" + "smull v9.8h, v1.8b, v4.8b\n" + "smull v10.8h, v2.8b, v4.8b\n" + "smull v11.8h, v3.8b, v4.8b\n" + "smull v12.8h, v0.8b, v5.8b\n" + "smull v13.8h, v1.8b, v5.8b\n" + "smull v14.8h, v2.8b, v5.8b\n" + "smull v15.8h, v3.8b, v5.8b\n" + "smlal2 v8.8h, v0.16b, v4.16b\n" + "smlal2 v9.8h, v1.16b, v4.16b\n" + "smlal2 v10.8h, v2.16b, v4.16b\n" + "smlal2 v11.8h, v3.16b, v4.16b\n" + "smlal2 v12.8h, v0.16b, v5.16b\n" + "smlal2 v13.8h, v1.16b, v5.16b\n" + "smlal2 v14.8h, v2.16b, v5.16b\n" + "smlal2 v15.8h, v3.16b, v5.16b\n" + + // Reload some params --- we had used x5 -- x7 for a few other things + // since the last time we had loaded them. + "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" + "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" + + // Move to the next block of the destination matrix, for the next iter + // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already + // been updated earlier. + // Have we reached the end row? + "cmp %w[row], w7\n" + "beq 20f\n" // yes, end row. + // Not end row. Move to the next row. + "add %w[row], %w[row], #4\n" + "b 21f\n" + "20:\n" + // Was already at end row. + "mov %w[row], w6\n" // Move back to first row. + "add %w[col], %w[col], #4\n" // Move to the next column. + "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #2\n" + "mov %[dst_ptr], %[dst_col_ptr]\n" + "21:\n" + + // Main loop exit condition: have we hit the end column? + "cmp %w[col], w8\n" + + // w1 is the number of levels of depth that we have already loaded + // LHS and RHS data for. Corresponding to the initial ld1 instructions + // above, this is currently 4. + "mov w1, #16\n" + + "ble 1b\n" + + // clang-format on + + : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr), + [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), + [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col) + : [ params ] "r"(¶ms), [dst_rows] "r"(params.dst_rows), + [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf), + [dst_type_id] "r"(params.dst_type_id) + : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc", + "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", + "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31"); +} + +// Similar to existing Kernel8bitNeonOutOfOrder but specialized for the case of +// RHS cols == 1. +// Relevant target CPUs for this kernel include ARM Cortex-A73 and Cortex-A75, +// since these are 64-bit, out-of-order and without dotprod support. +void Kernel8bitNeonOutOfOrder1Col(const KernelParams8bit<4, 4>& params) { + profiler::ScopeLabel label( + "Kernel (kNeon, optimized for out-of-order cores)"); + + CheckOffsetsInKernelParams8bit(params); + + const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; + const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; + const std::int8_t* lhs_ptr = lhs_col_ptr; + const std::int8_t* rhs_ptr = rhs_col_ptr; + void* dst_col_ptr = params.dst_base_ptr; + void* dst_ptr = dst_col_ptr; + int row = params.start_row; + int col = params.start_col; + + // The asm kernel below has the following NEON register allocation: + // + // v16 -- v19 are int32 accumulators. + // During accumulation, v0 -- v3 are used to load int8 data from LHS and + // v4 from RHS: + // + // int8 RHS 16x1 block + // /-----------\ + // |v4.b[0] | + // | ... | + // |v4.b[15] | + // \-----------/ + // int8 LHS 4x16 block + // /---------------------\ /-----------\ + // |v0.b[0] ... v0.b[15] | |v16.4s | + // |v1.b[0] ... v1.b[15] | |v17.4s | + // |v2.b[0] ... v2.b[15] | |v18.4s | + // |v3.b[0] ... v3.b[15] | |v19.4s | + // \---------------------/ \-----------/ + // int32 accumulators 4x1 block + // + // No attempt had been made so far at implementing the RUY_OPT_MAX_STREAMING + // optimization for this kernel. + asm volatile( +#define RUY_MAKE_ZERO(reg) "dup " #reg ".4s, wzr\n" + + // clang-format off + + // Load some parameters into registers. + "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" + "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" + "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" + "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" + "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" + "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" + "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" + + // Load the first 64 bytes of LHS and RHS data. + "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v2.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v3.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v4.16b}, [%[rhs_ptr]], #16\n" + "add %[rhs_ptr], %[rhs_ptr], #48\n" + + // Clear accumulators. + RUY_MAKE_ZERO(v16) + RUY_MAKE_ZERO(v17) + RUY_MAKE_ZERO(v18) + RUY_MAKE_ZERO(v19) + + // w1 is the number of levels of depth that we have already loaded + // LHS and RHS data for. Corresponding to the initial ld1 instructions + // above, this is currently 16. + "mov w1, #16\n" + + // Perform the first few multiply-adds on the data that we have already + // loaded. + "smull v8.8h, v0.8b, v4.8b\n" + "smull v9.8h, v1.8b, v4.8b\n" + "smull v10.8h, v2.8b, v4.8b\n" + "smull v11.8h, v3.8b, v4.8b\n" + + // Multiply-accumulate second-half, again into the same + // 16bit local accumulator registers. This is where we + // take advantage of having int8 instead of uint8 and therefore + // being able to accumulate two products into int16. + "smlal2 v8.8h, v0.16b, v4.16b\n" + "smlal2 v9.8h, v1.16b, v4.16b\n" + "smlal2 v10.8h, v2.16b, v4.16b\n" + "smlal2 v11.8h, v3.16b, v4.16b\n" + + // Main loop of the whole GEMM, over rows and columns of the + // destination matrix. + "1:\n" + + // Reminder - w1 is how many levels of depth we have already loaded + // data for, w12 is the total depth. + "cmp w1, w12\n" + "beq 79f\n" + + "2:\n" + + // Some multiplications and 16-bit accumulation were already done above, + // so we start right away in the middle. + "sadalp v16.4s, v8.8h\n" + "ld1 {v4.16b}, [%[rhs_ptr]], #16\n" + "add %[rhs_ptr], %[rhs_ptr], #48\n" + "sadalp v17.4s, v9.8h\n" + "sadalp v18.4s, v10.8h\n" + "sadalp v19.4s, v11.8h\n" + + "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v2.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v3.16b}, [%[lhs_ptr]], #16\n" + + "smull v8.8h, v0.8b, v4.8b\n" + "smull v9.8h, v1.8b, v4.8b\n" + "smull v10.8h, v2.8b, v4.8b\n" + "smull v11.8h, v3.8b, v4.8b\n" + + // Multiply-accumulate second-half, again into the same + // 16bit local accumulator registers. This is where we + // take advantage of having int8 instead of uint8 and therefore + // being able to accumulate two products into int16. + "smlal2 v8.8h, v0.16b, v4.16b\n" + "smlal2 v9.8h, v1.16b, v4.16b\n" + "smlal2 v10.8h, v2.16b, v4.16b\n" + "smlal2 v11.8h, v3.16b, v4.16b\n" + + // Each iteration of this loop advances by 16 levels of depth. + "add w1, w1, #16\n" + + // Loop termination condition + "cmp w1, w12\n" + + "blt 2b\n" + + "79:\n" + + "sadalp v16.4s, v8.8h\n" + "sadalp v17.4s, v9.8h\n" + "sadalp v18.4s, v10.8h\n" + "sadalp v19.4s, v11.8h\n" + + // End of accumulation. The registers v16 -- v19 contain the final + // int32 accumulator values of the current 4x1 destination block. + // We now have to compute the final 8-bit values from these int32 + // accumulators, and advance to the next 4x1 block. We intertwine + // these two aspects whenever possible for optimal pipelining, both + // at the data flow level (prefetch data for next block as early as + // possible) and instruction pipelining level (some of the next-block + // work can dual-issue with some of the final work on the current + // block). + + // Reduce 32bit accumulators horizontally. + "addp v16.4s, v16.4s, v17.4s\n" + "addp v18.4s, v18.4s, v19.4s\n" + + // Reduce 32bit accumulators horizontally, second pass + // (each pass adds pairwise. we need to add 4-wise). + "addp v16.4s, v16.4s, v18.4s\n" + + // Logic to advance to the next block in preparation for the next + // iteration of the main loop. For now, we only want to compute + // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are + // not yet ready to update the values of row and col, as we still need + // the current values for the rest of the work on the current block. + + "cmp %w[row], w7\n" // Have we finished the last row? + "bge 4f\n" // If finished last row, go to 4 + // Not finished last row: then advance to next row. + "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #2\n" + "b 5f\n" + "4:\n" // Finished last row... + "mov %[lhs_col_ptr], x5\n" // Go back to first row + // Now we need to advance to the next column. If we already + // finished the last column, then in principle we are done, however + // we can't just return here, as we need to allow the end work of the + // current block to complete. The good news is that at this point it + // doesn't matter what data we load for the next column, since + // we will exit from the main loop below before actually storing + // anything computed from that data. + "cmp %w[col], w8\n" // Have we finished the last column? + "bge 5f\n" // If yes, just carry on without updating the column pointer. + // Not finished last column: then advance to next column. + // (still multiply column stride by 4 due to packing) + "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #2\n" + "5:\n" + + // Set the LHS and RHS data pointers to the start of the columns just + // computed. + "mov %[lhs_ptr], %[lhs_col_ptr]\n" + "mov %[rhs_ptr], %[rhs_col_ptr]\n" + + // Load some parameters needed for the end work on current block. + RUY_MAKE_ZERO(v8) + "ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" + "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n" + "ins v13.h[4], w4\n" // dst_zero_point + "ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n" + "ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" + "dup v9.4s, w3\n" // create prod_zp_depth_vec + "add x5, x4, %x[row], lsl #2\n" + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" + "csel x4, x4, x5, eq\n" + + "ld1 {v15.4s}, [x4]\n" // multiplier_fixedpoint + + // Now we load: bias data, LHS sums data, RHS sums data. + + // First, load the base pointers from the params. + "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" + + "add x5, x1, %x[row], lsl #2\n" + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" + "csel x1, x1, x5, eq\n" + + // Load 4 bias values. + "ld1 {v14.4s}, [x1]\n" + + // Now that we know what LHS and RHS data the next iteration of the + // main loop will need to load, we start loading the first 32 bytes of + // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore + // in the rest of the work on the current block. + "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v2.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v3.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v4.16b}, [%[rhs_ptr]], #16\n" + "add %[rhs_ptr], %[rhs_ptr], #48\n" + + // Add to the bias values the product (depth * lhs_zero_point * rhs_zero_point), + // See the term NZ1Z2 in equation (7) in https://arxiv.org/pdf/1712.05877.pdf + "add v14.4s, v14.4s, v9.4s\n" + + // Perform the bias-addition (per the above, we have just folded into + // the bias the (depth * lhs_zero_point * rhs_zero_point) term.) + // (all four 32-bit accumulators are in v16 at this point) + "add v16.4s, v16.4s, v14.4s\n" + + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n" + "beq 401f\n" + "ldr x3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n" + "add x3, x3, %x[col], lsl #2\n" + "ld1 {v14.4s}, [x3]\n" + "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n" + "dup v10.4s, w5\n" // create lhs_zero_point_vec + // Subtract rhs_sums * lhs_zero_point, per + // equation (7) in https://arxiv.org/pdf/1712.05877.pdf + "mls v16.4s, v10.4s, v14.s[0]\n" + "401:\n" + + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n" + "beq 402f\n" + "ldr x2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n" + "add x2, x2, %x[row], lsl #2\n" + "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n" + // Load 4 lhs_sums values. + "ld1 {v11.4s}, [x2]\n" + "ins v13.s[1], w5\n" // rhs_zero_point + // Compute lhs_sums * rhs_zero_point. + "mul v11.4s, v11.4s, v13.s[1]\n" + // Subtract lhs_sums * rhs_zero_point, per + // equation (7) in https://arxiv.org/pdf/1712.05877.pdf + "sub v16.4s, v16.4s, v11.4s\n" + + // If the destination is int32, it means the user asks for the raw + // accumulators, no need for us to downquantize the value. + "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n" + "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n" + + "402:\n" + + // At this point we have computed the final int32 values. Now we + // start down-quantizing them to obtain the final 8bit values from them. + + // As part of this down-quantization, our int32 values will be + // multiplied by a multiplier that has a fixed-point component and an + // exponent component. + + //Load the exponent part of the multiplier. + "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n" + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" + "add x5, x1, %x[row], lsl #2\n" + "csel x1, x1, x5, eq\n" + + "ld1 {v14.4s}, [x1]\n" + + "smax v12.4s, v14.4s, v8.4s\n" + + "sshl v16.4s, v16.4s, v12.4s\n" + + "smin v12.4s, v14.4s, v8.4s\n" + + // Apply the fixed-point part of the multiplier. + "sqrdmulh v16.4s, v16.4s, v15.4s\n" + + // We have some rounding division-by-power-of-two to do. This should + // always use "round to nearest". We allow for some + // freedom in how ties are broken, to strike a good compromise of + // performance on given hardware vs. perfect agreement of results + // across hardware. + // + // When RUY_OPT_NATIVE_ROUNDING is enabled, we allow for implementation + // defined tie-breaks to help performance. On NEON, this means that we + // can just use the NEON rounding instructions, such as srshl. They + // happen to be breaking ties upward. + // + // When RUY_OPT_NATIVE_ROUNDING is disabled, we implement strict + // break-ties-away-from zero, as described in Appendix B of + // https://arxiv.org/pdf/1712.05877.pdf + // When we wrote that, we thought that that would be better unbiased + // than the NEON upwards tie-breaks, and we had observed some + // improvement on some model. However, that is only more unbiased for + // data centered at zero, which was likely the case in that model, + // but is not always the case. If we wanted something more consistently + // unbiased then we should try breaking ties toward-nearest-even. +#if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING) + // Fix up values to be right-shifted, so that the (round to nearest, + // break ties upward) behavior of srshl applied to these fixed-up + // values, produces the same result as the desired (round to nearest, + // break ties away from zero) behavior on the original values. + "and v8.16b, v16.16b, v12.16b\n" + "sshr v8.4s, v8.4s, #31\n" + "sqadd v16.4s, v16.4s, v8.4s\n" +#endif + // At this point we have reduced the problem of correctly implementing + // rounding divide-by-power-of-two, to what the SRSHL instruction can + // do. + "srshl v16.4s, v16.4s, v12.4s\n" + + "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n" + "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n" + "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n" + "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n" + + // Cast-and-saturate from int32 to int16 + // After this instruction, all data is in lower half (64-bits) of v16 + "sqxtn v16.4h, v16.4s\n" + + // At this point, v18 -- v31 aren't used anymore for the current block, + // so we can start clearing these accumulators for the next block + // (next iteration of the main loop). + RUY_MAKE_ZERO(v18) + RUY_MAKE_ZERO(v19) + + // Add the destination zero point + "dup v14.8h, v13.h[4]\n" + "add v16.8h, v16.8h, v14.8h\n" + + // Cast-and-saturate from int16 to uint8 + // Now all data is in the first 32-bits of v16 + "sqxtun v16.8b, v16.8h\n" + + // Load the clamp_min, clamp_max bounds + "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + "dup v14.16b, w2\n" // clamp_min + "dup v15.16b, w3\n" // clamp_max + + // Apply the clamp_min bound + "umax v16.16b, v16.16b, v14.16b\n" + // Apply the clamp_max bound + "umin v16.16b, v16.16b, v15.16b\n" + + // Compute how much of the 4x1 block of destination 8bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 4x1, there are some 4x1 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + "mov w3, #4\n" + "cmp w1, #4\n" + // Compute w1 = how many rows of the 4x1 block fit + "csel w1, w1, w3, le\n" + + // Test if w1==4, i.e. if all of the 4x1 block fits. + "cmp w1, w3\n" + + "mov x4, %[dst_ptr]\n" + // Yes, all of the 4x1 block fits, go to fast path. + "beq 30f\n" + // Not all of the 4x1 block fits. + // Store to dst_tmp_buf + "st1 {v16.16b}, [%[dst_tmp_buf]]\n" + // Slow loop copying from dst_tmp_buf to dst. + "mov x3, %[dst_tmp_buf]\n" + "mov w6, #0\n" + "50:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov w5, #0\n" + "51:\n" + "ldrb w7, [x3, w5, uxtw]\n" + "strb w7, [x4, w5, uxtw]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 51b\n" + "b 31f\n" + "30:\n" + // Yes, all of the 4x1 block fits. + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v16.b}[0], [x3], #1\n" + "st1 {v16.b}[1], [x3], #1\n" + "st1 {v16.b}[2], [x3], #1\n" + "st1 {v16.b}[3], [x3], #1\n" + "31:\n" + + "add %[dst_ptr], %[dst_ptr], #4\n" + + RUY_MAKE_ZERO(v16) + RUY_MAKE_ZERO(v17) + + "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n" + + // Cast-and-saturate from int32 to int16 + // After this, all values for output are in the lower half (64 bits) of v16. + "sqxtn v16.4h, v16.4s\n" + + // At this point, v18 -- v31 aren't used anymore for the current block, + // so we can start clearing these accumulators for the next block + // (next iteration of the main loop). + RUY_MAKE_ZERO(v18) + RUY_MAKE_ZERO(v19) + + // Add the destination zero point + "dup v14.8h, v13.h[4]\n" + "add v16.8h, v16.8h, v14.8h\n" + + // Cast-and-saturate from int16 to int8 + "sqxtn v16.8b, v16.8h\n" + // At this point, we only need 4 lowest 8-bit values in v16. + + // Load the clamp_min, clamp_max bounds + "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + "dup v14.16b, w2\n" // clamp_min + "dup v15.16b, w3\n" // clamp_max + + // Apply the clamp_min bound + "smax v16.16b, v16.16b, v14.16b\n" + // Apply the clamp_max bound + "smin v16.16b, v16.16b, v15.16b\n" + + // Compute how much of the 4x4 block of destination 8bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 4x4, there are some 4x4 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + "sub w2, %w[dst_cols], %w[col]\n" + "mov w3, #4\n" + "cmp w1, #4\n" + // Compute w1 = how many rows of the 4x1 block fit + "csel w1, w1, w3, le\n" + "cmp w2, #4\n" + + // Test if w1==4, i.e. if all of the 4x1 block fits. + "cmp w1, w3\n" + "ccmp w2, w3, 0, eq\n" + "mov x4, %[dst_ptr]\n" + // Yes, all of the 4x1 block fits, go to fast path. + "beq 30f\n" + // Not all of the 4x4 block fits. + // Store to dst_tmp_buf + "st1 {v16.16b}, [%[dst_tmp_buf]]\n" + // Slow loop copying from dst_tmp_buf to dst. + "mov x3, %[dst_tmp_buf]\n" + "mov w6, #0\n" + "50:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov w5, #0\n" + "51:\n" + "ldrb w7, [x3, w5, uxtw]\n" + "strb w7, [x4, w5, uxtw]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 51b\n" + "b 31f\n" + "30:\n" + // Yes, all of the 4x4 block fits. + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v16.b}[0], [x3], #1\n" + "st1 {v16.b}[1], [x3], #1\n" + "st1 {v16.b}[2], [x3], #1\n" + "st1 {v16.b}[3], [x3], #1\n" + "31:\n" + + "add %[dst_ptr], %[dst_ptr], #4\n" + + RUY_MAKE_ZERO(v16) + RUY_MAKE_ZERO(v17) + + "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n" + + // Add the destination zero point + "dup v14.4h, v13.h[4]\n" + "saddw v16.4s, v16.4s, v14.4h\n" + + // Cast-and-saturate from int32 to int16 + // After this instruction, all data is in lower half of v16. + "sqxtn v16.4h, v16.4s\n" + + // At this point, v18 -- v31 aren't used anymore for the current block, + // so we can start clearing these accumulators for the next block + // (next iteration of the main loop). + RUY_MAKE_ZERO(v18) + RUY_MAKE_ZERO(v19) + + // Load the clamp_min, clamp_max bounds + "ldrh w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + "ldrh w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + "dup v14.8h, w2\n" // clamp_min + "dup v15.8h, w3\n" // clamp_max + + // Apply the clamp_min bound + "smax v16.8h, v16.8h, v14.8h\n" + // Apply the clamp_max bound + "smin v16.8h, v16.8h, v15.8h\n" + + // Compute how much of the 4x4 block of destination 8bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 4x4, there are some 4x4 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + "sub w2, %w[dst_cols], %w[col]\n" + "mov w3, #4\n" + "cmp w1, #4\n" + // Compute w1 = how many rows of the 4x4 block fit + "csel w1, w1, w3, le\n" + "cmp w2, #4\n" + + // Test if w1==4 && w2 == 4, i.e. if all of the 8x8 block fits. + "cmp w1, w3\n" + "mov x4, %[dst_ptr]\n" + // Yes, all of the 4x4 block fits, go to fast path. + "beq 30f\n" + // Not all of the 4x4 block fits. + // Store to dst_tmp_buf + "str q16, [%[dst_tmp_buf], #0]\n" + // Slow loop copying from dst_tmp_buf to dst. + "mov x3, %[dst_tmp_buf]\n" + "mov w6, #0\n" + "50:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov w5, #0\n" + "51:\n" + "ldrh w7, [x3, x5, lsl #1]\n" + "strh w7, [x4, x5, lsl #1]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 51b\n" + "blt 50b\n" + "b 31f\n" + "30:\n" + // Yes, all of the 4x4 block fits. + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v16.h}[0], [x3], #2\n" + "st1 {v16.h}[1], [x3], #2\n" + "st1 {v16.h}[2], [x3], #2\n" + "st1 {v16.h}[3], [x3], #2\n" + "31:\n" + + "add %[dst_ptr], %[dst_ptr], #8\n" + + RUY_MAKE_ZERO(v16) + RUY_MAKE_ZERO(v17) + + "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n" + + // Since the store type is the same as the accum type, no need for + // downcast. There's also no need for clamp by min/max. + + // Compute how much of the 4x4 block of destination 8bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 4x4, there are some 4x4 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + "sub w2, %w[dst_cols], %w[col]\n" + "mov w3, #4\n" + "cmp w1, #4\n" + // Compute w1 = how many rows of the 4x4 block fit + "csel w1, w1, w3, le\n" + "cmp w2, #4\n" + + // Test if w1==4 i.e. if all of the 4x1 block fits. + "cmp w1, w3\n" + "ccmp w2, w3, 0, eq\n" + "mov x4, %[dst_ptr]\n" + // Yes, all of the 4x1 block fits, go to fast path. + "beq 30f\n" + // Not all of the 4x4 block fits. + // Store to dst_tmp_buf + "str q16, [%[dst_tmp_buf], #0]\n" + // Slow loop copying from dst_tmp_buf to dst. + "mov x3, %[dst_tmp_buf]\n" + "mov w6, #0\n" + "50:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov w5, #0\n" + "51:\n" + "ldr w7, [x3, x5, lsl #2]\n" + "str w7, [x4, x5, lsl #2]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 51b\n" + "b 31f\n" + "30:\n" + // Yes, all of the 4x4 block fits. + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v16.s}[0], [x3], #4\n" + "st1 {v16.s}[1], [x3], #4\n" + "st1 {v16.s}[2], [x3], #4\n" + "st1 {v16.s}[3], [x3], #4\n" + "31:\n" + + "add %[dst_ptr], %[dst_ptr], #16\n" + + RUY_MAKE_ZERO(v16) + RUY_MAKE_ZERO(v17) + RUY_MAKE_ZERO(v18) + RUY_MAKE_ZERO(v19) + + RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n" + + // For the next block: perform the first few multiply-adds on the data + // that we have already loaded. + "smull v8.8h, v0.8b, v4.8b\n" + "smull v9.8h, v1.8b, v4.8b\n" + "smull v10.8h, v2.8b, v4.8b\n" + "smull v11.8h, v3.8b, v4.8b\n" + "smlal2 v8.8h, v0.16b, v4.16b\n" + "smlal2 v9.8h, v1.16b, v4.16b\n" + "smlal2 v10.8h, v2.16b, v4.16b\n" + "smlal2 v11.8h, v3.16b, v4.16b\n" + + // Reload some params --- we had used x5 -- x7 for a few other things + // since the last time we had loaded them. + "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" + "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" + + // Move to the next block of the destination matrix, for the next iter + // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already + // been updated earlier. + // Have we reached the end row? + "cmp %w[row], w7\n" + "beq 20f\n" // yes, end row. + // Not end row. Move to the next row. + "add %w[row], %w[row], #4\n" + "b 21f\n" + "20:\n" + // Was already at end row. + "mov %w[row], w6\n" // Move back to first row. + "add %w[col], %w[col], #4\n" // Move to the next column. + "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #2\n" + "mov %[dst_ptr], %[dst_col_ptr]\n" + "21:\n" + + // Main loop exit condition: have we hit the end column? + "cmp %w[col], w8\n" + + // w1 is the number of levels of depth that we have already loaded + // LHS and RHS data for. Corresponding to the initial ld1 instructions + // above, this is currently 16. + "mov w1, #16\n" + + "ble 1b\n" + + // clang-format on + + : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr), + [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), + [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col) + : [ params ] "r"(¶ms), [dst_rows] "r"(params.dst_rows), + [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf), + [dst_type_id] "r"(params.dst_type_id) + : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc", + "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", + "v13", "v14", "v15", "v16", "v17", "v18", "v19"); +} + +// Variant of the above Kernel8bitNeonOutOfOrder, tuned for in-order CPUs. +// Specifically here, the relevant in-order CPUs are ARM Cortex-A53 and +// the original Cortex-A55, since these are 64-bit and do not support dotprod. +// +// While this kernel does not have a direct equivalent in gemmlowp, it was +// developed based on insights that David Mansell at ARM shared with their +// contribution of gemmlowp kernels tuned for Cortex-A53, with very helpful +// comments. Specifically, see this comment about tuning for Cortex-A53: +// https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L4215 +void Kernel8bitNeonInOrder(const KernelParams8bit<4, 4>& params) { + profiler::ScopeLabel label("Kernel (kNeon, optimized for in-order cores)"); + + CheckOffsetsInKernelParams8bit(params); + + const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; + const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; + const std::int8_t* lhs_ptr = lhs_col_ptr; + const std::int8_t* rhs_ptr = rhs_col_ptr; + void* dst_col_ptr = params.dst_base_ptr; + void* dst_ptr = dst_col_ptr; + int row = params.start_row; + int col = params.start_col; + + // The asm kernel below has the following NEON register allocation: + // + // v16 -- v31 are int32 accumulators. + // During accumulation, v0 -- v3 are used to load int8 data from LHS and + // v4 -- v7 from RHS: + // + // int8 RHS 16x4 block + // /-----------------------------------------\ + // |v4.b[0] ... v7.b[0] | + // | ... ... | + // |v4.b[15] ... v7.b[15] | + // \-----------------------------------------/ + // int8 LHS 4x16 block + // /---------------------\ /-----------------------------------------\ + // |v0.b[0] ... v0.b[15] | |v16.4s ... v28.4s | + // |v1.b[0] ... v1.b[15] | |v17.4s ... v29.4s | + // |v2.b[0] ... v2.b[15] | |v18.4s ... v30.4s | + // |v3.b[0] ... v3.b[15] | |v19.4s ... v31.4s | + // \---------------------/ \-----------------------------------------/ + // int32 accumulators 4x4 block + asm volatile( +#define RUY_MAKE_ZERO(reg) "dup " #reg ".4s, wzr\n" + + // clang-format off + + // Load some parameters into registers. + "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + RUY_MAKE_ZERO(v16) + "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" + RUY_MAKE_ZERO(v17) + "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" + RUY_MAKE_ZERO(v18) + "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" + RUY_MAKE_ZERO(v19) + "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" + RUY_MAKE_ZERO(v20) + "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" + RUY_MAKE_ZERO(v21) + "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" + RUY_MAKE_ZERO(v22) + "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" + RUY_MAKE_ZERO(v23) + + // Load the first 64 bytes of LHS and RHS data. + "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" + RUY_MAKE_ZERO(v24) + "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" + RUY_MAKE_ZERO(v25) + "ld1 {v2.16b}, [%[lhs_ptr]], #16\n" + RUY_MAKE_ZERO(v26) + "ld1 {v3.16b}, [%[lhs_ptr]], #16\n" + RUY_MAKE_ZERO(v27) + "ld1 {v4.16b}, [%[rhs_ptr]], #16\n" + RUY_MAKE_ZERO(v28) + "ld1 {v5.16b}, [%[rhs_ptr]], #16\n" + RUY_MAKE_ZERO(v29) + "ld1 {v6.16b}, [%[rhs_ptr]], #16\n" + RUY_MAKE_ZERO(v30) + "ld1 {v7.16b}, [%[rhs_ptr]], #16\n" + RUY_MAKE_ZERO(v31) + + + // w1 is the number of levels of depth that we have already loaded + // LHS and RHS data for. Corresponding to the initial ld1 instructions + // above, this is currently 16. + "mov w1, #16\n" + + // Perform the first few multiply-adds on the data that we have already + // loaded. + "smull v8.8h, v0.8b, v4.8b\n" + "smull v9.8h, v1.8b, v4.8b\n" + "smull v10.8h, v2.8b, v4.8b\n" + "smull v11.8h, v3.8b, v4.8b\n" + "smull v12.8h, v0.8b, v5.8b\n" + "smull v13.8h, v1.8b, v5.8b\n" + "smull v14.8h, v2.8b, v5.8b\n" + "smull v15.8h, v3.8b, v5.8b\n" + + // Multiply-accumulate second-half, again into the same + // 16bit local accumulator registers. This is where we + // take advantage of having int8 instead of uint8 and therefore + // being able to accumulate two products into int16. + "smlal2 v8.8h, v0.16b, v4.16b\n" + "smlal2 v9.8h, v1.16b, v4.16b\n" + "smlal2 v10.8h, v2.16b, v4.16b\n" + "smlal2 v11.8h, v3.16b, v4.16b\n" + "smlal2 v12.8h, v0.16b, v5.16b\n" + "smlal2 v13.8h, v1.16b, v5.16b\n" + "smlal2 v14.8h, v2.16b, v5.16b\n" + "smlal2 v15.8h, v3.16b, v5.16b\n" + + + // Main loop of the whole GEMM, over rows and columns of the + // destination matrix. + "1:\n" + + // Reminder - w1 is how many levels of depth we have already loaded + // data for, w12 is the total depth. + "cmp w1, w12\n" + "beq 79f\n" + + "2:\n" + + // Some multiplications and 16-bit accumulation were already done above, + // so we start right away in the middle. + "sadalp v16.4s, v8.8h\n" + "ldr d4, [%[rhs_ptr], #0]\n" + "smull v8.8h, v0.8b, v6.8b\n" + "ldr x7, [%[rhs_ptr], #8]\n" + "sadalp v17.4s, v9.8h\n" + "ldr d5, [%[rhs_ptr], #16]\n" + "smull v9.8h, v1.8b, v6.8b\n" + "ldr x8, [%[rhs_ptr], #24]\n" + "sadalp v18.4s, v10.8h\n" + "smull v10.8h, v2.8b, v6.8b\n" + "sadalp v19.4s, v11.8h\n" + "add %[lhs_ptr], %[lhs_ptr], #64\n" + "smull v11.8h, v3.8b, v6.8b\n" + "add %[rhs_ptr], %[rhs_ptr], #64\n" + "sadalp v20.4s, v12.8h\n" + // Each iteration of this loop advances by 16 levels of depth. + "add w1, w1, #16\n" + "smull v12.8h, v0.8b, v7.8b\n" + // Loop termination condition + "cmp w1, w12\n" + "sadalp v21.4s, v13.8h\n" + "ldr x3, [%[lhs_ptr], #-56]\n" + "smull v13.8h, v1.8b, v7.8b\n" + "ldr x4, [%[lhs_ptr], #-40]\n" + "sadalp v22.4s, v14.8h\n" + "ldr x5, [%[lhs_ptr], #-24]\n" + "smull v14.8h, v2.8b, v7.8b\n" + "ldr x6, [%[lhs_ptr], #-8]\n" + "sadalp v23.4s, v15.8h\n" + "smull v15.8h, v3.8b, v7.8b\n" + + // Multiply-accumulate second-half, again into the same + // 16bit local accumulator registers. This is where we + // take advantage of having int8 instead of uint8 and therefore + // being able to accumulate two products into int16. + "smlal2 v8.8h, v0.16b, v6.16b\n" + "smlal2 v9.8h, v1.16b, v6.16b\n" + "smlal2 v10.8h, v2.16b, v6.16b\n" + "ldr x9, [%[rhs_ptr], #-24]\n" + "smlal2 v11.8h, v3.16b, v6.16b\n" + "ldr d6, [%[rhs_ptr], #-32]\n" + "smlal2 v12.8h, v0.16b, v7.16b\n" + "ldr d0, [%[lhs_ptr], #-64]\n" + "smlal2 v13.8h, v1.16b, v7.16b\n" + "ldr d1, [%[lhs_ptr], #-48]\n" + "smlal2 v14.8h, v2.16b, v7.16b\n" + "ins v4.d[1], x7\n" + "smlal2 v15.8h, v3.16b, v7.16b\n" + "ins v5.d[1], x8\n" + + "ldr d2, [%[lhs_ptr], #-32]\n" + "ins v0.d[1], x3\n" + "sadalp v24.4s, v8.8h\n" + "ldr d3, [%[lhs_ptr], #-16]\n" + "ins v1.d[1], x4\n" + "smull v8.8h, v0.8b, v4.8b\n" + "ins v2.d[1], x5\n" + "sadalp v25.4s, v9.8h\n" + "ins v3.d[1], x6\n" + "smull v9.8h, v1.8b, v4.8b\n" + "ldr d7, [%[rhs_ptr], #-16]\n" + "sadalp v26.4s, v10.8h\n" + "ldr x10, [%[rhs_ptr], #-8]\n" + "smull v10.8h, v2.8b, v4.8b\n" + "sadalp v27.4s, v11.8h\n" + "smull v11.8h, v3.8b, v4.8b\n" + "sadalp v28.4s, v12.8h\n" + "smull v12.8h, v0.8b, v5.8b\n" + "sadalp v29.4s, v13.8h\n" + "smull v13.8h, v1.8b, v5.8b\n" + "sadalp v30.4s, v14.8h\n" + "smull v14.8h, v2.8b, v5.8b\n" + "sadalp v31.4s, v15.8h\n" + "smull v15.8h, v3.8b, v5.8b\n" + + // Multiply-accumulate second-half, again into the same + // 16bit local accumulator registers. This is where we + // take advantage of having int8 instead of uint8 and therefore + // being able to accumulate two products into int16. + "smlal2 v8.8h, v0.16b, v4.16b\n" + "smlal2 v9.8h, v1.16b, v4.16b\n" + "smlal2 v10.8h, v2.16b, v4.16b\n" + "smlal2 v11.8h, v3.16b, v4.16b\n" + + "smlal2 v12.8h, v0.16b, v5.16b\n" + "smlal2 v13.8h, v1.16b, v5.16b\n" + "ins v6.d[1], x9\n" + "smlal2 v14.8h, v2.16b, v5.16b\n" + "ins v7.d[1], x10\n" + "smlal2 v15.8h, v3.16b, v5.16b\n" + + "blt 2b\n" + + "79:\n" + + "sadalp v16.4s, v8.8h\n" + "smull v8.8h, v0.8b, v6.8b\n" + "sadalp v17.4s, v9.8h\n" + "smull v9.8h, v1.8b, v6.8b\n" + "sadalp v18.4s, v10.8h\n" + "smull v10.8h, v2.8b, v6.8b\n" + "sadalp v19.4s, v11.8h\n" + "smull v11.8h, v3.8b, v6.8b\n" + "sadalp v20.4s, v12.8h\n" + "smull v12.8h, v0.8b, v7.8b\n" + "sadalp v21.4s, v13.8h\n" + "smull v13.8h, v1.8b, v7.8b\n" + "sadalp v22.4s, v14.8h\n" + "smull v14.8h, v2.8b, v7.8b\n" + "sadalp v23.4s, v15.8h\n" + "smull v15.8h, v3.8b, v7.8b\n" + + // Multiply-accumulate second-half, again into the same + // 16bit local accumulator registers. This is where we + // take advantage of having int8 instead of uint8 and therefore + // being able to accumulate two products into int16. + "smlal2 v8.8h, v0.16b, v6.16b\n" + "smlal2 v9.8h, v1.16b, v6.16b\n" + "smlal2 v10.8h, v2.16b, v6.16b\n" + "smlal2 v11.8h, v3.16b, v6.16b\n" + + "smlal2 v12.8h, v0.16b, v7.16b\n" + "smlal2 v13.8h, v1.16b, v7.16b\n" + "smlal2 v14.8h, v2.16b, v7.16b\n" + "smlal2 v15.8h, v3.16b, v7.16b\n" + + "sadalp v24.4s, v8.8h\n" + "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + "sadalp v25.4s, v9.8h\n" + "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" + "sadalp v26.4s, v10.8h\n" + "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" + "sadalp v27.4s, v11.8h\n" + "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" + "sadalp v28.4s, v12.8h\n" + "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" + "sadalp v29.4s, v13.8h\n" + "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" + "sadalp v30.4s, v14.8h\n" + "sadalp v31.4s, v15.8h\n" + + // End of accumulation. The registers v16 -- v31 contain the final + // int32 accumulator values of the current 4x4 destination block. + // We now have to compute the final 8-bit values from these int32 + // accumulators, and advance to the next 4x4 block. We intertwine + // these two aspects whenever possible for optimal pipelining, both + // at the data flow level (prefetch data for next block as early as + // possible) and instruction pipelining level (some of the next-block + // work can dual-issue with some of the final work on the current + // block). + + // Reduce 32bit accumulators horizontally. + "addp v16.4s, v16.4s, v17.4s\n" + "addp v18.4s, v18.4s, v19.4s\n" + "addp v20.4s, v20.4s, v21.4s\n" + "addp v22.4s, v22.4s, v23.4s\n" + "addp v24.4s, v24.4s, v25.4s\n" + "addp v26.4s, v26.4s, v27.4s\n" + "addp v28.4s, v28.4s, v29.4s\n" + "addp v30.4s, v30.4s, v31.4s\n" + + // Reduce 32bit accumulators horizontally, second pass + // (each pass adds pairwise. we need to add 4-wise). + "addp v16.4s, v16.4s, v18.4s\n" + "addp v17.4s, v20.4s, v22.4s\n" + "addp v18.4s, v24.4s, v26.4s\n" + "addp v19.4s, v28.4s, v30.4s\n" + + // Logic to advance to the next block in preparation for the next + // iteration of the main loop. For now, we only want to compute + // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are + // not yet ready to update the values of row and col, as we still need + // the current values for the rest of the work on the current block. + + "cmp %w[row], w7\n" // Have we finished the last row? + "bge 4f\n" // If finished last row, go to 4 + // Not finished last row: then advance to next row. + "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #2\n" + "b 5f\n" + "4:\n" // Finished last row... + "mov %[lhs_col_ptr], x5\n" // Go back to first row + // Now we need to advance to the next column. If we already + // finished the last column, then in principle we are done, however + // we can't just return here, as we need to allow the end work of the + // current block to complete. The good news is that at this point it + // doesn't matter what data we load for the next column, since + // we will exit from the main loop below before actually storing + // anything computed from that data. + "cmp %w[col], w8\n" // Have we finished the last column? + "bge 5f\n" // If yes, just carry on without updating the column pointer. + // Not finished last column: then advance to next column. + "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #2\n" + "5:\n" + + // Set the LHS and RHS data pointers to the start of the columns just + // computed. + "mov %[lhs_ptr], %[lhs_col_ptr]\n" + "mov %[rhs_ptr], %[rhs_col_ptr]\n" + + // Load some parameters needed for the end work on current block. + RUY_MAKE_ZERO(v8) + "ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" + "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n" + "ins v13.h[4], w4\n" // dst_zero_point + "ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n" + "ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" + "dup v9.4s, w3\n" // create prod_zp_depth_vec + "add x5, x4, %x[row], lsl #2\n" + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" + "csel x4, x4, x5, eq\n" + + "ld1 {v15.4s}, [x4]\n" // multiplier_fixedpoint + + "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" + "add x5, x1, %x[row], lsl #2\n" + + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" + "csel x1, x1, x5, eq\n" + + // Load 4 bias values. + "ld1 {v14.4s}, [x1]\n" + + // Now that we know what LHS and RHS data the next iteration of the + // main loop will need to load, we start loading the first 32 bytes of + // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore + // in the rest of the work on the current block. + + // Add to the bias values the product (depth * lhs_zero_point * rhs_zero_point), + // See the term NZ1Z2 in equation (7) in https://arxiv.org/pdf/1712.05877.pdf + "add v14.4s, v14.4s, v9.4s\n" + "ldr d0, [%[lhs_ptr], #0]\n" + + // Perform the bias-addition (per the above, we have just folded into + // the bias the (depth * lhs_zero_point * rhs_zero_point) term.) + "add v16.4s, v16.4s, v14.4s\n" + "ldr d1, [%[lhs_ptr], #16]\n" + "add v17.4s, v17.4s, v14.4s\n" + "ldr d2, [%[lhs_ptr], #32]\n" + "add v18.4s, v18.4s, v14.4s\n" + "ldr d3, [%[lhs_ptr], #48]\n" + "add v19.4s, v19.4s, v14.4s\n" + "ldr d4, [%[rhs_ptr], #0]\n" + "ldr d5, [%[rhs_ptr], #16]\n" + "ldr d6, [%[rhs_ptr], #32]\n" + "ldr d7, [%[rhs_ptr], #48]\n" + + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n" + "beq 401f\n" + "ldr x3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n" + "add x3, x3, %x[col], lsl #2\n" + "ld1 {v14.4s}, [x3]\n" + "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n" + "dup v10.4s, w5\n" // create lhs_zero_point_vec + // Subtract rhs_sums * lhs_zero_point, per + // equation (7) in https://arxiv.org/pdf/1712.05877.pdf + "mls v16.4s, v10.4s, v14.s[0]\n" + "mls v17.4s, v10.4s, v14.s[1]\n" + "mls v18.4s, v10.4s, v14.s[2]\n" + "mls v19.4s, v10.4s, v14.s[3]\n" + "401:\n" + + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n" + "beq 402f\n" + "ldr x2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n" + "add x2, x2, %x[row], lsl #2\n" + "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n" + // Load 4 lhs_sums values. + "ld1 {v11.4s}, [x2]\n" + "ins v13.s[1], w5\n" // rhs_zero_point + // Compute lhs_sums * rhs_zero_point. + "mul v11.4s, v11.4s, v13.s[1]\n" + // Subtract lhs_sums * rhs_zero_point, per + // equation (7) in https://arxiv.org/pdf/1712.05877.pdf + "sub v16.4s, v16.4s, v11.4s\n" + "sub v17.4s, v17.4s, v11.4s\n" + "sub v18.4s, v18.4s, v11.4s\n" + "sub v19.4s, v19.4s, v11.4s\n" + + // If the destination is int32, it means the user asks for the raw + // accumulators, no need for us to downquantize the value. + "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n" + "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n" + + "402:\n" + + // At this point we have computed the final int32 values. Now we + // start down-quantizing them to obtain the final 8bit values from them. + + // As part of this down-quantization, our int32 values will be + // multiplied by a multiplier that has a fixed-point component and an + // exponent component. + + + "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n" + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" + "add x5, x1, %x[row], lsl #2\n" + "csel x1, x1, x5, eq\n" + + "ld1 {v14.4s}, [x1]\n" + + "smax v12.4s, v14.4s, v8.4s\n" + "ldr x1, [%[lhs_ptr], #8]\n" + + "sshl v16.4s, v16.4s, v12.4s\n" + "ldr x2, [%[lhs_ptr], #24]\n" + "sshl v17.4s, v17.4s, v12.4s\n" + "ldr x3, [%[lhs_ptr], #40]\n" + "sshl v18.4s, v18.4s, v12.4s\n" + "ldr x4, [%[lhs_ptr], #56]\n" + "sshl v19.4s, v19.4s, v12.4s\n" + + "smin v12.4s, v14.4s, v8.4s\n" + + // Apply the fixed-point part of the multiplier. + "ins v0.d[1], x1\n" + "ldr x1, [%[rhs_ptr], #8]\n" + "sqrdmulh v16.4s, v16.4s, v15.4s\n" + "ins v1.d[1], x2\n" + "ldr x2, [%[rhs_ptr], #24]\n" + "sqrdmulh v17.4s, v17.4s, v15.4s\n" + "ins v2.d[1], x3\n" + "ldr x3, [%[rhs_ptr], #40]\n" + "sqrdmulh v18.4s, v18.4s, v15.4s\n" + "ins v3.d[1], x4\n" + "ldr x4, [%[rhs_ptr], #56]\n" + "sqrdmulh v19.4s, v19.4s, v15.4s\n" + + // We have some rounding division-by-power-of-two to do. This should + // always use "round to nearest". We allow for some + // freedom in how ties are broken, to strike a good compromise of + // performance on given hardware vs. perfect agreement of results + // across hardware. + // + // When RUY_OPT_NATIVE_ROUNDING is enabled, we allow for implementation + // defined tie-breaks to help performance. On NEON, this means that we + // can just use the NEON rounding instructions, such as srshl. They + // happen to be breaking ties upward. + // + // When RUY_OPT_NATIVE_ROUNDING is disabled, we implement strict + // break-ties-away-from zero, as described in Appendix B of + // https://arxiv.org/pdf/1712.05877.pdf + // When we wrote that, we thought that that would be better unbiased + // than the NEON upwards tie-breaks, and we had observed some + // improvement on some model. However, that is only more unbiased for + // data centered at zero, which was likely the case in that model, + // but is not always the case. If we wanted something more consistently + // unbiased then we should try breaking ties toward-nearest-even. +#if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING) + // Fix up values to be right-shifted, so that the (round to nearest, + // break ties upward) behavior of srshl applied to these fixed-up + // values, produces the same result as the desired (round to nearest, + // break ties away from zero) behavior on the original values. + "and v8.16b, v16.16b, v12.16b\n" + "and v9.16b, v17.16b, v12.16b\n" + "and v14.16b, v18.16b, v12.16b\n" + "and v15.16b, v19.16b, v12.16b\n" + "sshr v8.4s, v8.4s, #31\n" + "sshr v9.4s, v9.4s, #31\n" + "sshr v14.4s, v14.4s, #31\n" + "sshr v15.4s, v15.4s, #31\n" + "sqadd v16.4s, v16.4s, v8.4s\n" + "sqadd v17.4s, v17.4s, v9.4s\n" + "sqadd v18.4s, v18.4s, v14.4s\n" + "sqadd v19.4s, v19.4s, v15.4s\n" +#endif + // At this point we have reduced the problem of correctly implementing + // rounding divide-by-power-of-two, to what the SRSHL instruction can + // do. + "srshl v16.4s, v16.4s, v12.4s\n" + "srshl v17.4s, v17.4s, v12.4s\n" + "srshl v18.4s, v18.4s, v12.4s\n" + "srshl v19.4s, v19.4s, v12.4s\n" + + "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n" + "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n" + "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n" + "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n" + + "ins v4.d[1], x1\n" + "sqxtn v16.4h, v16.4s\n" + "ins v5.d[1], x2\n" + "sqxtn2 v16.8h, v17.4s\n" + "ins v6.d[1], x3\n" + "sqxtn v17.4h, v18.4s\n" + "ins v7.d[1], x4\n" + RUY_MAKE_ZERO(v18) + "sqxtn2 v17.8h, v19.4s\n" + + // At this point, v18 -- v31 aren't used anymore for the current block, + // so we can start clearing these accumulators for the next block + // (next iteration of the main loop). + RUY_MAKE_ZERO(v19) + + // Add the destination zero point + "add %[lhs_ptr], %[lhs_ptr], #64\n" + "dup v14.8h, v13.h[4]\n" + RUY_MAKE_ZERO(v20) + "add %[rhs_ptr], %[rhs_ptr], #64\n" + "add v16.8h, v16.8h, v14.8h\n" + RUY_MAKE_ZERO(v21) + "add v17.8h, v17.8h, v14.8h\n" + RUY_MAKE_ZERO(v22) + + // Cast-and-saturate from int16 to uint8 + "sqxtun v16.8b, v16.8h\n" + RUY_MAKE_ZERO(v23) + "sqxtun2 v16.16b, v17.8h\n" + RUY_MAKE_ZERO(v24) + + // Load the clamp_min, clamp_max bounds + "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + RUY_MAKE_ZERO(v25) + "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + RUY_MAKE_ZERO(v26) + "dup v14.16b, w2\n" // clamp_min + RUY_MAKE_ZERO(v27) + "dup v15.16b, w3\n" // clamp_max + RUY_MAKE_ZERO(v28) + + // Apply the clamp_min bound + "umax v16.16b, v16.16b, v14.16b\n" + RUY_MAKE_ZERO(v29) + // Apply the clamp_max bound + "umin v16.16b, v16.16b, v15.16b\n" + RUY_MAKE_ZERO(v30) + + // Compute how much of the 4x4 block of destination 8bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 4x4, there are some 4x4 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + RUY_MAKE_ZERO(v31) + "sub w2, %w[dst_cols], %w[col]\n" + "mov w3, #4\n" + "cmp w1, #4\n" + // Compute w1 = how many rows of the 4x4 block fit + "csel w1, w1, w3, le\n" + "cmp w2, #4\n" + // Compute w2 = how many cols of the 4x4 block fit + "csel w2, w2, w3, le\n" + + // Test if w1==4 && w2 == 4, i.e. if all of the 8x8 block fits. + "cmp w1, w3\n" + "ccmp w2, w3, 0, eq\n" + "mov x4, %[dst_ptr]\n" + // Yes, all of the 4x4 block fits, go to fast path. + "beq 30f\n" + // Not all of the 4x4 block fits. + // Store to dst_tmp_buf + "st1 {v16.16b}, [%[dst_tmp_buf]]\n" + // Slow loop copying from dst_tmp_buf to dst. + "mov x3, %[dst_tmp_buf]\n" + "mov w6, #0\n" + "50:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov w5, #0\n" + "51:\n" + "ldrb w7, [x3, w5, uxtw]\n" + "strb w7, [x4, w5, uxtw]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 51b\n" + "add w6, w6, #1\n" + "add x3, x3, #4\n" + "add x4, x4, x11\n" + "cmp w6, w2\n" + "blt 50b\n" + "b 31f\n" + "30:\n" + // Yes, all of the 4x4 block fits. + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v16.b}[0], [x3], #1\n" + "add x4, x4, x11\n" + "st1 {v16.b}[1], [x3], #1\n" + "st1 {v16.b}[2], [x3], #1\n" + "st1 {v16.b}[3], [x3], #1\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v16.b}[4], [x3], #1\n" + "add x4, x4, x11\n" + "st1 {v16.b}[5], [x3], #1\n" + "st1 {v16.b}[6], [x3], #1\n" + "st1 {v16.b}[7], [x3], #1\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v16.b}[8], [x3], #1\n" + "add x4, x4, x11\n" + "st1 {v16.b}[9], [x3], #1\n" + "st1 {v16.b}[10], [x3], #1\n" + "st1 {v16.b}[11], [x3], #1\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v16.b}[12], [x3], #1\n" + "add x4, x4, x11\n" + "st1 {v16.b}[13], [x3], #1\n" + "st1 {v16.b}[14], [x3], #1\n" + "st1 {v16.b}[15], [x3], #1\n" + "31:\n" + + "add %[dst_ptr], %[dst_ptr], #4\n" + + RUY_MAKE_ZERO(v16) + RUY_MAKE_ZERO(v17) + + "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n" + + "ins v4.d[1], x1\n" + "sqxtn v16.4h, v16.4s\n" + "ins v5.d[1], x2\n" + "sqxtn2 v16.8h, v17.4s\n" + "ins v6.d[1], x3\n" + "sqxtn v17.4h, v18.4s\n" + "ins v7.d[1], x4\n" + RUY_MAKE_ZERO(v18) + "sqxtn2 v17.8h, v19.4s\n" + + // At this point, v18 -- v31 aren't used anymore for the current block, + // so we can start clearing these accumulators for the next block + // (next iteration of the main loop). + RUY_MAKE_ZERO(v19) + + // Add the destination zero point + "add %[lhs_ptr], %[lhs_ptr], #64\n" + "dup v14.8h, v13.h[4]\n" + RUY_MAKE_ZERO(v20) + "add %[rhs_ptr], %[rhs_ptr], #64\n" + "add v16.8h, v16.8h, v14.8h\n" + RUY_MAKE_ZERO(v21) + "add v17.8h, v17.8h, v14.8h\n" + RUY_MAKE_ZERO(v22) + + // Cast-and-saturate from int16 to uint8 + "sqxtn v16.8b, v16.8h\n" + RUY_MAKE_ZERO(v23) + "sqxtn2 v16.16b, v17.8h\n" + RUY_MAKE_ZERO(v24) + + // Load the clamp_min, clamp_max bounds + "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + RUY_MAKE_ZERO(v25) + "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + RUY_MAKE_ZERO(v26) + "dup v14.16b, w2\n" // clamp_min + RUY_MAKE_ZERO(v27) + "dup v15.16b, w3\n" // clamp_max + RUY_MAKE_ZERO(v28) + + // Apply the clamp_min bound + "smax v16.16b, v16.16b, v14.16b\n" + RUY_MAKE_ZERO(v29) + // Apply the clamp_max bound + "smin v16.16b, v16.16b, v15.16b\n" + RUY_MAKE_ZERO(v30) + + // Compute how much of the 4x4 block of destination 8bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 4x4, there are some 4x4 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + RUY_MAKE_ZERO(v31) + "sub w2, %w[dst_cols], %w[col]\n" + "mov w3, #4\n" + "cmp w1, #4\n" + // Compute w1 = how many rows of the 4x4 block fit + "csel w1, w1, w3, le\n" + "cmp w2, #4\n" + // Compute w2 = how many cols of the 4x4 block fit + "csel w2, w2, w3, le\n" + + // Test if w1==4 && w2 == 4, i.e. if all of the 8x8 block fits. + "cmp w1, w3\n" + "ccmp w2, w3, 0, eq\n" + "mov x4, %[dst_ptr]\n" + // Yes, all of the 4x4 block fits, go to fast path. + "beq 30f\n" + // Not all of the 4x4 block fits. + // Store to dst_tmp_buf + "st1 {v16.16b}, [%[dst_tmp_buf]]\n" + // Slow loop copying from dst_tmp_buf to dst. + "mov x3, %[dst_tmp_buf]\n" + "mov w6, #0\n" + "50:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov w5, #0\n" + "51:\n" + "ldrb w7, [x3, w5, uxtw]\n" + "strb w7, [x4, w5, uxtw]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 51b\n" + "add w6, w6, #1\n" + "add x3, x3, #4\n" + "add x4, x4, x11\n" + "cmp w6, w2\n" + "blt 50b\n" + "b 31f\n" + "30:\n" + // Yes, all of the 4x4 block fits. + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v16.b}[0], [x3], #1\n" + "add x4, x4, x11\n" + "st1 {v16.b}[1], [x3], #1\n" + "st1 {v16.b}[2], [x3], #1\n" + "st1 {v16.b}[3], [x3], #1\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v16.b}[4], [x3], #1\n" + "add x4, x4, x11\n" + "st1 {v16.b}[5], [x3], #1\n" + "st1 {v16.b}[6], [x3], #1\n" + "st1 {v16.b}[7], [x3], #1\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v16.b}[8], [x3], #1\n" + "add x4, x4, x11\n" + "st1 {v16.b}[9], [x3], #1\n" + "st1 {v16.b}[10], [x3], #1\n" + "st1 {v16.b}[11], [x3], #1\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v16.b}[12], [x3], #1\n" + "add x4, x4, x11\n" + "st1 {v16.b}[13], [x3], #1\n" + "st1 {v16.b}[14], [x3], #1\n" + "st1 {v16.b}[15], [x3], #1\n" + "31:\n" + + "add %[dst_ptr], %[dst_ptr], #4\n" + + RUY_MAKE_ZERO(v16) + RUY_MAKE_ZERO(v17) + + "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n" + + // Add the destination zero point + "dup v14.4h, v13.h[4]\n" + "saddw v16.4s, v16.4s, v14.4h\n" + "saddw v17.4s, v17.4s, v14.4h\n" + "saddw v18.4s, v18.4s, v14.4h\n" + "saddw v19.4s, v19.4s, v14.4h\n" + + // Cast-and-saturate from int32 to int16 + "ins v4.d[1], x1\n" + "sqxtn v16.4h, v16.4s\n" + "ins v5.d[1], x2\n" + "sqxtn2 v16.8h, v17.4s\n" + "ins v6.d[1], x3\n" + "sqxtn v17.4h, v18.4s\n" + "ins v7.d[1], x4\n" + RUY_MAKE_ZERO(v18) + "sqxtn2 v17.8h, v19.4s\n" + + // At this point, v18 -- v31 aren't used anymore for the current block, + // so we can start clearing these accumulators for the next block + // (next iteration of the main loop). + RUY_MAKE_ZERO(v19) + + "add %[lhs_ptr], %[lhs_ptr], #64\n" + RUY_MAKE_ZERO(v20) + "add %[rhs_ptr], %[rhs_ptr], #64\n" + RUY_MAKE_ZERO(v21) + RUY_MAKE_ZERO(v22) + + RUY_MAKE_ZERO(v23) + RUY_MAKE_ZERO(v24) + + // Load the clamp_min, clamp_max bounds + "ldrh w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + RUY_MAKE_ZERO(v25) + "ldrh w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + RUY_MAKE_ZERO(v26) + "dup v14.8h, w2\n" // clamp_min + RUY_MAKE_ZERO(v27) + "dup v15.8h, w3\n" // clamp_max + RUY_MAKE_ZERO(v28) + + // Apply the clamp_min bound + "smax v16.8h, v16.8h, v14.8h\n" + "smax v17.8h, v17.8h, v14.8h\n" + RUY_MAKE_ZERO(v29) + // Apply the clamp_max bound + "smin v16.8h, v16.8h, v15.8h\n" + "smin v17.8h, v17.8h, v15.8h\n" + RUY_MAKE_ZERO(v30) + + // Compute how much of the 4x4 block of destination 8bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 4x4, there are some 4x4 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + RUY_MAKE_ZERO(v31) + "sub w2, %w[dst_cols], %w[col]\n" + "mov w3, #4\n" + "cmp w1, #4\n" + // Compute w1 = how many rows of the 4x4 block fit + "csel w1, w1, w3, le\n" + "cmp w2, #4\n" + // Compute w2 = how many cols of the 4x4 block fit + "csel w2, w2, w3, le\n" + + // Test if w1==4 && w2 == 4, i.e. if all of the 4x4 block fits. + "cmp w1, w3\n" + "ccmp w2, w3, 0, eq\n" + "mov x4, %[dst_ptr]\n" + // Yes, all of the 4x4 block fits, go to fast path. + "beq 30f\n" + // Not all of the 4x4 block fits. + // Store to dst_tmp_buf + "str q16, [%[dst_tmp_buf], #0]\n" + "str q17, [%[dst_tmp_buf], #16]\n" + // Slow loop copying from dst_tmp_buf to dst. + "mov x3, %[dst_tmp_buf]\n" + "mov w6, #0\n" + "50:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov w5, #0\n" + "51:\n" + "ldrh w7, [x3, x5, lsl #1]\n" + "strh w7, [x4, x5, lsl #1]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 51b\n" + "add w6, w6, #1\n" + "add x3, x3, #8\n" + "add x4, x4, x11\n" + "cmp w6, w2\n" + "blt 50b\n" + "b 31f\n" + "30:\n" + // Yes, all of the 4x4 block fits. + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v16.h}[0], [x3], #2\n" + "add x4, x4, x11\n" + "st1 {v16.h}[1], [x3], #2\n" + "st1 {v16.h}[2], [x3], #2\n" + "st1 {v16.h}[3], [x3], #2\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v16.h}[4], [x3], #2\n" + "add x4, x4, x11\n" + "st1 {v16.h}[5], [x3], #2\n" + "st1 {v16.h}[6], [x3], #2\n" + "st1 {v16.h}[7], [x3], #2\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v17.h}[0], [x3], #2\n" + "add x4, x4, x11\n" + "st1 {v17.h}[1], [x3], #2\n" + "st1 {v17.h}[2], [x3], #2\n" + "st1 {v17.h}[3], [x3], #2\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v17.h}[4], [x3], #2\n" + "add x4, x4, x11\n" + "st1 {v17.h}[5], [x3], #2\n" + "st1 {v17.h}[6], [x3], #2\n" + "st1 {v17.h}[7], [x3], #2\n" + "31:\n" + + "add %[dst_ptr], %[dst_ptr], #8\n" + + RUY_MAKE_ZERO(v16) + RUY_MAKE_ZERO(v17) + + "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n" + + "ldr x1, [%[lhs_ptr], #8]\n" + "ldr x2, [%[lhs_ptr], #24]\n" + "ldr x3, [%[lhs_ptr], #40]\n" + "ldr x4, [%[lhs_ptr], #56]\n" + + "ins v0.d[1], x1\n" + "ldr x1, [%[rhs_ptr], #8]\n" + "ins v1.d[1], x2\n" + "ldr x2, [%[rhs_ptr], #24]\n" + "ins v2.d[1], x3\n" + "ldr x3, [%[rhs_ptr], #40]\n" + "ins v3.d[1], x4\n" + "ldr x4, [%[rhs_ptr], #56]\n" + "ins v4.d[1], x1\n" + "ins v5.d[1], x2\n" + "ins v6.d[1], x3\n" + "ins v7.d[1], x4\n" + + // Since the store type is the same as the accum type, no need for + // downcast. There's also no need for clamp by min/max. + + // At this point, v20 -- v31 aren't used anymore for the current block, + // so we can start clearing these accumulators for the next block + // (next iteration of the main loop). + + RUY_MAKE_ZERO(v20) + "add %[lhs_ptr], %[lhs_ptr], #64\n" + RUY_MAKE_ZERO(v21) + "add %[rhs_ptr], %[rhs_ptr], #64\n" + RUY_MAKE_ZERO(v22) + + RUY_MAKE_ZERO(v23) + RUY_MAKE_ZERO(v24) + RUY_MAKE_ZERO(v25) + RUY_MAKE_ZERO(v26) + RUY_MAKE_ZERO(v27) + RUY_MAKE_ZERO(v28) + RUY_MAKE_ZERO(v29) + RUY_MAKE_ZERO(v30) + + // Compute how much of the 4x4 block of destination 8bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 4x4, there are some 4x4 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + RUY_MAKE_ZERO(v31) + "sub w2, %w[dst_cols], %w[col]\n" + "mov w3, #4\n" + "cmp w1, #4\n" + // Compute w1 = how many rows of the 4x4 block fit + "csel w1, w1, w3, le\n" + "cmp w2, #4\n" + // Compute w2 = how many cols of the 4x4 block fit + "csel w2, w2, w3, le\n" + + // Test if w1==4 && w2 == 4, i.e. if all of the 8x8 block fits. + "cmp w1, w3\n" + "ccmp w2, w3, 0, eq\n" + "mov x4, %[dst_ptr]\n" + // Yes, all of the 4x4 block fits, go to fast path. + "beq 30f\n" + // Not all of the 4x4 block fits. + // Store to dst_tmp_buf + "str q16, [%[dst_tmp_buf], #0]\n" + "str q17, [%[dst_tmp_buf], #16]\n" + "str q18, [%[dst_tmp_buf], #32]\n" + "str q19, [%[dst_tmp_buf], #48]\n" + // Slow loop copying from dst_tmp_buf to dst. + "mov x3, %[dst_tmp_buf]\n" + "mov w6, #0\n" + "50:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov w5, #0\n" + "51:\n" + "ldr w7, [x3, x5, lsl #2]\n" + "str w7, [x4, x5, lsl #2]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 51b\n" + "add w6, w6, #1\n" + "add x3, x3, #16\n" + "add x4, x4, x11\n" + "cmp w6, w2\n" + "blt 50b\n" + "b 31f\n" + "30:\n" + // Yes, all of the 4x4 block fits. + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v16.s}[0], [x3], #4\n" + "add x4, x4, x11\n" + "st1 {v16.s}[1], [x3], #4\n" + "st1 {v16.s}[2], [x3], #4\n" + "st1 {v16.s}[3], [x3], #4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v17.s}[0], [x3], #4\n" + "add x4, x4, x11\n" + "st1 {v17.s}[1], [x3], #4\n" + "st1 {v17.s}[2], [x3], #4\n" + "st1 {v17.s}[3], [x3], #4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v18.s}[0], [x3], #4\n" + "add x4, x4, x11\n" + "st1 {v18.s}[1], [x3], #4\n" + "st1 {v18.s}[2], [x3], #4\n" + "st1 {v18.s}[3], [x3], #4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v19.s}[0], [x3], #4\n" + "add x4, x4, x11\n" + "st1 {v19.s}[1], [x3], #4\n" + "st1 {v19.s}[2], [x3], #4\n" + "st1 {v19.s}[3], [x3], #4\n" + "31:\n" + + "add %[dst_ptr], %[dst_ptr], #16\n" + + RUY_MAKE_ZERO(v16) + RUY_MAKE_ZERO(v17) + RUY_MAKE_ZERO(v18) + RUY_MAKE_ZERO(v19) + + RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n" + + // For the next block: perform the first few multiply-adds on the data + // that we have already loaded. + "smull v8.8h, v0.8b, v4.8b\n" + "smull v9.8h, v1.8b, v4.8b\n" + "smull v10.8h, v2.8b, v4.8b\n" + // Reload some params --- we had used x5 -- x7 for a few other things + // since the last time we had loaded them. + "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + "smull v11.8h, v3.8b, v4.8b\n" + "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" + "smull v12.8h, v0.8b, v5.8b\n" + "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" + "smull v13.8h, v1.8b, v5.8b\n" + "smull v14.8h, v2.8b, v5.8b\n" + "smull v15.8h, v3.8b, v5.8b\n" + // Move to the next block of the destination matrix, for the next iter + // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already + // been updated earlier. + // Have we reached the end row? + "cmp %w[row], w7\n" + "smlal2 v8.8h, v0.16b, v4.16b\n" + "smlal2 v9.8h, v1.16b, v4.16b\n" + "smlal2 v10.8h, v2.16b, v4.16b\n" + "smlal2 v11.8h, v3.16b, v4.16b\n" + "smlal2 v12.8h, v0.16b, v5.16b\n" + "smlal2 v13.8h, v1.16b, v5.16b\n" + "smlal2 v14.8h, v2.16b, v5.16b\n" + "smlal2 v15.8h, v3.16b, v5.16b\n" + + + "beq 20f\n" // yes, end row. + // Not end row. Move to the next row. + "add %w[row], %w[row], #4\n" + "b 21f\n" + "20:\n" + // Was already at end row. + "mov %w[row], w6\n" // Move back to first row. + "add %w[col], %w[col], #4\n" // Move to the next column. + "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #2\n" + "mov %[dst_ptr], %[dst_col_ptr]\n" + "21:\n" + + // Main loop exit condition: have we hit the end column? + "cmp %w[col], w8\n" + + // w1 is the number of levels of depth that we have already loaded + // LHS and RHS data for. Corresponding to the initial ld1 instructions + // above, this is currently 4. + "mov w1, #16\n" + + "ble 1b\n" + + // clang-format on + + : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr), + [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), + [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col) + : [ params ] "r"(¶ms),[dst_rows] "r"(params.dst_rows), + [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf), + [dst_type_id] "r"(params.dst_type_id) + : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc", + "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", + "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31"); +} + +// Kernel taking advantage of the optional dotprod instruction. +// This is very similar to (and directly inspired by) this gemmlowp kernel +// which was contributed by David Mansell at ARM: +// NEON_64bit_GEMM_Uint8Operands_Uint32Accumulators_dotproduct +// https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L3391 +// +// Besides the ruy-ification, the main difference here is that we use a 8x8 +// instead of 12x8 width, so as to stick to power-of-two widths. This slightly +// narrower kernel layout is still wide enough to achieve high performance +// although we haven't actually performed a real comparison to know exactly +// how this compares to ARM's aforementioned kernel. +// +// Relevant target CPUs for this kernel include ARM Cortex-A76, +// since these are 64-bit, out-of-order and with dotprod support. +void Kernel8bitNeonDotprodOutOfOrder(const KernelParams8bit<8, 8>& params) { + profiler::ScopeLabel label( + "Kernel (kNeonDotprod, optimized for out-of-order cores)"); + + CheckOffsetsInKernelParams8bit(params); + + const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; + const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; + const std::int8_t* lhs_ptr = lhs_col_ptr; + const std::int8_t* rhs_ptr = rhs_col_ptr; + void* dst_col_ptr = params.dst_base_ptr; + void* dst_ptr = dst_col_ptr; + int row = params.start_row; + int col = params.start_col; + + // The asm kernel below has the following NEON register allocation: + // + // v16 -- v31 are int32 accumulators. + // During accumulation, v0 -- v15 are used to load int8 data from LHS and + // RHS. At least v0 and v1 are used to load a 8x4 block of LHS, and v2 and + // v3 are used to load a 4x8 block of RHS, like this: + // + // int8 RHS 4x8 block + // /-----------------------------------------\ + // |v2.b[0] ... v2.b[12] v3.b[0] ... v3.b[12]| + // | ... ... | + // |v2.b[3] ... v2.b[15] v3.b[3] ... v3.b[15]| + // \-----------------------------------------/ + // int8 LHS 8x4 block + // /---------------------\ /-----------------------------------------\ + // |v0.b[0] ... v0.b[3] | |v16.s[0] ... v30.s[0]| + // | ... ... | | ... ... | + // |v0.b[12] ... v0.b[15]| |v16.s[3] ... v30.s[3]| + // |v1.b[0] ... v1.b[3] | |v17.s[0] ... v31.s[0]| + // | ... ... | | ... ... | + // |v1.b[12] ... v1.b[15]| |v17.s[3] ... v31.s[3]| + // \---------------------/ \-----------------------------------------/ + // int32 accumulators 8x8 block + // + // In the RUY_OPT_MAX_STREAMING part of the kernel, this elementary step + // is repeated 4 times, using 4x more registers for LHS and RHS, so that + // is where instead of using v0 -- v3 for LHS and RHS, we use v0 -- v15. + // + // Outside of the RUY_OPT_MAX_STREAMING part of the kernel, v4 -- v7 are + // unused, and v8 -- v15 are used for loading parameters used for the + // post-accumulation part of the kernel. + asm volatile( +#define RUY_MAKE_ZERO(reg) "dup " #reg ".4s, wzr\n" + + // clang-format off + + // Load some parameters into registers. + "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" + "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" + "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" + "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" + "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" + "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" + "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" + + // Load the first 32 bytes of LHS and RHS data. + "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v2.16b}, [%[rhs_ptr]], #16\n" + "ld1 {v3.16b}, [%[rhs_ptr]], #16\n" + + // Clear accumulators. + RUY_MAKE_ZERO(v16) + RUY_MAKE_ZERO(v17) + RUY_MAKE_ZERO(v18) + RUY_MAKE_ZERO(v19) + RUY_MAKE_ZERO(v20) + RUY_MAKE_ZERO(v21) + RUY_MAKE_ZERO(v22) + RUY_MAKE_ZERO(v23) + RUY_MAKE_ZERO(v24) + RUY_MAKE_ZERO(v25) + RUY_MAKE_ZERO(v26) + RUY_MAKE_ZERO(v27) + RUY_MAKE_ZERO(v28) + RUY_MAKE_ZERO(v29) + RUY_MAKE_ZERO(v30) + RUY_MAKE_ZERO(v31) + + // w1 is the number of levels of depth that we have already loaded + // LHS and RHS data for. Corresponding to the initial ld1 instructions + // above, this is currently 4. + "mov w1, #4\n" + + // Perform the first few multiply-adds on the data that we have already + // loaded. + ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" + ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" + ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" + ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" + + // Main loop of the whole GEMM, over rows and columns of the + // destination matrix. + "1:\n" + + // Optional, maximally-streaming, partial-unrolling (4x unrolled) + // optimization of the kernel inner loop (over depth). For more + // comments, see the non-unrolled loop below after the #endif. +#if RUY_OPT_ENABLED(RUY_OPT_MAX_STREAMING) + "cmp w12, #32\n" + "blt 78f\n" + + "ld1 {v4.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v5.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v6.16b}, [%[rhs_ptr]], #16\n" + "ld1 {v7.16b}, [%[rhs_ptr]], #16\n" + "ld1 {v8.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v9.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v10.16b}, [%[rhs_ptr]], #16\n" + "ld1 {v11.16b}, [%[rhs_ptr]], #16\n" + "ld1 {v12.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v13.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v14.16b}, [%[rhs_ptr]], #16\n" + "ld1 {v15.16b}, [%[rhs_ptr]], #16\n" + "mov w1, #16\n" + + "and w3, w12, #-16\n" + "81:\n" + "add w1, w1, #16\n" + + ".word 0x4f83e018 // sdot v24.4s, v0.16b, v3.4b[0]\n" + ".word 0x4fa3e01a // sdot v26.4s, v0.16b, v3.4b[1]\n" + ".word 0x4f83e81c // sdot v28.4s, v0.16b, v3.4b[2]\n" + ".word 0x4fa3e81e // sdot v30.4s, v0.16b, v3.4b[3]\n" + "ldr q0, [%[lhs_ptr], #0]\n" + ".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n" + ".word 0x4fa2e033 // sdot v19.4s, v1.16b, v2.4b[1]\n" + ".word 0x4f82e835 // sdot v21.4s, v1.16b, v2.4b[2]\n" + ".word 0x4fa2e837 // sdot v23.4s, v1.16b, v2.4b[3]\n" + "ldr q2, [%[rhs_ptr], #0]\n" + ".word 0x4f83e039 // sdot v25.4s, v1.16b, v3.4b[0]\n" + ".word 0x4fa3e03b // sdot v27.4s, v1.16b, v3.4b[1]\n" + ".word 0x4f83e83d // sdot v29.4s, v1.16b, v3.4b[2]\n" + ".word 0x4fa3e83f // sdot v31.4s, v1.16b, v3.4b[3]\n" + "ldr q1, [%[lhs_ptr], #16]\n" + + ".word 0x4f87e098 // sdot v24.4s, v4.16b, v7.4b[0]\n" + ".word 0x4fa7e09a // sdot v26.4s, v4.16b, v7.4b[1]\n" + "ldr q3, [%[rhs_ptr], #16]\n" + ".word 0x4f87e89c // sdot v28.4s, v4.16b, v7.4b[2]\n" + ".word 0x4fa7e89e // sdot v30.4s, v4.16b, v7.4b[3]\n" + ".word 0x4f86e0b1 // sdot v17.4s, v5.16b, v6.4b[0]\n" + ".word 0x4fa6e0b3 // sdot v19.4s, v5.16b, v6.4b[1]\n" + ".word 0x4f86e8b5 // sdot v21.4s, v5.16b, v6.4b[2]\n" + ".word 0x4fa6e8b7 // sdot v23.4s, v5.16b, v6.4b[3]\n" + ".word 0x4f87e0b9 // sdot v25.4s, v5.16b, v7.4b[0]\n" + ".word 0x4fa7e0bb // sdot v27.4s, v5.16b, v7.4b[1]\n" + ".word 0x4f87e8bd // sdot v29.4s, v5.16b, v7.4b[2]\n" + ".word 0x4fa7e8bf // sdot v31.4s, v5.16b, v7.4b[3]\n" + "ldr q5, [%[lhs_ptr], #48]\n" + ".word 0x4f86e090 // sdot v16.4s, v4.16b, v6.4b[0]\n" + ".word 0x4fa6e092 // sdot v18.4s, v4.16b, v6.4b[1]\n" + "ldr q7, [%[rhs_ptr], #48]\n" + ".word 0x4f86e894 // sdot v20.4s, v4.16b, v6.4b[2]\n" + ".word 0x4fa6e896 // sdot v22.4s, v4.16b, v6.4b[3]\n" + "ldr q4, [%[lhs_ptr], #32]\n" + + ".word 0x4f8be118 // sdot v24.4s, v8.16b, v11.4b[0]\n" + ".word 0x4fabe11a // sdot v26.4s, v8.16b, v11.4b[1]\n" + "ldr q6, [%[rhs_ptr], #32]\n" + ".word 0x4f8be91c // sdot v28.4s, v8.16b, v11.4b[2]\n" + ".word 0x4fabe91e // sdot v30.4s, v8.16b, v11.4b[3]\n" + ".word 0x4f8ae131 // sdot v17.4s, v9.16b, v10.4b[0]\n" + ".word 0x4faae133 // sdot v19.4s, v9.16b, v10.4b[1]\n" + ".word 0x4f8ae935 // sdot v21.4s, v9.16b, v10.4b[2]\n" + ".word 0x4faae937 // sdot v23.4s, v9.16b, v10.4b[3]\n" + ".word 0x4f8be139 // sdot v25.4s, v9.16b, v11.4b[0]\n" + ".word 0x4fabe13b // sdot v27.4s, v9.16b, v11.4b[1]\n" + ".word 0x4f8be93d // sdot v29.4s, v9.16b, v11.4b[2]\n" + ".word 0x4fabe93f // sdot v31.4s, v9.16b, v11.4b[3]\n" + "ldr q9, [%[lhs_ptr], #80]\n" + ".word 0x4f8ae110 // sdot v16.4s, v8.16b, v10.4b[0]\n" + ".word 0x4faae112 // sdot v18.4s, v8.16b, v10.4b[1]\n" + "ldr q11, [%[rhs_ptr], #80]\n" + ".word 0x4f8ae914 // sdot v20.4s, v8.16b, v10.4b[2]\n" + ".word 0x4faae916 // sdot v22.4s, v8.16b, v10.4b[3]\n" + "ldr q8, [%[lhs_ptr], #64]\n" + + ".word 0x4f8fe198 // sdot v24.4s, v12.16b, v15.4b[0]\n" + ".word 0x4fafe19a // sdot v26.4s, v12.16b, v15.4b[1]\n" + "ldr q10, [%[rhs_ptr], #64]\n" + ".word 0x4f8fe99c // sdot v28.4s, v12.16b, v15.4b[2]\n" + ".word 0x4fafe99e // sdot v30.4s, v12.16b, v15.4b[3]\n" + "add %[lhs_ptr], %[lhs_ptr], #128\n" + ".word 0x4f8ee1b1 // sdot v17.4s, v13.16b, v14.4b[0]\n" + ".word 0x4faee1b3 // sdot v19.4s, v13.16b, v14.4b[1]\n" + "add %[rhs_ptr], %[rhs_ptr], #128\n" + ".word 0x4f8ee9b5 // sdot v21.4s, v13.16b, v14.4b[2]\n" + ".word 0x4faee9b7 // sdot v23.4s, v13.16b, v14.4b[3]\n" + ".word 0x4f8fe1b9 // sdot v25.4s, v13.16b, v15.4b[0]\n" + ".word 0x4fafe1bb // sdot v27.4s, v13.16b, v15.4b[1]\n" + "cmp w1, w3\n" + ".word 0x4f8fe9bd // sdot v29.4s, v13.16b, v15.4b[2]\n" + ".word 0x4fafe9bf // sdot v31.4s, v13.16b, v15.4b[3]\n" + "ldr q13, [%[lhs_ptr], #-16]\n" + ".word 0x4f8ee190 // sdot v16.4s, v12.16b, v14.4b[0]\n" + ".word 0x4faee192 // sdot v18.4s, v12.16b, v14.4b[1]\n" + "ldr q15, [%[rhs_ptr], #-16]\n" + ".word 0x4f8ee994 // sdot v20.4s, v12.16b, v14.4b[2]\n" + ".word 0x4faee996 // sdot v22.4s, v12.16b, v14.4b[3]\n" + "ldr q12, [%[lhs_ptr], #-32]\n" + + ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" + ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" + "ldr q14, [%[rhs_ptr], #-32]\n" + ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" + ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" + + "blt 81b\n" + + ".word 0x4f87e098 // sdot v24.4s, v4.16b, v7.4b[0]\n" + ".word 0x4fa7e09a // sdot v26.4s, v4.16b, v7.4b[1]\n" + ".word 0x4f87e89c // sdot v28.4s, v4.16b, v7.4b[2]\n" + ".word 0x4fa7e89e // sdot v30.4s, v4.16b, v7.4b[3]\n" + ".word 0x4f86e0b1 // sdot v17.4s, v5.16b, v6.4b[0]\n" + ".word 0x4fa6e0b3 // sdot v19.4s, v5.16b, v6.4b[1]\n" + ".word 0x4f86e8b5 // sdot v21.4s, v5.16b, v6.4b[2]\n" + ".word 0x4fa6e8b7 // sdot v23.4s, v5.16b, v6.4b[3]\n" + ".word 0x4f87e0b9 // sdot v25.4s, v5.16b, v7.4b[0]\n" + ".word 0x4fa7e0bb // sdot v27.4s, v5.16b, v7.4b[1]\n" + ".word 0x4f87e8bd // sdot v29.4s, v5.16b, v7.4b[2]\n" + ".word 0x4fa7e8bf // sdot v31.4s, v5.16b, v7.4b[3]\n" + ".word 0x4f86e090 // sdot v16.4s, v4.16b, v6.4b[0]\n" + ".word 0x4fa6e092 // sdot v18.4s, v4.16b, v6.4b[1]\n" + ".word 0x4f86e894 // sdot v20.4s, v4.16b, v6.4b[2]\n" + ".word 0x4fa6e896 // sdot v22.4s, v4.16b, v6.4b[3]\n" + + ".word 0x4f8be118 // sdot v24.4s, v8.16b, v11.4b[0]\n" + ".word 0x4fabe11a // sdot v26.4s, v8.16b, v11.4b[1]\n" + ".word 0x4f8be91c // sdot v28.4s, v8.16b, v11.4b[2]\n" + ".word 0x4fabe91e // sdot v30.4s, v8.16b, v11.4b[3]\n" + ".word 0x4f8ae131 // sdot v17.4s, v9.16b, v10.4b[0]\n" + ".word 0x4faae133 // sdot v19.4s, v9.16b, v10.4b[1]\n" + ".word 0x4f8ae935 // sdot v21.4s, v9.16b, v10.4b[2]\n" + ".word 0x4faae937 // sdot v23.4s, v9.16b, v10.4b[3]\n" + ".word 0x4f8be139 // sdot v25.4s, v9.16b, v11.4b[0]\n" + ".word 0x4fabe13b // sdot v27.4s, v9.16b, v11.4b[1]\n" + ".word 0x4f8be93d // sdot v29.4s, v9.16b, v11.4b[2]\n" + ".word 0x4fabe93f // sdot v31.4s, v9.16b, v11.4b[3]\n" + ".word 0x4f8ae110 // sdot v16.4s, v8.16b, v10.4b[0]\n" + ".word 0x4faae112 // sdot v18.4s, v8.16b, v10.4b[1]\n" + ".word 0x4f8ae914 // sdot v20.4s, v8.16b, v10.4b[2]\n" + ".word 0x4faae916 // sdot v22.4s, v8.16b, v10.4b[3]\n" + + ".word 0x4f8fe198 // sdot v24.4s, v12.16b, v15.4b[0]\n" + ".word 0x4fafe19a // sdot v26.4s, v12.16b, v15.4b[1]\n" + ".word 0x4f8fe99c // sdot v28.4s, v12.16b, v15.4b[2]\n" + ".word 0x4fafe99e // sdot v30.4s, v12.16b, v15.4b[3]\n" + ".word 0x4f8ee1b1 // sdot v17.4s, v13.16b, v14.4b[0]\n" + ".word 0x4faee1b3 // sdot v19.4s, v13.16b, v14.4b[1]\n" + ".word 0x4f8ee9b5 // sdot v21.4s, v13.16b, v14.4b[2]\n" + ".word 0x4faee9b7 // sdot v23.4s, v13.16b, v14.4b[3]\n" + ".word 0x4f8fe1b9 // sdot v25.4s, v13.16b, v15.4b[0]\n" + ".word 0x4fafe1bb // sdot v27.4s, v13.16b, v15.4b[1]\n" + ".word 0x4f8fe9bd // sdot v29.4s, v13.16b, v15.4b[2]\n" + ".word 0x4fafe9bf // sdot v31.4s, v13.16b, v15.4b[3]\n" + ".word 0x4f8ee190 // sdot v16.4s, v12.16b, v14.4b[0]\n" + ".word 0x4faee192 // sdot v18.4s, v12.16b, v14.4b[1]\n" + ".word 0x4f8ee994 // sdot v20.4s, v12.16b, v14.4b[2]\n" + ".word 0x4faee996 // sdot v22.4s, v12.16b, v14.4b[3]\n" + + "78:\n" + +#endif // #if RUY_OPT_ENABLED(RUY_OPT_MAX_STREAMING) + + // Ordinary kernel inner loop (over depth), the simpler loop that the + // above was an equivalent 4x-partially-unrolled version of. + + // Reminder - w1 is how many levels of depth we have already loaded + // data for, w12 is the total depth. + "cmp w1, w12\n" + "beq 79f\n" + + "2:\n" + + // Because of the data that we have already loaded, we can start the + // loop body right away with some multiply-adds. + ".word 0x4f83e018 // sdot v24.4s, v0.16b, v3.4b[0]\n" + ".word 0x4fa3e01a // sdot v26.4s, v0.16b, v3.4b[1]\n" + // Each iteration of this loop advances by 4 levels of depth. + "add w1, w1, #4\n" + ".word 0x4f83e81c // sdot v28.4s, v0.16b, v3.4b[2]\n" + ".word 0x4fa3e81e // sdot v30.4s, v0.16b, v3.4b[3]\n" + "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" + ".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n" + ".word 0x4fa2e033 // sdot v19.4s, v1.16b, v2.4b[1]\n" + // Loop termination condition. + "cmp w1, w12\n" + ".word 0x4f82e835 // sdot v21.4s, v1.16b, v2.4b[2]\n" + ".word 0x4fa2e837 // sdot v23.4s, v1.16b, v2.4b[3]\n" + "ld1 {v2.16b}, [%[rhs_ptr]], #16\n" + ".word 0x4f83e039 // sdot v25.4s, v1.16b, v3.4b[0]\n" + ".word 0x4fa3e03b // sdot v27.4s, v1.16b, v3.4b[1]\n" + ".word 0x4f83e83d // sdot v29.4s, v1.16b, v3.4b[2]\n" + ".word 0x4fa3e83f // sdot v31.4s, v1.16b, v3.4b[3]\n" + "ld1 {v3.16b}, [%[rhs_ptr]], #16\n" + ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" + ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" + ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" + ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" + "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" + + "blt 2b\n" + + "79:\n" + // End of the inner loop on depth. Now perform the remaining + // multiply-adds of the last 4 levels of depth, for which the LHS + // and RHS data is already loaded. + + ".word 0x4f83e018 // sdot v24.4s, v0.16b, v3.4b[0]\n" + ".word 0x4fa3e01a // sdot v26.4s, v0.16b, v3.4b[1]\n" + ".word 0x4f83e81c // sdot v28.4s, v0.16b, v3.4b[2]\n" + ".word 0x4fa3e81e // sdot v30.4s, v0.16b, v3.4b[3]\n" + ".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n" + ".word 0x4fa2e033 // sdot v19.4s, v1.16b, v2.4b[1]\n" + ".word 0x4f82e835 // sdot v21.4s, v1.16b, v2.4b[2]\n" + ".word 0x4fa2e837 // sdot v23.4s, v1.16b, v2.4b[3]\n" + ".word 0x4f83e039 // sdot v25.4s, v1.16b, v3.4b[0]\n" + ".word 0x4fa3e03b // sdot v27.4s, v1.16b, v3.4b[1]\n" + ".word 0x4f83e83d // sdot v29.4s, v1.16b, v3.4b[2]\n" + ".word 0x4fa3e83f // sdot v31.4s, v1.16b, v3.4b[3]\n" + + // End of accumulation. The registers v16 -- v31 contain the final + // int32 accumulator values of the current 8x8 destination block. + // We now have to compute the final 8-bit values from these int32 + // accumulators, and advance to the next 8x8 block. We intertwine + // these two aspects whenever possible for optimal pipelining, both + // at the data flow level (prefetch data for next block as early as + // possible) and instruction pipelining level (some of the next-block + // work can dual-issue with some of the final work on the current + // block). + + // Logic to advance to the next block in preparation for the next + // iteration of the main loop. For now, we only want to compute + // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are + // not yet ready to update the values of row and col, as we still need + // the current values for the rest of the work on the current block. + + "cmp %w[row], w7\n" // Have we finished the last row? + "bge 4f\n" // If finished last row, go to 4 + // Not finished last row: then advance to next row. + "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #3\n" + "b 5f\n" + "4:\n" // Finished last row... + "mov %[lhs_col_ptr], x5\n" // Go back to first row + // Now we need to advance to the next column. If we already + // finished the last column, then in principle we are done, however + // we can't just return here, as we need to allow the end work of the + // current block to complete. The good news is that at this point it + // doesn't matter what data we load for the next column, since + // we will exit from the main loop below before actually storing + // anything computed from that data. + "cmp %w[col], w8\n" // Have we finished the last column? + "bge 5f\n" // If yes, just carry on without updating the column pointer. + // Not finished last column: then advance to next column. + "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #3\n" + "5:\n" + + // Set the LHS and RHS data pointers to the start of the columns just + // computed. + "mov %[lhs_ptr], %[lhs_col_ptr]\n" + "mov %[rhs_ptr], %[rhs_col_ptr]\n" + + // Load some parameters needed for the end work on current block. + RUY_MAKE_ZERO(v8) + "ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" + "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n" + "ins v13.h[4], w4\n" // dst_zero_point + "ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n" + "ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" + "dup v9.4s, w3\n" // create prod_zp_depth_vec + "add x5, x4, %x[row], lsl #2\n" + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" + "csel x4, x4, x5, eq\n" + + "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" + "add x5, x1, %x[row], lsl #2\n" + + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" + "csel x1, x1, x5, eq\n" + + // Load 8 bias values. + "ld1 {v14.4s}, [x1], #16\n" + "ld1 {v15.4s}, [x1]\n" + + // Now that we know what LHS and RHS data the next iteration of the + // main loop will need to load, we start loading the first 32 bytes of + // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore + // in the rest of the work on the current block. + "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v2.16b}, [%[rhs_ptr]], #16\n" + "ld1 {v3.16b}, [%[rhs_ptr]], #16\n" + + // Add to the bias values the product (depth * lhs_zero_point * rhs_zero_point), + // See the term NZ1Z2 in equation (7) in https://arxiv.org/pdf/1712.05877.pdf + "add v14.4s, v14.4s, v9.4s\n" + "add v15.4s, v15.4s, v9.4s\n" + + // Perform the bias-addition (per the above, we have just folded into + // the bias the (depth * lhs_zero_point * rhs_zero_point) term.) + "add v16.4s, v16.4s, v14.4s\n" + "add v17.4s, v17.4s, v15.4s\n" + "add v18.4s, v18.4s, v14.4s\n" + "add v19.4s, v19.4s, v15.4s\n" + "add v20.4s, v20.4s, v14.4s\n" + "add v21.4s, v21.4s, v15.4s\n" + "add v22.4s, v22.4s, v14.4s\n" + "add v23.4s, v23.4s, v15.4s\n" + "add v24.4s, v24.4s, v14.4s\n" + "add v25.4s, v25.4s, v15.4s\n" + "add v26.4s, v26.4s, v14.4s\n" + "add v27.4s, v27.4s, v15.4s\n" + "add v28.4s, v28.4s, v14.4s\n" + "add v29.4s, v29.4s, v15.4s\n" + "add v30.4s, v30.4s, v14.4s\n" + "add v31.4s, v31.4s, v15.4s\n" + + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n" + "beq 401f\n" + "ldr x3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n" + "add x3, x3, %x[col], lsl #2\n" + "ld1 {v14.4s}, [x3], #16\n" + "ld1 {v15.4s}, [x3]\n" + "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n" + "dup v10.4s, w5\n" // create lhs_zero_point_vec + // Subtract rhs_sums * lhs_zero_point, per + // equation (7) in https://arxiv.org/pdf/1712.05877.pdf + "mls v16.4s, v10.4s, v14.s[0]\n" + "mls v17.4s, v10.4s, v14.s[0]\n" + "mls v18.4s, v10.4s, v14.s[1]\n" + "mls v19.4s, v10.4s, v14.s[1]\n" + "mls v20.4s, v10.4s, v14.s[2]\n" + "mls v21.4s, v10.4s, v14.s[2]\n" + "mls v22.4s, v10.4s, v14.s[3]\n" + "mls v23.4s, v10.4s, v14.s[3]\n" + "mls v24.4s, v10.4s, v15.s[0]\n" + "mls v25.4s, v10.4s, v15.s[0]\n" + "mls v26.4s, v10.4s, v15.s[1]\n" + "mls v27.4s, v10.4s, v15.s[1]\n" + "mls v28.4s, v10.4s, v15.s[2]\n" + "mls v29.4s, v10.4s, v15.s[2]\n" + "mls v30.4s, v10.4s, v15.s[3]\n" + "mls v31.4s, v10.4s, v15.s[3]\n" + "401:\n" + + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n" + "beq 402f\n" + "ldr x2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n" + "add x2, x2, %x[row], lsl #2\n" + "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n" + // Load 4 lhs_sums values. + "ld1 {v11.4s}, [x2], #16\n" + "ld1 {v12.4s}, [x2]\n" + "ins v13.s[1], w5\n" // rhs_zero_point + // Compute lhs_sums * rhs_zero_point. + "mul v11.4s, v11.4s, v13.s[1]\n" + "mul v12.4s, v12.4s, v13.s[1]\n" + // Subtract lhs_sums * rhs_zero_point, per + // equation (7) in https://arxiv.org/pdf/1712.05877.pdf + "sub v16.4s, v16.4s, v11.4s\n" + "sub v17.4s, v17.4s, v12.4s\n" + "sub v18.4s, v18.4s, v11.4s\n" + "sub v19.4s, v19.4s, v12.4s\n" + "sub v20.4s, v20.4s, v11.4s\n" + "sub v21.4s, v21.4s, v12.4s\n" + "sub v22.4s, v22.4s, v11.4s\n" + "sub v23.4s, v23.4s, v12.4s\n" + "sub v24.4s, v24.4s, v11.4s\n" + "sub v25.4s, v25.4s, v12.4s\n" + "sub v26.4s, v26.4s, v11.4s\n" + "sub v27.4s, v27.4s, v12.4s\n" + "sub v28.4s, v28.4s, v11.4s\n" + "sub v29.4s, v29.4s, v12.4s\n" + "sub v30.4s, v30.4s, v11.4s\n" + "sub v31.4s, v31.4s, v12.4s\n" + + "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n" + "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n" + + "402:\n" + + // At this point we have computed the final int32 values. Now we + // start down-quantizing them to obtain the final 8bit values from them. + + // As part of this down-quantization, our int32 values will be + // multiplied by a multiplier that has a fixed-point component and an + // exponent component. + + //Load the exponent part of the multiplier. + "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n" + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" + "add x5, x1, %x[row], lsl #2\n" + "csel x1, x1, x5, eq\n" + + "ldr q9, [x1]\n" + "ldr q10, [x1, #16]\n" + + "tst w6, #" RUY_STR(RUY_ASM_FLAG_NEEDS_LEFT_SHIFT) "\n" + "beq 403f\n" + "smax v11.4s, v9.4s, v8.4s\n" + "smax v12.4s, v10.4s, v8.4s\n" + "sshl v16.4s, v16.4s, v11.4s\n" + "sshl v17.4s, v17.4s, v12.4s\n" + "sshl v18.4s, v18.4s, v11.4s\n" + "sshl v19.4s, v19.4s, v12.4s\n" + "sshl v20.4s, v20.4s, v11.4s\n" + "sshl v21.4s, v21.4s, v12.4s\n" + "sshl v22.4s, v22.4s, v11.4s\n" + "sshl v23.4s, v23.4s, v12.4s\n" + "sshl v24.4s, v24.4s, v11.4s\n" + "sshl v25.4s, v25.4s, v12.4s\n" + "sshl v26.4s, v26.4s, v11.4s\n" + "sshl v27.4s, v27.4s, v12.4s\n" + "sshl v28.4s, v28.4s, v11.4s\n" + "sshl v29.4s, v29.4s, v12.4s\n" + "sshl v30.4s, v30.4s, v11.4s\n" + "sshl v31.4s, v31.4s, v12.4s\n" + "403:\n" + + "ldr q14, [x4]\n" // multiplier_fixedpoint + "ldr q15, [x4, #16]\n" // multiplier_fixedpoint + + "smin v11.4s, v9.4s, v8.4s\n" + "smin v12.4s, v10.4s, v8.4s\n" + + // Apply the fixed-point part of the multiplier. + "sqrdmulh v16.4s, v16.4s, v14.4s\n" + "sqrdmulh v17.4s, v17.4s, v15.4s\n" + "sqrdmulh v18.4s, v18.4s, v14.4s\n" + "sqrdmulh v19.4s, v19.4s, v15.4s\n" + "sqrdmulh v20.4s, v20.4s, v14.4s\n" + "sqrdmulh v21.4s, v21.4s, v15.4s\n" + "sqrdmulh v22.4s, v22.4s, v14.4s\n" + "sqrdmulh v23.4s, v23.4s, v15.4s\n" + "sqrdmulh v24.4s, v24.4s, v14.4s\n" + "sqrdmulh v25.4s, v25.4s, v15.4s\n" + "sqrdmulh v26.4s, v26.4s, v14.4s\n" + "sqrdmulh v27.4s, v27.4s, v15.4s\n" + "sqrdmulh v28.4s, v28.4s, v14.4s\n" + "sqrdmulh v29.4s, v29.4s, v15.4s\n" + "sqrdmulh v30.4s, v30.4s, v14.4s\n" + "sqrdmulh v31.4s, v31.4s, v15.4s\n" + + // We have some rounding division-by-power-of-two to do. This should + // always use "round to nearest". We allow for some + // freedom in how ties are broken, to strike a good compromise of + // performance on given hardware vs. perfect agreement of results + // across hardware. + // + // When RUY_OPT_NATIVE_ROUNDING is enabled, we allow for implementation + // defined tie-breaks to help performance. On NEON, this means that we + // can just use the NEON rounding instructions, such as srshl. They + // happen to be breaking ties upward. + // + // When RUY_OPT_NATIVE_ROUNDING is disabled, we implement strict + // break-ties-away-from zero, as described in Appendix B of + // https://arxiv.org/pdf/1712.05877.pdf + // When we wrote that, we thought that that would be better unbiased + // than the NEON upwards tie-breaks, and we had observed some + // improvement on some model. However, that is only more unbiased for + // data centered at zero, which was likely the case in that model, + // but is not always the case. If we wanted something more consistently + // unbiased then we should try breaking ties toward-nearest-even. +#if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING) + // Fix up values to be right-shifted, so that the (round to nearest, + // break ties upward) behavior of srshl applied to these fixed-up + // values, produces the same result as the desired (round to nearest, + // break ties away from zero) behavior on the original values. + "and v8.16b, v16.16b, v11.16b\n" + "and v9.16b, v17.16b, v12.16b\n" + "and v14.16b, v18.16b, v11.16b\n" + "and v15.16b, v19.16b, v12.16b\n" + "sshr v8.4s, v8.4s, #31\n" + "sshr v9.4s, v9.4s, #31\n" + "sshr v14.4s, v14.4s, #31\n" + "sshr v15.4s, v15.4s, #31\n" + "sqadd v16.4s, v16.4s, v8.4s\n" + "sqadd v17.4s, v17.4s, v9.4s\n" + "sqadd v18.4s, v18.4s, v14.4s\n" + "sqadd v19.4s, v19.4s, v15.4s\n" + "and v8.16b, v20.16b, v11.16b\n" + "and v9.16b, v21.16b, v12.16b\n" + "and v14.16b, v22.16b, v11.16b\n" + "and v15.16b, v23.16b, v12.16b\n" + "sshr v8.4s, v8.4s, #31\n" + "sshr v9.4s, v9.4s, #31\n" + "sshr v14.4s, v14.4s, #31\n" + "sshr v15.4s, v15.4s, #31\n" + "sqadd v20.4s, v20.4s, v8.4s\n" + "sqadd v21.4s, v21.4s, v9.4s\n" + "sqadd v22.4s, v22.4s, v14.4s\n" + "sqadd v23.4s, v23.4s, v15.4s\n" + "and v8.16b, v24.16b, v11.16b\n" + "and v9.16b, v25.16b, v12.16b\n" + "and v14.16b, v26.16b, v11.16b\n" + "and v15.16b, v27.16b, v12.16b\n" + "sshr v8.4s, v8.4s, #31\n" + "sshr v9.4s, v9.4s, #31\n" + "sshr v14.4s, v14.4s, #31\n" + "sshr v15.4s, v15.4s, #31\n" + "sqadd v24.4s, v24.4s, v8.4s\n" + "sqadd v25.4s, v25.4s, v9.4s\n" + "sqadd v26.4s, v26.4s, v14.4s\n" + "sqadd v27.4s, v27.4s, v15.4s\n" + "and v8.16b, v28.16b, v11.16b\n" + "and v9.16b, v29.16b, v12.16b\n" + "and v14.16b, v30.16b, v11.16b\n" + "and v15.16b, v31.16b, v12.16b\n" + "sshr v8.4s, v8.4s, #31\n" + "sshr v9.4s, v9.4s, #31\n" + "sshr v14.4s, v14.4s, #31\n" + "sshr v15.4s, v15.4s, #31\n" + "sqadd v28.4s, v28.4s, v8.4s\n" + "sqadd v29.4s, v29.4s, v9.4s\n" + "sqadd v30.4s, v30.4s, v14.4s\n" + "sqadd v31.4s, v31.4s, v15.4s\n" +#endif + // At this point we have reduced the problem of correctly implementing + // rounding divide-by-power-of-two, to what the SRSHL instruction can + // do. + "srshl v16.4s, v16.4s, v11.4s\n" + "srshl v17.4s, v17.4s, v12.4s\n" + "srshl v18.4s, v18.4s, v11.4s\n" + "srshl v19.4s, v19.4s, v12.4s\n" + "srshl v20.4s, v20.4s, v11.4s\n" + "srshl v21.4s, v21.4s, v12.4s\n" + "srshl v22.4s, v22.4s, v11.4s\n" + "srshl v23.4s, v23.4s, v12.4s\n" + "srshl v24.4s, v24.4s, v11.4s\n" + "srshl v25.4s, v25.4s, v12.4s\n" + "srshl v26.4s, v26.4s, v11.4s\n" + "srshl v27.4s, v27.4s, v12.4s\n" + "srshl v28.4s, v28.4s, v11.4s\n" + "srshl v29.4s, v29.4s, v12.4s\n" + "srshl v30.4s, v30.4s, v11.4s\n" + "srshl v31.4s, v31.4s, v12.4s\n" + + "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n" + "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n" + "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n" + "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n" + + // Cast-and-saturate from int32 to int16 + "sqxtn v16.4h, v16.4s\n" + "sqxtn2 v16.8h, v17.4s\n" + "sqxtn v17.4h, v18.4s\n" + "sqxtn2 v17.8h, v19.4s\n" + "sqxtn v18.4h, v20.4s\n" + "sqxtn2 v18.8h, v21.4s\n" + "sqxtn v19.4h, v22.4s\n" + "sqxtn2 v19.8h, v23.4s\n" + "sqxtn v20.4h, v24.4s\n" + "sqxtn2 v20.8h, v25.4s\n" + "sqxtn v21.4h, v26.4s\n" + "sqxtn2 v21.8h, v27.4s\n" + "sqxtn v22.4h, v28.4s\n" + "sqxtn2 v22.8h, v29.4s\n" + "sqxtn v23.4h, v30.4s\n" + "sqxtn2 v23.8h, v31.4s\n" + + // At this point, v24 -- v31 aren't used anymore for the current block, + // so we can start clearing these accumulators for the next block + // (next iteration of the main loop). + RUY_MAKE_ZERO(v24) + RUY_MAKE_ZERO(v25) + RUY_MAKE_ZERO(v26) + RUY_MAKE_ZERO(v27) + RUY_MAKE_ZERO(v28) + RUY_MAKE_ZERO(v29) + RUY_MAKE_ZERO(v30) + RUY_MAKE_ZERO(v31) + + // Add the destination zero point + "dup v14.8h, v13.h[4]\n" + "add v16.8h, v16.8h, v14.8h\n" + "add v17.8h, v17.8h, v14.8h\n" + "add v18.8h, v18.8h, v14.8h\n" + "add v19.8h, v19.8h, v14.8h\n" + "add v20.8h, v20.8h, v14.8h\n" + "add v21.8h, v21.8h, v14.8h\n" + "add v22.8h, v22.8h, v14.8h\n" + "add v23.8h, v23.8h, v14.8h\n" + + // Cast-and-saturate from int16 to uint8 + "sqxtun v16.8b, v16.8h\n" + "sqxtun2 v16.16b, v17.8h\n" + "sqxtun v17.8b, v18.8h\n" + "sqxtun2 v17.16b, v19.8h\n" + "sqxtun v18.8b, v20.8h\n" + "sqxtun2 v18.16b, v21.8h\n" + "sqxtun v19.8b, v22.8h\n" + "sqxtun2 v19.16b, v23.8h\n" + + // Load the clamp_min, clamp_max bounds + "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + "dup v14.16b, w2\n" // clamp_min + "dup v15.16b, w3\n" // clamp_max + + // Apply the clamp_min bound + "umax v16.16b, v16.16b, v14.16b\n" + "umax v17.16b, v17.16b, v14.16b\n" + "umax v18.16b, v18.16b, v14.16b\n" + "umax v19.16b, v19.16b, v14.16b\n" + + // Apply the clamp_max bound + "umin v16.16b, v16.16b, v15.16b\n" + "umin v17.16b, v17.16b, v15.16b\n" + "umin v18.16b, v18.16b, v15.16b\n" + "umin v19.16b, v19.16b, v15.16b\n" + + // Make it so that all of the final 8bit values are stored in the + // first 64bits of 128bit NEON registers, so they can be stored + // by 64bit st1 store instructions with byte alignment. + "dup d20, v16.d[1]\n" + "dup d21, v17.d[1]\n" + "dup d22, v18.d[1]\n" + "dup d23, v19.d[1]\n" + + // Compute how much of the 8x8 block of destination 8bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 8x8, there are some 8x8 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + "sub w2, %w[dst_cols], %w[col]\n" + "mov w3, #8\n" + "cmp w1, #8\n" + // Compute w1 = how many rows of the 8x8 block fit + "csel w1, w1, w3, le\n" + "cmp w2, #8\n" + // Compute w2 = how many cols of the 8x8 block fit + "csel w2, w2, w3, le\n" + + // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. + "cmp w1, w3\n" + "ccmp w2, w3, 0, eq\n" + // Yes, all of the 8x8 block fits, go to fast path. + "beq 30f\n" + // Not all of the 8x8 block fits. + // Set (x3 address, x4 stride) to write to dst_tmp_buf + "mov x3, %[dst_tmp_buf]\n" + "mov x4, #8\n" + "b 31f\n" + "30:\n" + // Yes, all of the 8x8 block fits. + // Set (x3 address, x4 stride) to write directly to destination matrix. + "mov x3, %[dst_ptr]\n" + "mov x4, x11\n" + "31:\n" + + // Write our 8bit values to the destination described by + // (x3 address, x4 stride). + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v16.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v16) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v20.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v20) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v17.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v17) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v21.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v21) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v18.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v18) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v22.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v22) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v19.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v19) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v23.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v23) + + // For the next block: perform the first few multiply-adds on the data + // that we have already loaded. + ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" + ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" + ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" + ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" + + // If all of the 8x8 block fits, we just finished writing it to the + // destination, so we skip the next part. + "beq 41f\n" + // Not all of the 8x8 block fits in the destination matrix. We just + // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over + // it to copy into the destination matrix the part that fits. + "mov x3, %[dst_tmp_buf]\n" + "mov x4, %[dst_ptr]\n" + "mov w6, #0\n" + "50:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov w5, #0\n" + "51:\n" + "ldrb w7, [x3, w5, uxtw]\n" + "strb w7, [x4, w5, uxtw]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 51b\n" + "add w6, w6, #1\n" + "add x3, x3, #8\n" + "add x4, x4, x11\n" + "cmp w6, w2\n" + "blt 50b\n" + "41:\n" + "add %[dst_ptr], %[dst_ptr], #8\n" + // At this point we have completely finished writing values to the + // destination matrix for the current block. + + "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n" + + // Cast-and-saturate from int32 to int16 + "sqxtn v16.4h, v16.4s\n" + "sqxtn2 v16.8h, v17.4s\n" + "sqxtn v17.4h, v18.4s\n" + "sqxtn2 v17.8h, v19.4s\n" + "sqxtn v18.4h, v20.4s\n" + "sqxtn2 v18.8h, v21.4s\n" + "sqxtn v19.4h, v22.4s\n" + "sqxtn2 v19.8h, v23.4s\n" + "sqxtn v20.4h, v24.4s\n" + "sqxtn2 v20.8h, v25.4s\n" + "sqxtn v21.4h, v26.4s\n" + "sqxtn2 v21.8h, v27.4s\n" + "sqxtn v22.4h, v28.4s\n" + "sqxtn2 v22.8h, v29.4s\n" + "sqxtn v23.4h, v30.4s\n" + "sqxtn2 v23.8h, v31.4s\n" + + // At this point, v24 -- v31 aren't used anymore for the current block, + // so we can start clearing these accumulators for the next block + // (next iteration of the main loop). + RUY_MAKE_ZERO(v24) + RUY_MAKE_ZERO(v25) + RUY_MAKE_ZERO(v26) + RUY_MAKE_ZERO(v27) + RUY_MAKE_ZERO(v28) + RUY_MAKE_ZERO(v29) + RUY_MAKE_ZERO(v30) + RUY_MAKE_ZERO(v31) + + // Add the destination zero point + "dup v14.8h, v13.h[4]\n" + "add v16.8h, v16.8h, v14.8h\n" + "add v17.8h, v17.8h, v14.8h\n" + "add v18.8h, v18.8h, v14.8h\n" + "add v19.8h, v19.8h, v14.8h\n" + "add v20.8h, v20.8h, v14.8h\n" + "add v21.8h, v21.8h, v14.8h\n" + "add v22.8h, v22.8h, v14.8h\n" + "add v23.8h, v23.8h, v14.8h\n" + + // Cast-and-saturate from int16 to uint8 + "sqxtn v16.8b, v16.8h\n" + "sqxtn2 v16.16b, v17.8h\n" + "sqxtn v17.8b, v18.8h\n" + "sqxtn2 v17.16b, v19.8h\n" + "sqxtn v18.8b, v20.8h\n" + "sqxtn2 v18.16b, v21.8h\n" + "sqxtn v19.8b, v22.8h\n" + "sqxtn2 v19.16b, v23.8h\n" + + // Load the clamp_min, clamp_max bounds + "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + "dup v14.16b, w2\n" // clamp_min + "dup v15.16b, w3\n" // clamp_max + + // Apply the clamp_min bound + "smax v16.16b, v16.16b, v14.16b\n" + "smax v17.16b, v17.16b, v14.16b\n" + "smax v18.16b, v18.16b, v14.16b\n" + "smax v19.16b, v19.16b, v14.16b\n" + + // Apply the clamp_max bound + "smin v16.16b, v16.16b, v15.16b\n" + "smin v17.16b, v17.16b, v15.16b\n" + "smin v18.16b, v18.16b, v15.16b\n" + "smin v19.16b, v19.16b, v15.16b\n" + + // Make it so that all of the final 8bit values are stored in the + // first 64bits of 128bit NEON registers, so they can be stored + // by 64bit st1 store instructions with byte alignment. + "dup d20, v16.d[1]\n" + "dup d21, v17.d[1]\n" + "dup d22, v18.d[1]\n" + "dup d23, v19.d[1]\n" + + // Compute how much of the 8x8 block of destination 8bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 8x8, there are some 8x8 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + "sub w2, %w[dst_cols], %w[col]\n" + "mov w3, #8\n" + "cmp w1, #8\n" + // Compute w1 = how many rows of the 8x8 block fit + "csel w1, w1, w3, le\n" + "cmp w2, #8\n" + // Compute w2 = how many cols of the 8x8 block fit + "csel w2, w2, w3, le\n" + + // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. + "cmp w1, w3\n" + "ccmp w2, w3, 0, eq\n" + // Yes, all of the 8x8 block fits, go to fast path. + "beq 130f\n" + // Not all of the 8x8 block fits. + // Set (x3 address, x4 stride) to write to dst_tmp_buf + "mov x3, %[dst_tmp_buf]\n" + "mov x4, #8\n" + "b 131f\n" + "130:\n" + // Yes, all of the 8x8 block fits. + // Set (x3 address, x4 stride) to write directly to destination matrix. + "mov x3, %[dst_ptr]\n" + "mov x4, x11\n" + "131:\n" + + // Write our 8bit values to the destination described by + // (x3 address, x4 stride). + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v16.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v16) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v20.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v20) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v17.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v17) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v21.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v21) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v18.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v18) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v22.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v22) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v19.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v19) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v23.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v23) + + // For the next block: perform the first few multiply-adds on the data + // that we have already loaded. + ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" + ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" + ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" + ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" + + // If all of the 8x8 block fits, we just finished writing it to the + // destination, so we skip the next part. + "beq 141f\n" + // Not all of the 8x8 block fits in the destination matrix. We just + // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over + // it to copy into the destination matrix the part that fits. + "mov x3, %[dst_tmp_buf]\n" + "mov x4, %[dst_ptr]\n" + "mov w6, #0\n" + "150:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov w5, #0\n" + "151:\n" + "ldrb w7, [x3, w5, uxtw]\n" + "strb w7, [x4, w5, uxtw]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 151b\n" + "add w6, w6, #1\n" + "add x3, x3, #8\n" + "add x4, x4, x11\n" + "cmp w6, w2\n" + "blt 150b\n" + "141:\n" + "add %[dst_ptr], %[dst_ptr], #8\n" + // At this point we have completely finished writing values to the + // destination matrix for the current block. + + "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n" + + // Add the destination zero point + "dup v14.8h, v13.h[4]\n" + "saddw v16.4s, v16.4s, v14.4h\n" + "saddw v17.4s, v17.4s, v14.4h\n" + "saddw v18.4s, v18.4s, v14.4h\n" + "saddw v19.4s, v19.4s, v14.4h\n" + "saddw v20.4s, v20.4s, v14.4h\n" + "saddw v21.4s, v21.4s, v14.4h\n" + "saddw v22.4s, v22.4s, v14.4h\n" + "saddw v23.4s, v23.4s, v14.4h\n" + "saddw v24.4s, v24.4s, v14.4h\n" + "saddw v25.4s, v25.4s, v14.4h\n" + "saddw v26.4s, v26.4s, v14.4h\n" + "saddw v27.4s, v27.4s, v14.4h\n" + "saddw v28.4s, v28.4s, v14.4h\n" + "saddw v29.4s, v29.4s, v14.4h\n" + "saddw v30.4s, v30.4s, v14.4h\n" + "saddw v31.4s, v31.4s, v14.4h\n" + + // Cast-and-saturate from int32 to int16 + "sqxtn v16.4h, v16.4s\n" + "sqxtn2 v16.8h, v17.4s\n" + "sqxtn v17.4h, v18.4s\n" + "sqxtn2 v17.8h, v19.4s\n" + "sqxtn v18.4h, v20.4s\n" + "sqxtn2 v18.8h, v21.4s\n" + "sqxtn v19.4h, v22.4s\n" + "sqxtn2 v19.8h, v23.4s\n" + "sqxtn v20.4h, v24.4s\n" + "sqxtn2 v20.8h, v25.4s\n" + "sqxtn v21.4h, v26.4s\n" + "sqxtn2 v21.8h, v27.4s\n" + "sqxtn v22.4h, v28.4s\n" + "sqxtn2 v22.8h, v29.4s\n" + "sqxtn v23.4h, v30.4s\n" + "sqxtn2 v23.8h, v31.4s\n" + + // At this point, v24 -- v31 aren't used anymore for the current block, + // so we can start clearing these accumulators for the next block + // (next iteration of the main loop). + RUY_MAKE_ZERO(v24) + RUY_MAKE_ZERO(v25) + RUY_MAKE_ZERO(v26) + RUY_MAKE_ZERO(v27) + RUY_MAKE_ZERO(v28) + RUY_MAKE_ZERO(v29) + RUY_MAKE_ZERO(v30) + RUY_MAKE_ZERO(v31) + + // Load the clamp_min, clamp_max bounds + "ldrsh w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + "ldrsh w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + "dup v14.8h, w2\n" // clamp_min + "dup v15.8h, w3\n" // clamp_max + + // Apply the clamp_min bound + "smax v16.8h, v16.8h, v14.8h\n" + "smax v17.8h, v17.8h, v14.8h\n" + "smax v18.8h, v18.8h, v14.8h\n" + "smax v19.8h, v19.8h, v14.8h\n" + "smax v20.8h, v20.8h, v14.8h\n" + "smax v21.8h, v21.8h, v14.8h\n" + "smax v22.8h, v22.8h, v14.8h\n" + "smax v23.8h, v23.8h, v14.8h\n" + // Apply the clamp_max bound + "smin v16.8h, v16.8h, v15.8h\n" + "smin v17.8h, v17.8h, v15.8h\n" + "smin v18.8h, v18.8h, v15.8h\n" + "smin v19.8h, v19.8h, v15.8h\n" + "smin v20.8h, v20.8h, v15.8h\n" + "smin v21.8h, v21.8h, v15.8h\n" + "smin v22.8h, v22.8h, v15.8h\n" + "smin v23.8h, v23.8h, v15.8h\n" + + // Compute how much of the 8x8 block of destination 16bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 8x8, there are some 8x8 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + "sub w2, %w[dst_cols], %w[col]\n" + "mov w3, #8\n" + "cmp w1, #8\n" + // Compute w1 = how many rows of the 8x8 block fit + "csel w1, w1, w3, le\n" + "cmp w2, #8\n" + // Compute w1 = how many rows of the 8x8 block fit + "csel w2, w2, w3, le\n" + + // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. + "cmp w1, w3\n" + "ccmp w2, w3, 0, eq\n" + // Yes, all of the 8x8 block fits, go to fast path. + "beq 230f\n" + // Not all of the 8x8 block fits. + // Set (x3 address, x4 stride) to write to dst_tmp_buf + "mov x3, %[dst_tmp_buf]\n" + "mov x4, #16\n" + "b 231f\n" + "230:\n" + // Yes, all of the 8x8 block fits. + // Set (x3 address, x4 stride) to write directly to destination matrix. + "mov x3, %[dst_ptr]\n" + "mov x4, x11\n" + "231:\n" + + // Write our 16bit values to the destination described by + // (x3 address, x4 stride). + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v16.8h}, [x3], x4\n" + RUY_MAKE_ZERO(v16) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v17.8h}, [x3], x4\n" + RUY_MAKE_ZERO(v17) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v18.8h}, [x3], x4\n" + RUY_MAKE_ZERO(v18) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v19.8h}, [x3], x4\n" + RUY_MAKE_ZERO(v19) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v20.8h}, [x3], x4\n" + RUY_MAKE_ZERO(v20) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v21.8h}, [x3], x4\n" + RUY_MAKE_ZERO(v21) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v22.8h}, [x3], x4\n" + RUY_MAKE_ZERO(v22) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v23.8h}, [x3], x4\n" + RUY_MAKE_ZERO(v23) + + // For the next block: perform the first few multiply-adds on the data + // that we have already loaded. + ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" + ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" + ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" + ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" + + // If all of the 8x8 block fits, we just finished writing it to the + // destination, so we skip the next part. + "beq 241f\n" + // Not all of the 8x8 block fits in the destination matrix. We just + // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over + // it to copy into the destination matrix the part that fits. + "mov x3, %[dst_tmp_buf]\n" + "mov x4, %[dst_ptr]\n" + "mov w6, #0\n" + "250:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov w5, #0\n" + "251:\n" + "ldrsh w7, [x3, x5, lsl #1]\n" + "strh w7, [x4, x5, lsl #1]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 251b\n" + "add w6, w6, #1\n" + "add x3, x3, #16\n" + "add x4, x4, x11\n" + "cmp w6, w2\n" + "blt 250b\n" + "241:\n" + "add %[dst_ptr], %[dst_ptr], #16\n" + // At this point we have completely finished writing values to the + // destination matrix for the current block. + + "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n" + + // Since the store type is the same as the accum type, no need for + // downcast. There's also no need for clamp by min/max. + + // Compute how much of the 8x8 block of destination 32it values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 8x8, there are some 8x8 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + "sub w2, %w[dst_cols], %w[col]\n" + "mov w3, #8\n" + "cmp w1, #8\n" + // Compute w1 = how many rows of the 8x8 block fit + "csel w1, w1, w3, le\n" + "cmp w2, #8\n" + // Compute w1 = how many rows of the 8x8 block fit + "csel w2, w2, w3, le\n" + + // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. + "cmp w1, w3\n" + "ccmp w2, w3, 0, eq\n" + // Yes, all of the 8x8 block fits, go to fast path. + "beq 330f\n" + // Not all of the 8x8 block fits. + // Write to dst_tmp_buf + "mov x3, %[dst_tmp_buf]\n" + "st1 {v16.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v16) + "st1 {v17.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v17) + "st1 {v18.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v18) + "st1 {v19.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v19) + "st1 {v20.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v20) + "st1 {v21.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v21) + "st1 {v22.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v22) + "st1 {v23.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v23) + "st1 {v24.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v24) + "st1 {v25.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v25) + "st1 {v26.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v26) + "st1 {v27.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v27) + "st1 {v28.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v28) + "st1 {v29.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v29) + "st1 {v30.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v30) + "st1 {v31.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v31) + + "b 331f\n" + + "330:\n" + // Yes, all of the 8x8 block fits. + "mov x4, %[dst_ptr]\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v16.4s, v17.4s}, [x3], #32\n" + RUY_MAKE_ZERO(v16) + RUY_MAKE_ZERO(v17) + "add x4, x4, x11\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v18.4s, v19.4s}, [x3], #32\n" + RUY_MAKE_ZERO(v18) + RUY_MAKE_ZERO(v19) + "add x4, x4, x11\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v20.4s, v21.4s}, [x3], #32\n" + RUY_MAKE_ZERO(v20) + RUY_MAKE_ZERO(v21) + "add x4, x4, x11\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v22.4s, v23.4s}, [x3], #32\n" + RUY_MAKE_ZERO(v22) + RUY_MAKE_ZERO(v23) + "add x4, x4, x11\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v24.4s, v25.4s}, [x3], #32\n" + RUY_MAKE_ZERO(v24) + RUY_MAKE_ZERO(v25) + "add x4, x4, x11\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v26.4s, v27.4s}, [x3], #32\n" + RUY_MAKE_ZERO(v26) + RUY_MAKE_ZERO(v27) + "add x4, x4, x11\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v28.4s, v29.4s}, [x3], #32\n" + RUY_MAKE_ZERO(v28) + RUY_MAKE_ZERO(v29) + "add x4, x4, x11\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov x3, x4\n" + "st1 {v30.4s, v31.4s}, [x3], #32\n" + RUY_MAKE_ZERO(v30) + RUY_MAKE_ZERO(v31) + + "331:\n" + + // For the next block: perform the first few multiply-adds on the data + // that we have already loaded. + ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" + ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" + ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" + ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" + + // If all of the 8x8 block fits, we just finished writing it to the + // destination, so we skip the next part. + "beq 341f\n" + + // Not all of the 8x8 block fits in the destination matrix. We just + // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over + // it to copy into the destination matrix the part that fits. + "mov x3, %[dst_tmp_buf]\n" + "mov x4, %[dst_ptr]\n" + "mov w6, #0\n" + "350:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov w5, #0\n" + "351:\n" + "ldr w7, [x3, x5, lsl #2]\n" + "str w7, [x4, x5, lsl #2]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 351b\n" + "add w6, w6, #1\n" + "add x3, x3, #32\n" + "add x4, x4, x11\n" + "cmp w6, w2\n" + "blt 350b\n" + "341:\n" + "add %[dst_ptr], %[dst_ptr], #32\n" + // At this point we have completely finished writing values to the + // destination matrix for the current block. + + RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n" + + // Reload some params --- we had used x5 -- x7 for a few other things + // since the last time we had loaded them. + "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" + "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" + + // Move to the next block of the destination matrix, for the next iter + // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already + // been updated earlier. + // Have we reached the end row? + "cmp %w[row], w7\n" + "beq 20f\n" // yes, end row. + // Not end row. Move to the next row. + "add %w[row], %w[row], #8\n" + "b 21f\n" + "20:\n" + // Was already at end row. + "mov %w[row], w6\n" // Move back to first row. + "add %w[col], %w[col], #8\n" // Move to the next column. + "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #3\n" + "mov %[dst_ptr], %[dst_col_ptr]\n" + "21:\n" + + // Main loop exit condition: have we hit the end column? + "cmp %w[col], w8\n" + + // w1 is the number of levels of depth that we have already loaded + // LHS and RHS data for. Corresponding to the initial ld1 instructions + // above, this is currently 4. + "mov w1, #4\n" + + "ble 1b\n" + + // clang-format on + + : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr), + [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), + [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col) + : [ params ] "r"(¶ms), [dst_rows] "r"(params.dst_rows), + [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf), + [dst_type_id] "r"(params.dst_type_id) + : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc", + "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", + "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31"); +} + +// Similar to the above 8-bit dotprod kernel, but specialized for the case of +// RHS cols == 1. +// Relevant target CPUs for this kernel include ARM Cortex-A76, +// since these are 64-bit, out-of-order and with dotprod support. +void Kernel8bitNeonDotprodOutOfOrder1Col(const KernelParams8bit<8, 8>& params) { + profiler::ScopeLabel label( + "Kernel (kNeonDotprod, optimized for out-of-order cores)"); + + CheckOffsetsInKernelParams8bit(params); + + const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; + const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; + const std::int8_t* lhs_ptr = lhs_col_ptr; + const std::int8_t* rhs_ptr = rhs_col_ptr; + void* dst_col_ptr = params.dst_base_ptr; + void* dst_ptr = dst_col_ptr; + int row = params.start_row; + int col = params.start_col; + + // The asm kernel below has the following NEON register allocation: + // + // v16 -- v31 are int32 accumulators. + // During accumulation, v0 -- v15 are used to load int8 data from LHS and + // RHS. At least v0 and v1 are used to load a 8x4 block of LHS, and v2 and + // v3 are used to load a 4x8 block of RHS, like this: + // + // int8 RHS 4x1 block + // /-------\ + // |v2.b[0]| + // | ... | + // |v2.b[3]| + // \-------/ + // int8 LHS 8x4 block + // /---------------------\ /--------\ + // |v0.b[0] ... v0.b[3] | |v16.s[0]| + // | ... ... | | ... | + // |v0.b[12] ... v0.b[15]| |v16.s[3]| + // |v1.b[0] ... v1.b[3] | |v17.s[0]| + // | ... ... | | ... | + // |v1.b[12] ... v1.b[15]| |v17.s[3]| + // \---------------------/ \--------/ + // int32 accumulators 8x1 block + // + // In the RUY_OPT_MAX_STREAMING part of the kernel, this elementary step + // is repeated 4 times, using 4x more registers for LHS and RHS, so that + // is where instead of using v0 -- v3 for LHS and RHS, we use v0 -- v15. + // + // Outside of the RUY_OPT_MAX_STREAMING part of the kernel, v4 -- v7 are + // unused, and v8 -- v15 are used for loading parameters used for the + // post-accumulation part of the kernel. + asm volatile( +#define RUY_MAKE_ZERO(reg) "dup " #reg ".4s, wzr\n" + + // clang-format off + + // Load some parameters into registers. + "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" + "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" + "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" + "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" + "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" + "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" + "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" + + // Load the first 32 bytes of LHS and RHS data. + "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v2.8b}, [%[rhs_ptr]]\n" + "add %[rhs_ptr], %[rhs_ptr], #32\n" + + // Clear accumulators. + RUY_MAKE_ZERO(v16) + RUY_MAKE_ZERO(v17) + + // w1 is the number of levels of depth that we have already loaded + // LHS and RHS data for. Corresponding to the initial ld1 instructions + // above, this is currently 4. + "mov w1, #4\n" + + // Perform the first few multiply-adds on the data that we have already + // loaded. + ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" + + // Main loop of the whole GEMM, over rows and columns of the + // destination matrix. + "1:\n" + + // Ordinary kernel inner loop (over depth), the simpler loop that the + // above was an equivalent 4x-partially-unrolled version of. + + // Reminder - w1 is how many levels of depth we have already loaded + // data for, w12 is the total depth. + "cmp w1, w12\n" + "beq 79f\n" + + "2:\n" + + // Because of the data that we have already loaded, we can start the + // loop body right away with some multiply-adds. + // Each iteration of this loop advances by 4 levels of depth. + "add w1, w1, #4\n" + "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" + ".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n" + // Loop termination condition. + "cmp w1, w12\n" + "ld1 {v2.8b}, [%[rhs_ptr]]\n" + "add %[rhs_ptr], %[rhs_ptr], #32\n" + ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" + "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" + + "blt 2b\n" + + "79:\n" + // End of the inner loop on depth. Now perform the remaining + // multiply-adds of the last 4 levels of depth, for which the LHS + // and RHS data is already loaded. + + ".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n" + + // End of accumulation. The registers v16 -- v31 contain the final + // int32 accumulator values of the current 8x8 destination block. + // We now have to compute the final 8-bit values from these int32 + // accumulators, and advance to the next 8x8 block. We intertwine + // these two aspects whenever possible for optimal pipelining, both + // at the data flow level (prefetch data for next block as early as + // possible) and instruction pipelining level (some of the next-block + // work can dual-issue with some of the final work on the current + // block). + + // Logic to advance to the next block in preparation for the next + // iteration of the main loop. For now, we only want to compute + // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are + // not yet ready to update the values of row and col, as we still need + // the current values for the rest of the work on the current block. + + "cmp %w[row], w7\n" // Have we finished the last row? + "bge 4f\n" // If finished last row, go to 4 + // Not finished last row: then advance to next row. + "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #3\n" + "b 5f\n" + "4:\n" // Finished last row... + "mov %[lhs_col_ptr], x5\n" // Go back to first row + // Now we need to advance to the next column. If we already + // finished the last column, then in principle we are done, however + // we can't just return here, as we need to allow the end work of the + // current block to complete. The good news is that at this point it + // doesn't matter what data we load for the next column, since + // we will exit from the main loop below before actually storing + // anything computed from that data. + "cmp %w[col], w8\n" // Have we finished the last column? + "bge 5f\n" // If yes, just carry on without updating the column pointer. + // Not finished last column: then advance to next column. + "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #3\n" + "5:\n" + + // Set the LHS and RHS data pointers to the start of the columns just + // computed. + "mov %[lhs_ptr], %[lhs_col_ptr]\n" + "mov %[rhs_ptr], %[rhs_col_ptr]\n" + + // Load some parameters needed for the end work on current block. + RUY_MAKE_ZERO(v8) + "ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" + "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n" + "ins v13.h[4], w4\n" // dst_zero_point + "ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n" + "ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" + "dup v9.4s, w3\n" // create prod_zp_depth_vec + "add x5, x4, %x[row], lsl #2\n" + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" + "csel x4, x4, x5, eq\n" + + "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" + "add x5, x1, %x[row], lsl #2\n" + + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" + "csel x1, x1, x5, eq\n" + + // Load 8 bias values. + "ld1 {v14.4s}, [x1], #16\n" + "ld1 {v15.4s}, [x1]\n" + + // Now that we know what LHS and RHS data the next iteration of the + // main loop will need to load, we start loading the first 32 bytes of + // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore + // in the rest of the work on the current block. + "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v2.8b}, [%[rhs_ptr]]\n" + "add %[rhs_ptr], %[rhs_ptr], #32\n" + + // Add to the bias values the product (depth * lhs_zero_point * rhs_zero_point), + // See the term NZ1Z2 in equation (7) in https://arxiv.org/pdf/1712.05877.pdf + "add v14.4s, v14.4s, v9.4s\n" + "add v15.4s, v15.4s, v9.4s\n" + + // Perform the bias-addition (per the above, we have just folded into + // the bias the (depth * lhs_zero_point * rhs_zero_point) term.) + "add v16.4s, v16.4s, v14.4s\n" + "add v17.4s, v17.4s, v15.4s\n" + + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n" + "beq 401f\n" + "ldr x3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n" + "add x3, x3, %x[col], lsl #2\n" + "ld1 {v14.4s}, [x3], #16\n" + "ld1 {v15.4s}, [x3]\n" + "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n" + "dup v10.4s, w5\n" // create lhs_zero_point_vec + // Subtract rhs_sums * lhs_zero_point, per + // equation (7) in https://arxiv.org/pdf/1712.05877.pdf + "mls v16.4s, v10.4s, v14.s[0]\n" + "mls v17.4s, v10.4s, v14.s[0]\n" + "401:\n" + + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n" + "beq 402f\n" + "ldr x2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n" + "add x2, x2, %x[row], lsl #2\n" + "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n" + // Load 4 lhs_sums values. + "ld1 {v11.4s}, [x2], #16\n" + "ld1 {v12.4s}, [x2]\n" + "ins v13.s[1], w5\n" // rhs_zero_point + // Compute lhs_sums * rhs_zero_point. + "mul v11.4s, v11.4s, v13.s[1]\n" + "mul v12.4s, v12.4s, v13.s[1]\n" + // Subtract lhs_sums * rhs_zero_point, per + // equation (7) in https://arxiv.org/pdf/1712.05877.pdf + "sub v16.4s, v16.4s, v11.4s\n" + "sub v17.4s, v17.4s, v12.4s\n" + + "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n" + "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n" + + "402:\n" + + // At this point we have computed the final int32 values. Now we + // start down-quantizing them to obtain the final 8bit values from them. + + // As part of this down-quantization, our int32 values will be + // multiplied by a multiplier that has a fixed-point component and an + // exponent component. + + //Load the exponent part of the multiplier. + "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n" + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" + "add x5, x1, %x[row], lsl #2\n" + "csel x1, x1, x5, eq\n" + + "ldr q9, [x1]\n" + "ldr q10, [x1, #16]\n" + + "tst w6, #" RUY_STR(RUY_ASM_FLAG_NEEDS_LEFT_SHIFT) "\n" + "beq 403f\n" + "smax v11.4s, v9.4s, v8.4s\n" + "smax v12.4s, v10.4s, v8.4s\n" + "sshl v16.4s, v16.4s, v11.4s\n" + "sshl v17.4s, v17.4s, v12.4s\n" + "403:\n" + + "ldr q14, [x4]\n" // multiplier_fixedpoint + "ldr q15, [x4, #16]\n" // multiplier_fixedpoint + + "smin v11.4s, v9.4s, v8.4s\n" + "smin v12.4s, v10.4s, v8.4s\n" + + // Apply the fixed-point part of the multiplier. + "sqrdmulh v16.4s, v16.4s, v14.4s\n" + "sqrdmulh v17.4s, v17.4s, v15.4s\n" + + // We have some rounding division-by-power-of-two to do. This should + // always use "round to nearest". We allow for some + // freedom in how ties are broken, to strike a good compromise of + // performance on given hardware vs. perfect agreement of results + // across hardware. + // + // When RUY_OPT_NATIVE_ROUNDING is enabled, we allow for implementation + // defined tie-breaks to help performance. On NEON, this means that we + // can just use the NEON rounding instructions, such as srshl. They + // happen to be breaking ties upward. + // + // When RUY_OPT_NATIVE_ROUNDING is disabled, we implement strict + // break-ties-away-from zero, as described in Appendix B of + // https://arxiv.org/pdf/1712.05877.pdf + // When we wrote that, we thought that that would be better unbiased + // than the NEON upwards tie-breaks, and we had observed some + // improvement on some model. However, that is only more unbiased for + // data centered at zero, which was likely the case in that model, + // but is not always the case. If we wanted something more consistently + // unbiased then we should try breaking ties toward-nearest-even. +#if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING) + // Fix up values to be right-shifted, so that the (round to nearest, + // break ties upward) behavior of srshl applied to these fixed-up + // values, produces the same result as the desired (round to nearest, + // break ties away from zero) behavior on the original values. + "and v8.16b, v16.16b, v11.16b\n" + "and v9.16b, v17.16b, v12.16b\n" + "sshr v8.4s, v8.4s, #31\n" + "sshr v9.4s, v9.4s, #31\n" + "sqadd v16.4s, v16.4s, v8.4s\n" + "sqadd v17.4s, v17.4s, v9.4s\n" + +#endif + // At this point we have reduced the problem of correctly implementing + // rounding divide-by-power-of-two, to what the SRSHL instruction can + // do. + "srshl v16.4s, v16.4s, v11.4s\n" + "srshl v17.4s, v17.4s, v12.4s\n" + + "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n" + "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n" + "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n" + "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n" + + // Cast-and-saturate from int32 to int16 + "sqxtn v16.4h, v16.4s\n" + "sqxtn2 v16.8h, v17.4s\n" + // All data in v16 at this point. + + // Add the destination zero point + "dup v14.8h, v13.h[4]\n" + "add v16.8h, v16.8h, v14.8h\n" + + // Cast-and-saturate from int16 to uint8, leaving all data in the + // lower half of v16. + "sqxtun v16.8b, v16.8h\n" + + // Load the clamp_min, clamp_max bounds + "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + "dup v14.16b, w2\n" // clamp_min + "dup v15.16b, w3\n" // clamp_max + + // Apply the clamp_min bound + "umax v16.16b, v16.16b, v14.16b\n" + + // Apply the clamp_max bound + "umin v16.16b, v16.16b, v15.16b\n" + + // Make it so that all of the final 8bit values are stored in the + // first 64bits of 128bit NEON registers, so they can be stored + // by 64bit st1 store instructions with byte alignment. + "dup d20, v16.d[1]\n" + + // Compute how much of the 8x1 block of destination 8bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 8x1, there are some 8x1 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + "sub w2, %w[dst_cols], %w[col]\n" + "mov w3, #8\n" + "cmp w1, #8\n" + // Compute w1 = how many rows of the 8x1 block fit + "csel w1, w1, w3, le\n" + "cmp w2, #8\n" + + // Test if w1==8, i.e. if all of the 8x1 block fits. + "cmp w1, w3\n" + // Yes, all of the 8x1 block fits, go to fast path. + "beq 30f\n" + // Not all of the 8x1 block fits. + // Set (x3 address, x4 stride) to write to dst_tmp_buf + "mov x3, %[dst_tmp_buf]\n" + "mov x4, #8\n" + "b 31f\n" + "30:\n" + // Yes, all of the 8x1 block fits. + // Set (x3 address, x4 stride) to write directly to destination matrix. + "mov x3, %[dst_ptr]\n" + "mov x4, x11\n" + "31:\n" + + // Write our 8bit values to the destination + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "st1 {v16.8b}, [x3]\n" + RUY_MAKE_ZERO(v16) + RUY_MAKE_ZERO(v17) + + // For the next block: perform the first few multiply-adds on the data + // that we have already loaded. + ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" + + // If all of the 8x8 block fits, we just finished writing it to the + // destination, so we skip the next part. + "beq 41f\n" + // Not all of the 8x8 block fits in the destination matrix. We just + // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over + // it to copy into the destination matrix the part that fits. + "mov x3, %[dst_tmp_buf]\n" + "mov x4, %[dst_ptr]\n" + "mov w6, #0\n" + "50:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov w5, #0\n" + "51:\n" + "ldrb w7, [x3, w5, uxtw]\n" + "strb w7, [x4, w5, uxtw]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 51b\n" + "41:\n" + "add %[dst_ptr], %[dst_ptr], #8\n" + // At this point we have completely finished writing values to the + // destination matrix for the current block. + + "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n" + + // Cast-and-saturate from int32 to int16 + "sqxtn v16.4h, v16.4s\n" + "sqxtn2 v16.8h, v17.4s\n" + + + // Add the destination zero point + "dup v14.8h, v13.h[4]\n" + "add v16.8h, v16.8h, v14.8h\n" + + // Cast-and-saturate from int16 to uint8 + "sqxtn v16.8b, v16.8h\n" + + // Load the clamp_min, clamp_max bounds + "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + "dup v14.16b, w2\n" // clamp_min + "dup v15.16b, w3\n" // clamp_max + + // Apply the clamp_min bound + "smax v16.16b, v16.16b, v14.16b\n" + + // Apply the clamp_max bound + "smin v16.16b, v16.16b, v15.16b\n" + + // Make it so that all of the final 8bit values are stored in the + // first 64bits of 128bit NEON registers, so they can be stored + // by 64bit st1 store instructions with byte alignment. + "dup d20, v16.d[1]\n" + + // Compute how much of the 8x1 block of destination 8bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 8x8, there are some 8x8 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + "sub w2, %w[dst_cols], %w[col]\n" + "mov w3, #8\n" + "cmp w1, #8\n" + // Compute w1 = how many rows of the 8x1 block fit + "csel w1, w1, w3, le\n" + "cmp w2, #8\n" + + // Test if w1==8, i.e. if all of the 8x1 block fits. + "cmp w1, w3\n" + // Yes, all of the 8x1 block fits, go to fast path. + "beq 130f\n" + // Not all of the 8x1 block fits. + // Set (x3 address, x4 stride) to write to dst_tmp_buf + "mov x3, %[dst_tmp_buf]\n" + "mov x4, #8\n" + "b 131f\n" + "130:\n" + // Yes, all of the 8x8 block fits. + // Set (x3 address, x4 stride) to write directly to destination matrix. + "mov x3, %[dst_ptr]\n" + "mov x4, x11\n" + "131:\n" + + // Write our 8bit values to the destination + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "st1 {v16.8b}, [x3]\n" + RUY_MAKE_ZERO(v16) + RUY_MAKE_ZERO(v17) + + // For the next block: perform the first few multiply-adds on the data + // that we have already loaded. + ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" + + // If all of the 8x8 block fits, we just finished writing it to the + // destination, so we skip the next part. + "beq 141f\n" + // Not all of the 8x8 block fits in the destination matrix. We just + // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over + // it to copy into the destination matrix the part that fits. + "mov x3, %[dst_tmp_buf]\n" + "mov x4, %[dst_ptr]\n" + "mov w6, #0\n" + "150:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov w5, #0\n" + "151:\n" + "ldrb w7, [x3, w5, uxtw]\n" + "strb w7, [x4, w5, uxtw]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 151b\n" + "141:\n" + "add %[dst_ptr], %[dst_ptr], #8\n" + // At this point we have completely finished writing values to the + // destination matrix for the current block. + + "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n" + + // Add the destination zero point + "dup v14.8h, v13.h[4]\n" + "saddw v16.4s, v16.4s, v14.4h\n" + "saddw v17.4s, v17.4s, v14.4h\n" + + // Cast-and-saturate from int32 to int16 + "sqxtn v16.4h, v16.4s\n" + "sqxtn2 v16.8h, v17.4s\n" + + // Load the clamp_min, clamp_max bounds + "ldrsh w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + "ldrsh w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + "dup v14.8h, w2\n" // clamp_min + "dup v15.8h, w3\n" // clamp_max + + // Apply the clamp_min bound + "smax v16.8h, v16.8h, v14.8h\n" + // Apply the clamp_max bound + "smin v16.8h, v16.8h, v15.8h\n" + + // Compute how much of the 8x1 block of destination 16bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 8x8, there are some 8x1 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + "sub w2, %w[dst_cols], %w[col]\n" + "mov w3, #8\n" + "cmp w1, #8\n" + // Compute w1 = how many rows of the 8x1 block fit + "csel w1, w1, w3, le\n" + "cmp w2, #8\n" + + // Test if w1==8, i.e. if all of the 8x8 block fits. + "cmp w1, w3\n" + // Yes, all of the 8x1 block fits, go to fast path. + "beq 230f\n" + // Not all of the 8x1 block fits. + // Set (x3 address, x4 stride) to write to dst_tmp_buf + "mov x3, %[dst_tmp_buf]\n" + "mov x4, #16\n" + "b 231f\n" + "230:\n" + // Yes, all of the 8x1 block fits. + // Set (x3 address, x4 stride) to write directly to destination matrix. + "mov x3, %[dst_ptr]\n" + "mov x4, x11\n" + "231:\n" + + // Write our 16bit values to the destination + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "st1 {v16.8h}, [x3]\n" + RUY_MAKE_ZERO(v16) + RUY_MAKE_ZERO(v17) + + // For the next block: perform the first few multiply-adds on the data + // that we have already loaded. + ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" + + // If all of the 8x1 block fits, we just finished writing it to the + // destination, so we skip the next part. + "beq 241f\n" + // Not all of the 8x1 block fits in the destination matrix. We just + // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over + // it to copy into the destination matrix the part that fits. + "mov x3, %[dst_tmp_buf]\n" + "mov x4, %[dst_ptr]\n" + "mov w6, #0\n" + "250:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov w5, #0\n" + "251:\n" + "ldrsh w7, [x3, x5, lsl #1]\n" + "strh w7, [x4, x5, lsl #1]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 251b\n" + "241:\n" + "add %[dst_ptr], %[dst_ptr], #16\n" + // At this point we have completely finished writing values to the + // destination matrix for the current block. + + "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n" + + // Since the store type is the same as the accum type, no need for + // downcast. There's also no need for clamp by min/max. + + // Compute how much of the 8x1 block of destination 32 bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 8x1, there are some 8x1 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + "sub w2, %w[dst_cols], %w[col]\n" + "mov w3, #8\n" + "cmp w1, #8\n" + // Compute w1 = how many rows of the 8x1 block fit + "csel w1, w1, w3, le\n" + "cmp w2, #8\n" + // Compute w1 = how many rows of the 8x8 block fit + "csel w2, w2, w3, le\n" + + // Test if w1==8, i.e. if all of the 8x8 block fits. + "cmp w1, w3\n" + // Yes, all of the 8x1 block fits, go to fast path. + "beq 330f\n" + // Not all of the 8x1 block fits. + // Set (x3 address, x4 stride) to write to dst_tmp_buf + "mov x3, %[dst_tmp_buf]\n" + "mov x4, #16\n" + + // Write our 32bit values to the destination described by + // (x3 address, x4 stride). + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v16.4s}, [x3], x4\n" + RUY_MAKE_ZERO(v16) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v17.4s}, [x3], x4\n" + RUY_MAKE_ZERO(v17) + + "b 331f\n" + + "330:\n" + // Yes, all of the 8x1 block fits. + // Set (x3 address, x4 stride) to write directly to destination matrix. + "mov x4, %[dst_ptr]\n" + "mov x3, x4\n" + + // Write our 32bit values to the destination described by + // (x3 address, x4 stride). + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v16.4s, v17.4s}, [x3], #32\n" + RUY_MAKE_ZERO(v16) + RUY_MAKE_ZERO(v17) + + "331:\n" + + // For the next block: perform the first few multiply-adds on the data + // that we have already loaded. + ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" + + // If all of the 8x8 block fits, we just finished writing it to the + // destination, so we skip the next part. + "beq 341f\n" + + // Not all of the 8x8 block fits in the destination matrix. We just + // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over + // it to copy into the destination matrix the part that fits. + "mov x3, %[dst_tmp_buf]\n" + "mov x4, %[dst_ptr]\n" + "mov w6, #0\n" + "350:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "mov w5, #0\n" + "351:\n" + "ldr w7, [x3, x5, lsl #2]\n" + "str w7, [x4, x5, lsl #2]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 351b\n" + "341:\n" + "add %[dst_ptr], %[dst_ptr], #32\n" + // At this point we have completely finished writing values to the + // destination matrix for the current block. + + RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n" + + // Reload some params --- we had used x5 -- x7 for a few other things + // since the last time we had loaded them. + "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" + "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" + + // Move to the next block of the destination matrix, for the next iter + // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already + // been updated earlier. + // Have we reached the end row? + "cmp %w[row], w7\n" + "beq 20f\n" // yes, end row. + // Not end row. Move to the next row. + "add %w[row], %w[row], #8\n" + "b 21f\n" + "20:\n" + // Was already at end row. + "mov %w[row], w6\n" // Move back to first row. + "add %w[col], %w[col], #8\n" // Move to the next column. + "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #3\n" + "mov %[dst_ptr], %[dst_col_ptr]\n" + "21:\n" + + // Main loop exit condition: have we hit the end column? + "cmp %w[col], w8\n" + + // w1 is the number of levels of depth that we have already loaded + // LHS and RHS data for. Corresponding to the initial ld1 instructions + // above, this is currently 4. + "mov w1, #4\n" + + "ble 1b\n" + + // clang-format on + + : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr), + [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), + [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col) + : [ params ] "r"(¶ms), [dst_rows] "r"(params.dst_rows), + [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf), + [dst_type_id] "r"(params.dst_type_id) + : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc", + "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", + "v13", "v14", "v15", "v16", "v17"); +} + +// Variant of the above Kernel8bitNeonDotprodOutOfOrder, tuned for in-order +// CPUs. Specifically here, the relevant in-order CPUs are ARM Cortex-A55r1, +// since these are 64-bit and support dotprod. +// +// While this kernel does not have a direct equivalent in gemmlowp, it was +// developed based on insights that David Mansell at ARM shared with their +// contribution of gemmlowp kernels tuned for Cortex-A55r1, with very helpful +// comments. Specifically, see this comment about tuning for Cortex-A55r1: +// https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L4412 +void Kernel8bitNeonDotprodInOrder(const KernelParams8bit<8, 8>& params) { + profiler::ScopeLabel label( + "Kernel (kNeonDotprod, optimized for in-order cores)"); + + CheckOffsetsInKernelParams8bit(params); + + const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; + const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; + const std::int8_t* lhs_ptr = lhs_col_ptr; + const std::int8_t* rhs_ptr = rhs_col_ptr; + void* dst_col_ptr = params.dst_base_ptr; + void* dst_ptr = dst_col_ptr; + int row = params.start_row; + int col = params.start_col; + + // The asm kernel below has the following NEON register allocation: + // + // v16 -- v31 are int32 accumulators. + // During accumulation, v0 -- v3 are used to load int8 data from LHS and + // RHS. + // + // int8 RHS 4x8 block + // /-----------------------------------------\ + // |v2.b[0] ... v2.b[12] v3.b[0] ... v3.b[12]| + // | ... ... | + // |v2.b[3] ... v2.b[15] v3.b[3] ... v3.b[15]| + // \-----------------------------------------/ + // int8 LHS 8x4 block + // /---------------------\ /-----------------------------------------\ + // |v0.b[0] ... v0.b[3] | |v16.s[0] ... v30.s[0]| + // | ... ... | | ... ... | + // |v0.b[12] ... v0.b[15]| |v16.s[3] ... v30.s[3]| + // |v1.b[0] ... v1.b[3] | |v17.s[0] ... v31.s[0]| + // | ... ... | | ... ... | + // |v1.b[12] ... v1.b[15]| |v17.s[3] ... v31.s[3]| + // \---------------------/ \-----------------------------------------/ + // int32 accumulators 8x8 block + // + // There is no RUY_OPT_MAX_STREAMING 4x-unrolled part in this kernel because + // we did not observe a benefit of such partial unrolling on in-order CPUs. + // + // v4 -- v7 are unused, and v8 -- v15 are used for loading parameters used for + // the post-accumulation part of the kernel. + asm volatile( +#define RUY_MAKE_ZERO(reg) "dup " #reg ".4s, wzr\n" + + // clang-format off + + // Load some parameters into registers. + "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + RUY_MAKE_ZERO(v16) + "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" + RUY_MAKE_ZERO(v17) + "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" + RUY_MAKE_ZERO(v18) + "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" + RUY_MAKE_ZERO(v19) + "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" + RUY_MAKE_ZERO(v20) + "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" + RUY_MAKE_ZERO(v21) + "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" + RUY_MAKE_ZERO(v22) + "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" + + // Load the first 32 bytes of LHS and RHS data. + "ld1 {v0.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v1.16b}, [%[lhs_ptr]], #16\n" + "ld1 {v2.16b}, [%[rhs_ptr]], #16\n" + "ld1 {v3.16b}, [%[rhs_ptr]], #16\n" + + // Clear accumulators. + RUY_MAKE_ZERO(v23) + RUY_MAKE_ZERO(v24) + RUY_MAKE_ZERO(v25) + RUY_MAKE_ZERO(v26) + RUY_MAKE_ZERO(v27) + // Perform the first few multiply-adds on the data that we have already + // loaded. + ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" + RUY_MAKE_ZERO(v28) + ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" + RUY_MAKE_ZERO(v29) + ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" + RUY_MAKE_ZERO(v30) + ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" + RUY_MAKE_ZERO(v31) + + + "1:\n" + + "add x5, %[lhs_ptr], x12, lsl #3\n" + "sub x5, x5, #32\n" + "cmp %[lhs_ptr], x5\n" + + "beq 79f\n" + + // Main accumulation loop + "2:\n" + ".word 0x4f83e018 // sdot v24.4s, v0.16b, v3.4b[0]\n" + "ldr x1, [%[lhs_ptr], #8]\n" + ".word 0x4fa3e01a // sdot v26.4s, v0.16b, v3.4b[1]\n" + "ldr x3, [%[rhs_ptr], #8]\n" + ".word 0x4f83e81c // sdot v28.4s, v0.16b, v3.4b[2]\n" + "ldr x4, [%[rhs_ptr], #24]\n" + ".word 0x4fa3e81e // sdot v30.4s, v0.16b, v3.4b[3]\n" + "ldr d0, [%[lhs_ptr], #0]\n" + ".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n" + "ins v0.d[1], x1\n" + ".word 0x4fa2e033 // sdot v19.4s, v1.16b, v2.4b[1]\n" + "ldr x2, [%[lhs_ptr], #24]\n" + ".word 0x4f82e835 // sdot v21.4s, v1.16b, v2.4b[2]\n" + "add %[lhs_ptr], %[lhs_ptr], #32\n" + ".word 0x4fa2e837 // sdot v23.4s, v1.16b, v2.4b[3]\n" + "ldr d2, [%[rhs_ptr], #0]\n" + ".word 0x4f83e039 // sdot v25.4s, v1.16b, v3.4b[0]\n" + "ins v2.d[1], x3\n" + ".word 0x4fa3e03b // sdot v27.4s, v1.16b, v3.4b[1]\n" + "cmp %[lhs_ptr], x5\n" + ".word 0x4f83e83d // sdot v29.4s, v1.16b, v3.4b[2]\n" + "add %[rhs_ptr], %[rhs_ptr], #32\n" + ".word 0x4fa3e83f // sdot v31.4s, v1.16b, v3.4b[3]\n" + "ldr d3, [%[rhs_ptr], #-16]\n" + ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" + "ldr d1, [%[lhs_ptr], #-16]\n" + ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" + "ins v3.d[1], x4\n" + ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" + "ins v1.d[1], x2\n" + ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" + "blt 2b\n" + + // Last accumulation steps, nothing left to load. + "79:\n" + ".word 0x4f83e018 // sdot v24.4s, v0.16b, v3.4b[0]\n" + "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + ".word 0x4fa3e01a // sdot v26.4s, v0.16b, v3.4b[1]\n" + "cmp %w[row], w7\n" // Have we finished the last row? + ".word 0x4f83e81c // sdot v28.4s, v0.16b, v3.4b[2]\n" + ".word 0x4fa3e81e // sdot v30.4s, v0.16b, v3.4b[3]\n" + ".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n" + ".word 0x4fa2e033 // sdot v19.4s, v1.16b, v2.4b[1]\n" + ".word 0x4f82e835 // sdot v21.4s, v1.16b, v2.4b[2]\n" + ".word 0x4fa2e837 // sdot v23.4s, v1.16b, v2.4b[3]\n" + ".word 0x4f83e039 // sdot v25.4s, v1.16b, v3.4b[0]\n" + ".word 0x4fa3e03b // sdot v27.4s, v1.16b, v3.4b[1]\n" + ".word 0x4f83e83d // sdot v29.4s, v1.16b, v3.4b[2]\n" + ".word 0x4fa3e83f // sdot v31.4s, v1.16b, v3.4b[3]\n" + + // End of accumulation. The registers v16 -- v31 contain the final + // int32 accumulator values of the current 8x8 destination block. + // We now have to compute the final 8-bit values from these int32 + // accumulators, and advance to the next 8x8 block. We intertwine + // these two aspects whenever possible for optimal pipelining, both + // at the data flow level (prefetch data for next block as early as + // possible) and instruction pipelining level (some of the next-block + // work can dual-issue with some of the final work on the current + // block). + + // Logic to advance to the next block in preparation for the next + // iteration of the main loop. For now, we only want to compute + // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are + // not yet ready to update the values of row and col, as we still need + // the current values for the rest of the work on the current block. + + "bge 4f\n" // If finished last row, go to 4 + // Not finished last row: then advance to next row. + "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #3\n" + "b 5f\n" + "4:\n" // Finished last row... + "mov %[lhs_col_ptr], x5\n" // Go back to first row + // Now we need to advance to the next column. If we already + // finished the last column, then in principle we are done, however + // we can't just return here, as we need to allow the end work of the + // current block to complete. The good news is that at this point it + // doesn't matter what data we load for the next column, since + // we will exit from the main loop below before actually storing + // anything computed from that data. + "cmp %w[col], w8\n" // Have we finished the last column? + "bge 5f\n" // If yes, just carry on without updating the column pointer. + // Not finished last column: then advance to next column. + "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #3\n" + "5:\n" + + // Set the LHS and RHS data pointers to the start of the columns just + // computed. + "mov %[lhs_ptr], %[lhs_col_ptr]\n" + // Load some parameters needed for the end work on current block. + RUY_MAKE_ZERO(v8) + "mov %[rhs_ptr], %[rhs_col_ptr]\n" + "ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n" + "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n" + "ins v13.h[4], w4\n" // dst_zero_point + "ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n" + "ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" + "dup v9.4s, w3\n" // create prod_zp_depth_vec + "add x5, x4, %x[row], lsl #2\n" + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" + "csel x4, x4, x5, eq\n" + + "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" + "add x5, x1, %x[row], lsl #2\n" + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" + "csel x1, x1, x5, eq\n" + + // Load 8 bias values. + "ld1 {v14.2s}, [x1], #8\n" + "ldr x5, [x1], #8\n" + "ins v14.d[1], x5\n" + "ld1 {v15.2s}, [x1], #8\n" + "ldr x5, [x1], #8\n" + "ins v15.d[1], x5\n" + + // Add to the bias values the product (depth * lhs_zero_point * rhs_zero_point), + // See the term NZ1Z2 in equation (7) in https://arxiv.org/pdf/1712.05877.pdf + "add v14.4s, v14.4s, v9.4s\n" + "add v15.4s, v15.4s, v9.4s\n" + // Perform the bias-addition (per the above, we have just folded into + // the bias the (depth * lhs_zero_point * rhs_zero_point) term.) + "add v16.4s, v16.4s, v14.4s\n" + "add v17.4s, v17.4s, v15.4s\n" + "add v18.4s, v18.4s, v14.4s\n" + "add v19.4s, v19.4s, v15.4s\n" + "add v20.4s, v20.4s, v14.4s\n" + "add v21.4s, v21.4s, v15.4s\n" + "add v22.4s, v22.4s, v14.4s\n" + "add v23.4s, v23.4s, v15.4s\n" + "add v24.4s, v24.4s, v14.4s\n" + "add v25.4s, v25.4s, v15.4s\n" + "add v26.4s, v26.4s, v14.4s\n" + "add v27.4s, v27.4s, v15.4s\n" + "add v28.4s, v28.4s, v14.4s\n" + "add v29.4s, v29.4s, v15.4s\n" + "add v30.4s, v30.4s, v14.4s\n" + "add v31.4s, v31.4s, v15.4s\n" + + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n" + "beq 401f\n" + "ldr x3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n" + "add x3, x3, %x[col], lsl #2\n" + "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n" + "dup v10.4s, w5\n" // create lhs_zero_point_vec + // Load 8 rhs_sums values. + "ld1 {v14.2s}, [x3], #8\n" + "ldr x7, [x3], #8\n" + "ld1 {v15.2s}, [x3], #8\n" + "ins v14.d[1], x7\n" + "ldr x7, [x3], #8\n" + "ins v15.d[1], x7\n" + // Subtract rhs_sums * lhs_zero_point, per + // equation (7) in https://arxiv.org/pdf/1712.05877.pdf + "mls v16.4s, v10.4s, v14.s[0]\n" + "mls v17.4s, v10.4s, v14.s[0]\n" + "mls v18.4s, v10.4s, v14.s[1]\n" + "mls v19.4s, v10.4s, v14.s[1]\n" + "mls v20.4s, v10.4s, v14.s[2]\n" + "mls v21.4s, v10.4s, v14.s[2]\n" + "mls v22.4s, v10.4s, v14.s[3]\n" + "mls v23.4s, v10.4s, v14.s[3]\n" + "mls v24.4s, v10.4s, v15.s[0]\n" + "mls v25.4s, v10.4s, v15.s[0]\n" + "mls v26.4s, v10.4s, v15.s[1]\n" + "mls v27.4s, v10.4s, v15.s[1]\n" + "mls v28.4s, v10.4s, v15.s[2]\n" + "mls v29.4s, v10.4s, v15.s[2]\n" + "mls v30.4s, v10.4s, v15.s[3]\n" + "mls v31.4s, v10.4s, v15.s[3]\n" + "401:\n" + + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n" + "beq 402f\n" + "ldr x2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n" + "add x2, x2, %x[row], lsl #2\n" + "ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n" + "ins v13.s[1], w5\n" // rhs_zero_point + // Load 8 lhs_sums values. + "ld1 {v11.2s}, [x2], #8\n" + "ldr x6, [x2], #8\n" + "ins v11.d[1], x6\n" + "ld1 {v12.2s}, [x2], #8\n" + "ldr x6, [x2], #8\n" + "ins v12.d[1], x6\n" + // Compute lhs_sums * rhs_zero_point. + "mul v11.4s, v11.4s, v13.s[1]\n" + "mul v12.4s, v12.4s, v13.s[1]\n" + // Subtract lhs_sums * rhs_zero_point, per + // equation (7) in https://arxiv.org/pdf/1712.05877.pdf + "sub v16.4s, v16.4s, v11.4s\n" + "sub v17.4s, v17.4s, v12.4s\n" + "sub v18.4s, v18.4s, v11.4s\n" + "sub v19.4s, v19.4s, v12.4s\n" + "sub v20.4s, v20.4s, v11.4s\n" + "sub v21.4s, v21.4s, v12.4s\n" + "sub v22.4s, v22.4s, v11.4s\n" + "sub v23.4s, v23.4s, v12.4s\n" + "sub v24.4s, v24.4s, v11.4s\n" + "sub v25.4s, v25.4s, v12.4s\n" + "sub v26.4s, v26.4s, v11.4s\n" + "sub v27.4s, v27.4s, v12.4s\n" + "sub v28.4s, v28.4s, v11.4s\n" + "sub v29.4s, v29.4s, v12.4s\n" + "sub v30.4s, v30.4s, v11.4s\n" + "sub v31.4s, v31.4s, v12.4s\n" + + "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n" + "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n" + + "402:\n" + + // At this point we have computed the final int32 values. Now we + // start down-quantizing them to obtain the final 8bit values from them. + + // As part of this down-quantization, our int32 values will be + // multiplied by a multiplier that has a fixed-point component and an + // exponent component. + + //Load the exponent part of the multiplier. + "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n" + "ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" + "tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n" + "add x5, x1, %x[row], lsl #2\n" + "csel x1, x1, x5, eq\n" + + "ldr q9, [x1]\n" + "ldr q10, [x1, #16]\n" + + "tst w6, #" RUY_STR(RUY_ASM_FLAG_NEEDS_LEFT_SHIFT) "\n" + "beq 403f\n" + "smax v11.4s, v9.4s, v8.4s\n" + "smax v12.4s, v10.4s, v8.4s\n" + "sshl v16.4s, v16.4s, v11.4s\n" + "sshl v17.4s, v17.4s, v12.4s\n" + "sshl v18.4s, v18.4s, v11.4s\n" + "sshl v19.4s, v19.4s, v12.4s\n" + "sshl v20.4s, v20.4s, v11.4s\n" + "sshl v21.4s, v21.4s, v12.4s\n" + "sshl v22.4s, v22.4s, v11.4s\n" + "sshl v23.4s, v23.4s, v12.4s\n" + "sshl v24.4s, v24.4s, v11.4s\n" + "sshl v25.4s, v25.4s, v12.4s\n" + "sshl v26.4s, v26.4s, v11.4s\n" + "sshl v27.4s, v27.4s, v12.4s\n" + "sshl v28.4s, v28.4s, v11.4s\n" + "sshl v29.4s, v29.4s, v12.4s\n" + "sshl v30.4s, v30.4s, v11.4s\n" + "sshl v31.4s, v31.4s, v12.4s\n" + "403:\n" + + "ldr q14, [x4]\n" // multiplier_fixedpoint + "ldr q15, [x4, #16]\n" // multiplier_fixedpoint + + "smin v11.4s, v9.4s, v8.4s\n" + "smin v12.4s, v10.4s, v8.4s\n" + + // Apply the fixed-point part of the multiplier. + // + // ... and, interleaved into that: + // Now that we know what LHS and RHS data the next iteration of the + // main loop will need to load, we start loading the first 32 bytes of + // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore + // in the rest of the work on the current block. + "ld1 {v0.8b}, [%[lhs_ptr]], #8\n" + "sqrdmulh v16.4s, v16.4s, v14.4s\n" + "ldr x1, [%[lhs_ptr]], #8\n" + "sqrdmulh v17.4s, v17.4s, v15.4s\n" + "ld1 {v1.8b}, [%[lhs_ptr]], #8\n" + "sqrdmulh v18.4s, v18.4s, v14.4s\n" + "ldr x2, [%[lhs_ptr]], #8\n" + "sqrdmulh v19.4s, v19.4s, v15.4s\n" + "ld1 {v2.8b}, [%[rhs_ptr]], #8\n" + "sqrdmulh v20.4s, v20.4s, v14.4s\n" + "ldr x5, [%[rhs_ptr]], #8\n" + "sqrdmulh v21.4s, v21.4s, v15.4s\n" + "ld1 {v3.8b}, [%[rhs_ptr]], #8\n" + "sqrdmulh v22.4s, v22.4s, v14.4s\n" + "ldr x6, [%[rhs_ptr]], #8\n" + "sqrdmulh v23.4s, v23.4s, v15.4s\n" + "sqrdmulh v24.4s, v24.4s, v14.4s\n" + "sqrdmulh v25.4s, v25.4s, v15.4s\n" + "sqrdmulh v26.4s, v26.4s, v14.4s\n" + "sqrdmulh v27.4s, v27.4s, v15.4s\n" + "sqrdmulh v28.4s, v28.4s, v14.4s\n" + "sqrdmulh v29.4s, v29.4s, v15.4s\n" + "sqrdmulh v30.4s, v30.4s, v14.4s\n" + "sqrdmulh v31.4s, v31.4s, v15.4s\n" + + // We have some rounding division-by-power-of-two to do. This should + // always use "round to nearest". We allow for some + // freedom in how ties are broken, to strike a good compromise of + // performance on given hardware vs. perfect agreement of results + // across hardware. + // + // When RUY_OPT_NATIVE_ROUNDING is enabled, we allow for implementation + // defined tie-breaks to help performance. On NEON, this means that we + // can just use the NEON rounding instructions, such as srshl. They + // happen to be breaking ties upward. + // + // When RUY_OPT_NATIVE_ROUNDING is disabled, we implement strict + // break-ties-away-from zero, as described in Appendix B of + // https://arxiv.org/pdf/1712.05877.pdf + // When we wrote that, we thought that that would be better unbiased + // than the NEON upwards tie-breaks, and we had observed some + // improvement on some model. However, that is only more unbiased for + // data centered at zero, which was likely the case in that model, + // but is not always the case. If we wanted something more consistently + // unbiased then we should try breaking ties toward-nearest-even. +#if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING) + // Fix up values to be right-shifted, so that the (round to nearest, + // break ties upward) behavior of srshl applied to these fixed-up + // values, produces the same result as the desired (round to nearest, + // break ties away from zero) behavior on the original values. + "and v8.16b, v16.16b, v11.16b\n" + "and v9.16b, v17.16b, v12.16b\n" + "and v14.16b, v18.16b, v11.16b\n" + "and v15.16b, v19.16b, v12.16b\n" + "sshr v8.4s, v8.4s, #31\n" + "sshr v9.4s, v9.4s, #31\n" + "sshr v14.4s, v14.4s, #31\n" + "sshr v15.4s, v15.4s, #31\n" + "sqadd v16.4s, v16.4s, v8.4s\n" + "sqadd v17.4s, v17.4s, v9.4s\n" + "sqadd v18.4s, v18.4s, v14.4s\n" + "sqadd v19.4s, v19.4s, v15.4s\n" + "and v8.16b, v20.16b, v11.16b\n" + "and v9.16b, v21.16b, v12.16b\n" + "and v14.16b, v22.16b, v11.16b\n" + "and v15.16b, v23.16b, v12.16b\n" + "sshr v8.4s, v8.4s, #31\n" + "sshr v9.4s, v9.4s, #31\n" + "sshr v14.4s, v14.4s, #31\n" + "sshr v15.4s, v15.4s, #31\n" + "sqadd v20.4s, v20.4s, v8.4s\n" + "sqadd v21.4s, v21.4s, v9.4s\n" + "sqadd v22.4s, v22.4s, v14.4s\n" + "sqadd v23.4s, v23.4s, v15.4s\n" + "and v8.16b, v24.16b, v11.16b\n" + "and v9.16b, v25.16b, v12.16b\n" + "and v14.16b, v26.16b, v11.16b\n" + "and v15.16b, v27.16b, v12.16b\n" + "sshr v8.4s, v8.4s, #31\n" + "sshr v9.4s, v9.4s, #31\n" + "sshr v14.4s, v14.4s, #31\n" + "sshr v15.4s, v15.4s, #31\n" + "sqadd v24.4s, v24.4s, v8.4s\n" + "sqadd v25.4s, v25.4s, v9.4s\n" + "sqadd v26.4s, v26.4s, v14.4s\n" + "sqadd v27.4s, v27.4s, v15.4s\n" + "and v8.16b, v28.16b, v11.16b\n" + "and v9.16b, v29.16b, v12.16b\n" + "and v14.16b, v30.16b, v11.16b\n" + "and v15.16b, v31.16b, v12.16b\n" + "sshr v8.4s, v8.4s, #31\n" + "sshr v9.4s, v9.4s, #31\n" + "sshr v14.4s, v14.4s, #31\n" + "sshr v15.4s, v15.4s, #31\n" + "sqadd v28.4s, v28.4s, v8.4s\n" + "sqadd v29.4s, v29.4s, v9.4s\n" + "sqadd v30.4s, v30.4s, v14.4s\n" + "sqadd v31.4s, v31.4s, v15.4s\n" +#endif + // At this point we have reduced the problem of correctly implementing + // rounding divide-by-power-of-two, to what the SRSHL instruction can + // do. + "srshl v16.4s, v16.4s, v11.4s\n" + "srshl v17.4s, v17.4s, v12.4s\n" + "srshl v18.4s, v18.4s, v11.4s\n" + "srshl v19.4s, v19.4s, v12.4s\n" + "srshl v20.4s, v20.4s, v11.4s\n" + "srshl v21.4s, v21.4s, v12.4s\n" + "srshl v22.4s, v22.4s, v11.4s\n" + "srshl v23.4s, v23.4s, v12.4s\n" + "srshl v24.4s, v24.4s, v11.4s\n" + "srshl v25.4s, v25.4s, v12.4s\n" + "srshl v26.4s, v26.4s, v11.4s\n" + "srshl v27.4s, v27.4s, v12.4s\n" + "ins v0.d[1], x1\n" + "srshl v28.4s, v28.4s, v11.4s\n" + "ins v1.d[1], x2\n" + "srshl v29.4s, v29.4s, v12.4s\n" + "ins v2.d[1], x5\n" + "srshl v30.4s, v30.4s, v11.4s\n" + "ins v3.d[1], x6\n" + "srshl v31.4s, v31.4s, v12.4s\n" + + "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n" + "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n" + "cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n" + "beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n" + + // Cast-and-saturate from int32 to int16 + "sqxtn v16.4h, v16.4s\n" + "sqxtn2 v16.8h, v17.4s\n" + "sqxtn v17.4h, v18.4s\n" + "sqxtn2 v17.8h, v19.4s\n" + "sqxtn v18.4h, v20.4s\n" + "sqxtn2 v18.8h, v21.4s\n" + "sqxtn v19.4h, v22.4s\n" + "sqxtn2 v19.8h, v23.4s\n" + "sqxtn v20.4h, v24.4s\n" + "sqxtn2 v20.8h, v25.4s\n" + "sqxtn v21.4h, v26.4s\n" + "sqxtn2 v21.8h, v27.4s\n" + "sqxtn v22.4h, v28.4s\n" + "sqxtn2 v22.8h, v29.4s\n" + "sqxtn v23.4h, v30.4s\n" + "sqxtn2 v23.8h, v31.4s\n" + + // Destination zero_point + "dup v14.8h, v13.h[4]\n" + // At this point, v24 -- v31 aren't used anymore for the current block, + // so we can start clearing these accumulators for the next block + // (next iteration of the main loop). + RUY_MAKE_ZERO(v24) + RUY_MAKE_ZERO(v25) + RUY_MAKE_ZERO(v26) + RUY_MAKE_ZERO(v27) + RUY_MAKE_ZERO(v28) + RUY_MAKE_ZERO(v29) + RUY_MAKE_ZERO(v30) + RUY_MAKE_ZERO(v31) + + // Add the destination zero point + "add v16.8h, v16.8h, v14.8h\n" + "add v17.8h, v17.8h, v14.8h\n" + "add v18.8h, v18.8h, v14.8h\n" + "add v19.8h, v19.8h, v14.8h\n" + "add v20.8h, v20.8h, v14.8h\n" + "add v21.8h, v21.8h, v14.8h\n" + "add v22.8h, v22.8h, v14.8h\n" + "add v23.8h, v23.8h, v14.8h\n" + + // Load the clamp_min, clamp_max bounds + "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + // Cast-and-saturate from int16 to uint8 + "sqxtun v16.8b, v16.8h\n" + "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + "sqxtun2 v16.16b, v17.8h\n" + "sqxtun v17.8b, v18.8h\n" + "sqxtun2 v17.16b, v19.8h\n" + "sqxtun v18.8b, v20.8h\n" + "sqxtun2 v18.16b, v21.8h\n" + "sqxtun v19.8b, v22.8h\n" + "sqxtun2 v19.16b, v23.8h\n" + + "dup v14.16b, w2\n" // clamp_min + "dup v15.16b, w3\n" // clamp_max + + // Compute how much of the 8x8 block of destination 8bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 8x8, there are some 8x8 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + // Apply the clamp_min bound + "umax v16.16b, v16.16b, v14.16b\n" + "sub w2, %w[dst_cols], %w[col]\n" + "umax v17.16b, v17.16b, v14.16b\n" + "mov w3, #8\n" + "umax v18.16b, v18.16b, v14.16b\n" + "cmp w1, #8\n" + "umax v19.16b, v19.16b, v14.16b\n" + // Compute w1 = how many rows of the 8x8 block fit + "csel w1, w1, w3, le\n" + // Apply the clamp_max bound + "umin v16.16b, v16.16b, v15.16b\n" + "cmp w2, #8\n" + "umin v17.16b, v17.16b, v15.16b\n" + // Compute w2 = how many cols of the 8x8 block fit + "csel w2, w2, w3, le\n" + "umin v18.16b, v18.16b, v15.16b\n" + "umin v19.16b, v19.16b, v15.16b\n" + + // Make it so that all of the final 8bit values are stored in the + // first 64bits of 128bit NEON registers, so they can be stored + // by 64bit st1 store instructions with byte alignment. + "dup d20, v16.d[1]\n" + "dup d21, v17.d[1]\n" + "dup d22, v18.d[1]\n" + "dup d23, v19.d[1]\n" + + // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. + "cmp w1, w3\n" + "ccmp w2, w3, 0, eq\n" + // Yes, all of the 8x8 block fits, go to fast path. + "beq 30f\n" + // Not all of the 8x8 block fits. + // Set (x3 address, x4 stride) to write to dst_tmp_buf + "mov x3, %[dst_tmp_buf]\n" + "mov x4, #8\n" + "b 31f\n" + "30:\n" + // Yes, all of the 8x8 block fits. + // Set (x3 address, x4 stride) to write directly to destination matrix. + "mov x3, %[dst_ptr]\n" + "mov x4, x11\n" + "31:\n" + + // Write our 8bit values to the destination described by + // (x3 address, x4 stride). + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v16.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v16) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v20.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v20) + // For the next block: perform the first few multiply-adds on the data + // that we have already loaded. + ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v17.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v17) + ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v21.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v21) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v18.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v18) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v22.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v22) + ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v19.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v19) + ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v23.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v23) + + // If all of the 8x8 block fits, we just finished writing it to the + // destination, so we skip the next part. + "beq 41f\n" + // Not all of the 8x8 block fits in the destination matrix. We just + // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over + // it to copy into the destination matrix the part that fits. + "mov x3, %[dst_tmp_buf]\n" + "mov x4, %[dst_ptr]\n" + "mov w6, #0\n" + "50:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov w5, #0\n" + "51:\n" + "ldrb w7, [x3, w5, uxtw]\n" + "strb w7, [x4, w5, uxtw]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 51b\n" + "add w6, w6, #1\n" + "add x3, x3, #8\n" + "add x4, x4, x11\n" + "cmp w6, w2\n" + "blt 50b\n" + "41:\n" + "add %[dst_ptr], %[dst_ptr], #8\n" + + // At this point we have completely finished writing values to the + // destination matrix for the current block. + + "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n" + + // Cast-and-saturate from int32 to int16 + "sqxtn v16.4h, v16.4s\n" + "sqxtn2 v16.8h, v17.4s\n" + "sqxtn v17.4h, v18.4s\n" + "sqxtn2 v17.8h, v19.4s\n" + "sqxtn v18.4h, v20.4s\n" + "sqxtn2 v18.8h, v21.4s\n" + "sqxtn v19.4h, v22.4s\n" + "sqxtn2 v19.8h, v23.4s\n" + "sqxtn v20.4h, v24.4s\n" + "sqxtn2 v20.8h, v25.4s\n" + "sqxtn v21.4h, v26.4s\n" + "sqxtn2 v21.8h, v27.4s\n" + "sqxtn v22.4h, v28.4s\n" + "sqxtn2 v22.8h, v29.4s\n" + "sqxtn v23.4h, v30.4s\n" + "sqxtn2 v23.8h, v31.4s\n" + + // Destination zero_point + "dup v14.8h, v13.h[4]\n" + // At this point, v24 -- v31 aren't used anymore for the current block, + // so we can start clearing these accumulators for the next block + // (next iteration of the main loop). + RUY_MAKE_ZERO(v24) + RUY_MAKE_ZERO(v25) + RUY_MAKE_ZERO(v26) + RUY_MAKE_ZERO(v27) + RUY_MAKE_ZERO(v28) + RUY_MAKE_ZERO(v29) + RUY_MAKE_ZERO(v30) + RUY_MAKE_ZERO(v31) + + // Add the destination zero point + "add v16.8h, v16.8h, v14.8h\n" + "add v17.8h, v17.8h, v14.8h\n" + "add v18.8h, v18.8h, v14.8h\n" + "add v19.8h, v19.8h, v14.8h\n" + "add v20.8h, v20.8h, v14.8h\n" + "add v21.8h, v21.8h, v14.8h\n" + "add v22.8h, v22.8h, v14.8h\n" + "add v23.8h, v23.8h, v14.8h\n" + + // Load the clamp_min, clamp_max bounds + "ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + // Cast-and-saturate from int16 to uint8 + "sqxtn v16.8b, v16.8h\n" + "ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + "sqxtn2 v16.16b, v17.8h\n" + "sqxtn v17.8b, v18.8h\n" + "sqxtn2 v17.16b, v19.8h\n" + "sqxtn v18.8b, v20.8h\n" + "sqxtn2 v18.16b, v21.8h\n" + "sqxtn v19.8b, v22.8h\n" + "sqxtn2 v19.16b, v23.8h\n" + + "dup v14.16b, w2\n" // clamp_min + "dup v15.16b, w3\n" // clamp_max + + // Compute how much of the 8x8 block of destination 8bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 8x8, there are some 8x8 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + // Apply the clamp_min bound + "smax v16.16b, v16.16b, v14.16b\n" + "sub w2, %w[dst_cols], %w[col]\n" + "smax v17.16b, v17.16b, v14.16b\n" + "mov w3, #8\n" + "smax v18.16b, v18.16b, v14.16b\n" + "cmp w1, #8\n" + "smax v19.16b, v19.16b, v14.16b\n" + // Compute w1 = how many rows of the 8x8 block fit + "csel w1, w1, w3, le\n" + // Apply the clamp_max bound + "smin v16.16b, v16.16b, v15.16b\n" + "cmp w2, #8\n" + "smin v17.16b, v17.16b, v15.16b\n" + // Compute w2 = how many cols of the 8x8 block fit + "csel w2, w2, w3, le\n" + "smin v18.16b, v18.16b, v15.16b\n" + "smin v19.16b, v19.16b, v15.16b\n" + + // Make it so that all of the final 8bit values are stored in the + // first 64bits of 128bit NEON registers, so they can be stored + // by 64bit st1 store instructions with byte alignment. + "dup d20, v16.d[1]\n" + "dup d21, v17.d[1]\n" + "dup d22, v18.d[1]\n" + "dup d23, v19.d[1]\n" + + // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. + "cmp w1, w3\n" + "ccmp w2, w3, 0, eq\n" + // Yes, all of the 8x8 block fits, go to fast path. + "beq 130f\n" + // Not all of the 8x8 block fits. + // Set (x3 address, x4 stride) to write to dst_tmp_buf + "mov x3, %[dst_tmp_buf]\n" + "mov x4, #8\n" + "b 131f\n" + "130:\n" + // Yes, all of the 8x8 block fits. + // Set (x3 address, x4 stride) to write directly to destination matrix. + "mov x3, %[dst_ptr]\n" + "mov x4, x11\n" + "131:\n" + + // Write our 8bit values to the destination described by + // (x3 address, x4 stride). + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v16.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v16) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v20.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v20) + // For the next block: perform the first few multiply-adds on the data + // that we have already loaded. + ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v17.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v17) + ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v21.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v21) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v18.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v18) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v22.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v22) + ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v19.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v19) + ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v23.8b}, [x3], x4\n" + RUY_MAKE_ZERO(v23) + + // If all of the 8x8 block fits, we just finished writing it to the + // destination, so we skip the next part. + "beq 141f\n" + // Not all of the 8x8 block fits in the destination matrix. We just + // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over + // it to copy into the destination matrix the part that fits. + "mov x3, %[dst_tmp_buf]\n" + "mov x4, %[dst_ptr]\n" + "mov w6, #0\n" + "150:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov w5, #0\n" + "151:\n" + "ldrb w7, [x3, w5, uxtw]\n" + "strb w7, [x4, w5, uxtw]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 151b\n" + "add w6, w6, #1\n" + "add x3, x3, #8\n" + "add x4, x4, x11\n" + "cmp w6, w2\n" + "blt 150b\n" + "141:\n" + "add %[dst_ptr], %[dst_ptr], #8\n" + + // At this point we have completely finished writing values to the + // destination matrix for the current block. + + "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n" + + // Add the destination zero point + "dup v14.8h, v13.h[4]\n" + "saddw v16.4s, v16.4s, v14.4h\n" + "saddw v17.4s, v17.4s, v14.4h\n" + "saddw v18.4s, v18.4s, v14.4h\n" + "saddw v19.4s, v19.4s, v14.4h\n" + "saddw v20.4s, v20.4s, v14.4h\n" + "saddw v21.4s, v21.4s, v14.4h\n" + "saddw v22.4s, v22.4s, v14.4h\n" + "saddw v23.4s, v23.4s, v14.4h\n" + "saddw v24.4s, v24.4s, v14.4h\n" + "saddw v25.4s, v25.4s, v14.4h\n" + "saddw v26.4s, v26.4s, v14.4h\n" + "saddw v27.4s, v27.4s, v14.4h\n" + "saddw v28.4s, v28.4s, v14.4h\n" + "saddw v29.4s, v29.4s, v14.4h\n" + "saddw v30.4s, v30.4s, v14.4h\n" + "saddw v31.4s, v31.4s, v14.4h\n" + + // Cast-and-saturate from int32 to int16 + "sqxtn v16.4h, v16.4s\n" + "sqxtn2 v16.8h, v17.4s\n" + "sqxtn v17.4h, v18.4s\n" + "sqxtn2 v17.8h, v19.4s\n" + "sqxtn v18.4h, v20.4s\n" + "sqxtn2 v18.8h, v21.4s\n" + "sqxtn v19.4h, v22.4s\n" + "sqxtn2 v19.8h, v23.4s\n" + "sqxtn v20.4h, v24.4s\n" + "sqxtn2 v20.8h, v25.4s\n" + "sqxtn v21.4h, v26.4s\n" + "sqxtn2 v21.8h, v27.4s\n" + "sqxtn v22.4h, v28.4s\n" + "sqxtn2 v22.8h, v29.4s\n" + "sqxtn v23.4h, v30.4s\n" + "sqxtn2 v23.8h, v31.4s\n" + + // At this point, v24 -- v31 aren't used anymore for the current block, + // so we can start clearing these accumulators for the next block + // (next iteration of the main loop). + RUY_MAKE_ZERO(v24) + RUY_MAKE_ZERO(v25) + RUY_MAKE_ZERO(v26) + RUY_MAKE_ZERO(v27) + RUY_MAKE_ZERO(v28) + RUY_MAKE_ZERO(v29) + RUY_MAKE_ZERO(v30) + RUY_MAKE_ZERO(v31) + + // Load the clamp_min, clamp_max bounds + "ldrsh w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + "ldrsh w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + "dup v14.8h, w2\n" // clamp_min + "dup v15.8h, w3\n" // clamp_max + + // Apply the clamp_min bound + "smax v16.8h, v16.8h, v14.8h\n" + "smax v17.8h, v17.8h, v14.8h\n" + "smax v18.8h, v18.8h, v14.8h\n" + "smax v19.8h, v19.8h, v14.8h\n" + "smax v20.8h, v20.8h, v14.8h\n" + "smax v21.8h, v21.8h, v14.8h\n" + "smax v22.8h, v22.8h, v14.8h\n" + "smax v23.8h, v23.8h, v14.8h\n" + // Apply the clamp_max bound + "smin v16.8h, v16.8h, v15.8h\n" + "smin v17.8h, v17.8h, v15.8h\n" + "smin v18.8h, v18.8h, v15.8h\n" + "smin v19.8h, v19.8h, v15.8h\n" + "smin v20.8h, v20.8h, v15.8h\n" + "smin v21.8h, v21.8h, v15.8h\n" + "smin v22.8h, v22.8h, v15.8h\n" + "smin v23.8h, v23.8h, v15.8h\n" + + // Compute how much of the 8x8 block of destination 16bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 8x8, there are some 8x8 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + "sub w2, %w[dst_cols], %w[col]\n" + "mov w3, #8\n" + "cmp w1, #8\n" + // Compute w1 = how many rows of the 8x8 block fit + "csel w1, w1, w3, le\n" + "cmp w2, #8\n" + // Compute w1 = how many rows of the 8x8 block fit + "csel w2, w2, w3, le\n" + + // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. + "cmp w1, w3\n" + "ccmp w2, w3, 0, eq\n" + // Yes, all of the 8x8 block fits, go to fast path. + "beq 230f\n" + // Not all of the 8x8 block fits. + // Set (x3 address, x4 stride) to write to dst_tmp_buf + "mov x3, %[dst_tmp_buf]\n" + "mov x4, #16\n" + "b 231f\n" + "230:\n" + // Yes, all of the 8x8 block fits. + // Set (x3 address, x4 stride) to write directly to destination matrix. + "mov x3, %[dst_ptr]\n" + "mov x4, x11\n" + "231:\n" + + // Write our 8bit values to the destination described by + // (x3 address, x4 stride). + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v16.8h}, [x3], x4\n" + RUY_MAKE_ZERO(v16) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v17.8h}, [x3], x4\n" + RUY_MAKE_ZERO(v17) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v18.8h}, [x3], x4\n" + RUY_MAKE_ZERO(v18) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v19.8h}, [x3], x4\n" + RUY_MAKE_ZERO(v19) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v20.8h}, [x3], x4\n" + RUY_MAKE_ZERO(v20) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v21.8h}, [x3], x4\n" + RUY_MAKE_ZERO(v21) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v22.8h}, [x3], x4\n" + RUY_MAKE_ZERO(v22) + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "st1 {v23.8h}, [x3], x4\n" + RUY_MAKE_ZERO(v23) + + // For the next block: perform the first few multiply-adds on the data + // that we have already loaded. + ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" + ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" + ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" + ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" + + // If all of the 8x8 block fits, we just finished writing it to the + // destination, so we skip the next part. + "beq 241f\n" + // Not all of the 8x8 block fits in the destination matrix. We just + // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over + // it to copy into the destination matrix the part that fits. + "mov x3, %[dst_tmp_buf]\n" + "mov x4, %[dst_ptr]\n" + "mov w6, #0\n" + "250:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov w5, #0\n" + "251:\n" + "ldrsh w7, [x3, x5, lsl #1]\n" + "strh w7, [x4, x5, lsl #1]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 251b\n" + "add w6, w6, #1\n" + "add x3, x3, #16\n" + "add x4, x4, x11\n" + "cmp w6, w2\n" + "blt 250b\n" + "241:\n" + "add %[dst_ptr], %[dst_ptr], #16\n" + // At this point we have completely finished writing values to the + // destination matrix for the current block. + + "b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n" + + RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n" + + "ld1 {v0.8b}, [%[lhs_ptr]], #8\n" + "ldr x1, [%[lhs_ptr]], #8\n" + "ld1 {v1.8b}, [%[lhs_ptr]], #8\n" + "ldr x2, [%[lhs_ptr]], #8\n" + "ld1 {v2.8b}, [%[rhs_ptr]], #8\n" + "ldr x5, [%[rhs_ptr]], #8\n" + "ld1 {v3.8b}, [%[rhs_ptr]], #8\n" + "ldr x6, [%[rhs_ptr]], #8\n" + "ins v0.d[1], x1\n" + "ins v1.d[1], x2\n" + "ins v2.d[1], x5\n" + "ins v3.d[1], x6\n" + + // Since the store type is the same as the accum type, no need for + // downcast. There's also no need for clamp by min/max. + + // Compute how much of the 8x8 block of destination 32it values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 8x8, there are some 8x8 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + "sub w2, %w[dst_cols], %w[col]\n" + "mov w3, #8\n" + "cmp w1, #8\n" + // Compute w1 = how many rows of the 8x8 block fit + "csel w1, w1, w3, le\n" + "cmp w2, #8\n" + // Compute w1 = how many rows of the 8x8 block fit + "csel w2, w2, w3, le\n" + + // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. + "cmp w1, w3\n" + "ccmp w2, w3, 0, eq\n" + // Yes, all of the 8x8 block fits, go to fast path. + "beq 330f\n" + // Not all of the 8x8 block fits. + // Write to dst_tmp_buf + "mov x3, %[dst_tmp_buf]\n" + "st1 {v16.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v16) + "st1 {v17.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v17) + "st1 {v18.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v18) + "st1 {v19.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v19) + "st1 {v20.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v20) + "st1 {v21.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v21) + "st1 {v22.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v22) + "st1 {v23.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v23) + "st1 {v24.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v24) + "st1 {v25.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v25) + "st1 {v26.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v26) + "st1 {v27.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v27) + "st1 {v28.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v28) + "st1 {v29.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v29) + "st1 {v30.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v30) + "st1 {v31.4s}, [x3], #16\n" + RUY_MAKE_ZERO(v31) + + "b 331f\n" + + "330:\n" + // Yes, all of the 8x8 block fits. + "mov x4, %[dst_ptr]\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "st1 {v16.4s, v17.4s}, [x4], x11\n" + RUY_MAKE_ZERO(v16) + RUY_MAKE_ZERO(v17) + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "st1 {v18.4s, v19.4s}, [x4], x11\n" + RUY_MAKE_ZERO(v18) + RUY_MAKE_ZERO(v19) + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "st1 {v20.4s, v21.4s}, [x4], x11\n" + RUY_MAKE_ZERO(v20) + RUY_MAKE_ZERO(v21) + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "st1 {v22.4s, v23.4s}, [x4], x11\n" + RUY_MAKE_ZERO(v22) + RUY_MAKE_ZERO(v23) + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "st1 {v24.4s, v25.4s}, [x4], x11\n" + RUY_MAKE_ZERO(v24) + RUY_MAKE_ZERO(v25) + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "st1 {v26.4s, v27.4s}, [x4], x11\n" + RUY_MAKE_ZERO(v26) + RUY_MAKE_ZERO(v27) + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "st1 {v28.4s, v29.4s}, [x4], x11\n" + RUY_MAKE_ZERO(v28) + RUY_MAKE_ZERO(v29) + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "st1 {v30.4s, v31.4s}, [x4], x11\n" + RUY_MAKE_ZERO(v30) + RUY_MAKE_ZERO(v31) + + "331:\n" + + // For the next block: perform the first few multiply-adds on the data + // that we have already loaded. + ".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n" + ".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n" + ".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n" + ".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n" + + // If all of the 8x8 block fits, we just finished writing it to the + // destination, so we skip the next part. + "beq 341f\n" + + // Not all of the 8x8 block fits in the destination matrix. We just + // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over + // it to copy into the destination matrix the part that fits. + "mov x3, %[dst_tmp_buf]\n" + "mov x4, %[dst_ptr]\n" + "mov w6, #0\n" + "350:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov w5, #0\n" + "351:\n" + "ldr w7, [x3, x5, lsl #2]\n" + "str w7, [x4, x5, lsl #2]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 351b\n" + "add w6, w6, #1\n" + "add x3, x3, #32\n" + "add x4, x4, x11\n" + "cmp w6, w2\n" + "blt 350b\n" + "341:\n" + "add %[dst_ptr], %[dst_ptr], #32\n" + // At this point we have completely finished writing values to the + // destination matrix for the current block. + + RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n" + + // Reload some params --- we had used x5 -- x7 for a few other things + // since the last time we had loaded them. + "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" + "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" + + // Move to the next block of the destination matrix, for the next iter + // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already + // been updated earlier. + // Have we reached the end row? + "cmp %w[row], w7\n" + "beq 20f\n" // yes, end row. + // Not end row. Move to the next row. + "add %w[row], %w[row], #8\n" + "b 21f\n" + "20:\n" + // Was already at end row. + "mov %w[row], w6\n" // Move back to first row. + "add %w[col], %w[col], #8\n" // Move to the next column. + "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #3\n" + "mov %[dst_ptr], %[dst_col_ptr]\n" + "21:\n" + + // Main loop exit condition: have we hit the end column? + "cmp %w[col], w8\n" + "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + "ble 1b\n" + + // clang-format on + + : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr), + [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), + [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col) + : [ params ] "r"(¶ms), [dst_rows] "r"(params.dst_rows), + [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf), + [dst_type_id] "r"(params.dst_type_id) + : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc", + "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", + "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31"); +} +#undef RUY_OFFSET_BIAS +#undef RUY_OFFSET_LHS_SUMS +#undef RUY_OFFSET_RHS_SUMS +#undef RUY_OFFSET_LHS_BASE_PTR +#undef RUY_OFFSET_MULTIPLIER_FIXEDPOINT +#undef RUY_OFFSET_MULTIPLIER_EXPONENT +#undef RUY_OFFSET_RHS_BASE_PTR +#undef RUY_OFFSET_DST_BASE_PTR +#undef RUY_OFFSET_LHS_ZERO_POINT +#undef RUY_OFFSET_RHS_ZERO_POINT +#undef RUY_OFFSET_DST_ZERO_POINT +#undef RUY_OFFSET_PROD_ZP_DEPTH +#undef RUY_OFFSET_START_ROW +#undef RUY_OFFSET_START_COL +#undef RUY_OFFSET_LAST_ROW +#undef RUY_OFFSET_LAST_COL +#undef RUY_OFFSET_DST_ROWS +#undef RUY_OFFSET_DST_COLS +#undef RUY_OFFSET_LHS_STRIDE +#undef RUY_OFFSET_RHS_STRIDE +#undef RUY_OFFSET_DST_STRIDE +#undef RUY_OFFSET_DEPTH +#undef RUY_OFFSET_CLAMP_MIN +#undef RUY_OFFSET_CLAMP_MAX +#undef RUY_OFFSET_FLAGS + +#define RUY_OFFSET_LHS_BASE_PTR 0 +#define RUY_OFFSET_RHS_BASE_PTR 8 +#define RUY_OFFSET_DST_BASE_PTR 16 +#define RUY_OFFSET_BIAS 24 +#define RUY_OFFSET_START_ROW 32 +#define RUY_OFFSET_START_COL 36 +#define RUY_OFFSET_LAST_ROW 40 +#define RUY_OFFSET_LAST_COL 44 +#define RUY_OFFSET_LHS_STRIDE 56 +#define RUY_OFFSET_RHS_STRIDE 60 +#define RUY_OFFSET_DST_STRIDE 64 +#define RUY_OFFSET_DEPTH 68 +#define RUY_OFFSET_CLAMP_MIN 72 +#define RUY_OFFSET_CLAMP_MAX 76 +#define RUY_OFFSET_FLAGS 80 + +template +void CheckOffsetsInKernelParamsFloat(const Params&) { + static_assert(offsetof(Params, lhs_base_ptr) == RUY_OFFSET_LHS_BASE_PTR, ""); + static_assert(offsetof(Params, rhs_base_ptr) == RUY_OFFSET_RHS_BASE_PTR, ""); + static_assert(offsetof(Params, dst_base_ptr) == RUY_OFFSET_DST_BASE_PTR, ""); + static_assert(offsetof(Params, bias) == RUY_OFFSET_BIAS, ""); + static_assert(offsetof(Params, start_row) == RUY_OFFSET_START_ROW, ""); + static_assert(offsetof(Params, start_col) == RUY_OFFSET_START_COL, ""); + static_assert(offsetof(Params, last_row) == RUY_OFFSET_LAST_ROW, ""); + static_assert(offsetof(Params, last_col) == RUY_OFFSET_LAST_COL, ""); + static_assert(offsetof(Params, lhs_stride) == RUY_OFFSET_LHS_STRIDE, ""); + static_assert(offsetof(Params, rhs_stride) == RUY_OFFSET_RHS_STRIDE, ""); + static_assert(offsetof(Params, dst_stride) == RUY_OFFSET_DST_STRIDE, ""); + static_assert(offsetof(Params, depth) == RUY_OFFSET_DEPTH, ""); + static_assert(offsetof(Params, clamp_min) == RUY_OFFSET_CLAMP_MIN, ""); + static_assert(offsetof(Params, clamp_max) == RUY_OFFSET_CLAMP_MAX, ""); + static_assert(offsetof(Params, flags) == RUY_OFFSET_FLAGS, ""); +} + +// Just a plain float kernel; good enough for out-of-order cores. +// The closest to it in the gemmlowp collection would be +// NEON_64bit_GEMM_Float32_WithScalar, +// https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L3925 +// +// Besides ruy-ification, the main nuance here is that we stick to a 8x8 +// width instead of the wider 12x8 that the register space permits and that +// the aforementioned gemmlowp kernel uses. Ruy likes powers of two for now +// and we don't have evidence that going beyond 8x8 is needed. +void KernelFloatNeonOutOfOrder(const KernelParamsFloat<8, 8>& params) { + CheckOffsetsInKernelParamsFloat(params); + profiler::ScopeLabel label( + "Kernel (kNeon, optimized for out-of-order cores)"); + + const float* lhs_col_ptr = params.lhs_base_ptr; + const float* rhs_col_ptr = params.rhs_base_ptr; + const float* lhs_ptr = lhs_col_ptr; + const float* rhs_ptr = rhs_col_ptr; + float* dst_col_ptr = params.dst_base_ptr; + float* dst_ptr = dst_col_ptr; + int row = params.start_row; + int col = params.start_col; + + // The asm kernel below has the following NEON register allocation: + // + // v16 -- v31 are accumulators. + // During accumulation, v0 -- v15 are used to load data from LHS and RHS. + // At least v0 and v1 are used to load a 8x1 block of LHS, and v2 and + // v3 are used to load a 1x8 block of RHS, like this: + // + // RHS 1x8 block + // /-----------------------------------------\ + // |v2.s[0] ... v2.s[3] v3.s[0] ... v3.s[3]| + // \-----------------------------------------/ + // LHS 8x1 block + // /---------------------\ /-----------------------------------------\ + // | v0.s[0] | |v16.s[0] ... v30.s[0]| + // | ... | | ... ... | + // | v0.s[3] | |v16.s[3] ... v30.s[3]| + // | v1.s[0] | |v17.s[0] ... v31.s[0]| + // | ... | | ... ... | + // | v1.s[3] | |v17.s[3] ... v31.s[3]| + // \---------------------/ \-----------------------------------------/ + // accumulators 8x8 block + // + // In the RUY_OPT_MAX_STREAMING part of the kernel, this elementary step + // is repeated 4 times, using 4x more registers for LHS and RHS, so that + // is where instead of using v0 -- v3 for LHS and RHS, we use v0 -- v15. + // + // Outside of the RUY_OPT_MAX_STREAMING part of the kernel, v4 -- v7 are + // unused, and v8 -- v15 are used for floading parameters used for the + // post-accumulation part of the kernel. + asm volatile( +#define RUY_MAKE_ZERO(reg) "dup " #reg ".4s, wzr\n" + + // clang-format off + + // Load some parameters into registers. + "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" + "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" + "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" + "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" + "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" + "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" + "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" + + // Load the first 32 bytes of LHS and RHS data. + "ld1 {v0.4s}, [%[lhs_ptr]], #16\n" + "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" + "ld1 {v2.4s}, [%[rhs_ptr]], #16\n" + "ld1 {v3.4s}, [%[rhs_ptr]], #16\n" + + // Clear accumulators. + RUY_MAKE_ZERO(v16) + RUY_MAKE_ZERO(v17) + RUY_MAKE_ZERO(v18) + RUY_MAKE_ZERO(v19) + RUY_MAKE_ZERO(v20) + RUY_MAKE_ZERO(v21) + RUY_MAKE_ZERO(v22) + RUY_MAKE_ZERO(v23) + RUY_MAKE_ZERO(v24) + RUY_MAKE_ZERO(v25) + RUY_MAKE_ZERO(v26) + RUY_MAKE_ZERO(v27) + RUY_MAKE_ZERO(v28) + RUY_MAKE_ZERO(v29) + RUY_MAKE_ZERO(v30) + RUY_MAKE_ZERO(v31) + + // w1 is the number of levels of depth that we have already loaded + // LHS and RHS data for. Corresponding to the initial ld1 instructions + // above, this is currently 1. + "mov w1, #1\n" + + // Main loop of the whole GEMM, over rows and columns of the + // destination matrix. + "1:\n" + + "fmla v16.4s, v0.4s, v2.s[0]\n" + "fmla v18.4s, v0.4s, v2.s[1]\n" + "fmla v20.4s, v0.4s, v2.s[2]\n" + "fmla v22.4s, v0.4s, v2.s[3]\n" + +#if RUY_OPT_ENABLED(RUY_OPT_MAX_STREAMING) + "cmp w12, #8\n" + "blt 78f\n" + "and w2, w12, #-4\n" + + "ld1 {v4.4s}, [%[lhs_ptr]], #16\n" + "ld1 {v5.4s}, [%[lhs_ptr]], #16\n" + "ld1 {v6.4s}, [%[rhs_ptr]], #16\n" + "ld1 {v7.4s}, [%[rhs_ptr]], #16\n" + + "ld1 {v8.4s}, [%[lhs_ptr]], #16\n" + "ld1 {v9.4s}, [%[lhs_ptr]], #16\n" + "ld1 {v10.4s}, [%[rhs_ptr]], #16\n" + "ld1 {v11.4s}, [%[rhs_ptr]], #16\n" + + "ld1 {v12.4s}, [%[lhs_ptr]], #16\n" + "ld1 {v13.4s}, [%[lhs_ptr]], #16\n" + "ld1 {v14.4s}, [%[rhs_ptr]], #16\n" + "ld1 {v15.4s}, [%[rhs_ptr]], #16\n" + "mov w1, #4\n" + + "80:\n" + + "add %[lhs_ptr], %[lhs_ptr], #128\n" + "add %[rhs_ptr], %[rhs_ptr], #128\n" + + "fmla v24.4s, v0.4s, v3.s[0]\n" + "fmla v26.4s, v0.4s, v3.s[1]\n" + "fmla v28.4s, v0.4s, v3.s[2]\n" + "fmla v30.4s, v0.4s, v3.s[3]\n" + "ldr q0, [%[lhs_ptr], #-128]\n" + "fmla v25.4s, v1.4s, v3.s[0]\n" + "fmla v27.4s, v1.4s, v3.s[1]\n" + "fmla v29.4s, v1.4s, v3.s[2]\n" + "fmla v31.4s, v1.4s, v3.s[3]\n" + "ldr q3, [%[rhs_ptr], #-112]\n" + "fmla v17.4s, v1.4s, v2.s[0]\n" + "fmla v19.4s, v1.4s, v2.s[1]\n" + "fmla v21.4s, v1.4s, v2.s[2]\n" + "fmla v23.4s, v1.4s, v2.s[3]\n" + "ldr q1, [%[lhs_ptr], #-112]\n" + "fmla v16.4s, v4.4s, v6.s[0]\n" + "fmla v18.4s, v4.4s, v6.s[1]\n" + "ldr q2, [%[rhs_ptr], #-128]\n" + "fmla v20.4s, v4.4s, v6.s[2]\n" + "fmla v22.4s, v4.4s, v6.s[3]\n" + + "fmla v24.4s, v4.4s, v7.s[0]\n" + "fmla v26.4s, v4.4s, v7.s[1]\n" + "fmla v28.4s, v4.4s, v7.s[2]\n" + "fmla v30.4s, v4.4s, v7.s[3]\n" + "ldr q4, [%[lhs_ptr], #-96]\n" + "fmla v25.4s, v5.4s, v7.s[0]\n" + "fmla v27.4s, v5.4s, v7.s[1]\n" + "fmla v29.4s, v5.4s, v7.s[2]\n" + "fmla v31.4s, v5.4s, v7.s[3]\n" + "ldr q7, [%[rhs_ptr], #-80]\n" + "fmla v17.4s, v5.4s, v6.s[0]\n" + "fmla v19.4s, v5.4s, v6.s[1]\n" + "fmla v21.4s, v5.4s, v6.s[2]\n" + "fmla v23.4s, v5.4s, v6.s[3]\n" + "ldr q5, [%[lhs_ptr], #-80]\n" + "fmla v16.4s, v8.4s, v10.s[0]\n" + "fmla v18.4s, v8.4s, v10.s[1]\n" + "ldr q6, [%[rhs_ptr], #-96]\n" + "fmla v20.4s, v8.4s, v10.s[2]\n" + "fmla v22.4s, v8.4s, v10.s[3]\n" + + "fmla v24.4s, v8.4s, v11.s[0]\n" + "fmla v26.4s, v8.4s, v11.s[1]\n" + "fmla v28.4s, v8.4s, v11.s[2]\n" + "fmla v30.4s, v8.4s, v11.s[3]\n" + "ldr q8, [%[lhs_ptr], #-64]\n" + "fmla v25.4s, v9.4s, v11.s[0]\n" + "fmla v27.4s, v9.4s, v11.s[1]\n" + "fmla v29.4s, v9.4s, v11.s[2]\n" + "fmla v31.4s, v9.4s, v11.s[3]\n" + "ldr q11, [%[rhs_ptr], #-48]\n" + "fmla v17.4s, v9.4s, v10.s[0]\n" + "fmla v19.4s, v9.4s, v10.s[1]\n" + "fmla v21.4s, v9.4s, v10.s[2]\n" + "fmla v23.4s, v9.4s, v10.s[3]\n" + "ldr q9, [%[lhs_ptr], #-48]\n" + "fmla v16.4s, v12.4s, v14.s[0]\n" + "fmla v18.4s, v12.4s, v14.s[1]\n" + "ldr q10, [%[rhs_ptr], #-64]\n" + "fmla v20.4s, v12.4s, v14.s[2]\n" + "fmla v22.4s, v12.4s, v14.s[3]\n" + + "fmla v24.4s, v12.4s, v15.s[0]\n" + "fmla v26.4s, v12.4s, v15.s[1]\n" + "fmla v28.4s, v12.4s, v15.s[2]\n" + "fmla v30.4s, v12.4s, v15.s[3]\n" + "ldr q12, [%[lhs_ptr], #-32]\n" + "fmla v25.4s, v13.4s, v15.s[0]\n" + "fmla v27.4s, v13.4s, v15.s[1]\n" + "fmla v29.4s, v13.4s, v15.s[2]\n" + "fmla v31.4s, v13.4s, v15.s[3]\n" + "ldr q15, [%[rhs_ptr], #-16]\n" + "fmla v17.4s, v13.4s, v14.s[0]\n" + "fmla v19.4s, v13.4s, v14.s[1]\n" + "fmla v21.4s, v13.4s, v14.s[2]\n" + "fmla v23.4s, v13.4s, v14.s[3]\n" + "ldr q13, [%[lhs_ptr], #-16]\n" + "fmla v16.4s, v0.4s, v2.s[0]\n" + "fmla v18.4s, v0.4s, v2.s[1]\n" + "ldr q14, [%[rhs_ptr], #-32]\n" + "fmla v20.4s, v0.4s, v2.s[2]\n" + "fmla v22.4s, v0.4s, v2.s[3]\n" + + "add w1, w1, #4\n" + "cmp w1, w2\n" + "blt 80b\n" + + "fmla v16.4s, v4.4s, v6.s[0]\n" + "fmla v18.4s, v4.4s, v6.s[1]\n" + "fmla v20.4s, v4.4s, v6.s[2]\n" + "fmla v22.4s, v4.4s, v6.s[3]\n" + "fmla v24.4s, v4.4s, v7.s[0]\n" + "fmla v26.4s, v4.4s, v7.s[1]\n" + "fmla v28.4s, v4.4s, v7.s[2]\n" + "fmla v30.4s, v4.4s, v7.s[3]\n" + "fmla v25.4s, v5.4s, v7.s[0]\n" + "fmla v27.4s, v5.4s, v7.s[1]\n" + "fmla v29.4s, v5.4s, v7.s[2]\n" + "fmla v31.4s, v5.4s, v7.s[3]\n" + "fmla v17.4s, v5.4s, v6.s[0]\n" + "fmla v19.4s, v5.4s, v6.s[1]\n" + "fmla v21.4s, v5.4s, v6.s[2]\n" + "fmla v23.4s, v5.4s, v6.s[3]\n" + + "fmla v16.4s, v8.4s, v10.s[0]\n" + "fmla v18.4s, v8.4s, v10.s[1]\n" + "fmla v20.4s, v8.4s, v10.s[2]\n" + "fmla v22.4s, v8.4s, v10.s[3]\n" + "fmla v24.4s, v8.4s, v11.s[0]\n" + "fmla v26.4s, v8.4s, v11.s[1]\n" + "fmla v28.4s, v8.4s, v11.s[2]\n" + "fmla v30.4s, v8.4s, v11.s[3]\n" + "fmla v25.4s, v9.4s, v11.s[0]\n" + "fmla v27.4s, v9.4s, v11.s[1]\n" + "fmla v29.4s, v9.4s, v11.s[2]\n" + "fmla v31.4s, v9.4s, v11.s[3]\n" + "fmla v17.4s, v9.4s, v10.s[0]\n" + "fmla v19.4s, v9.4s, v10.s[1]\n" + "fmla v21.4s, v9.4s, v10.s[2]\n" + "fmla v23.4s, v9.4s, v10.s[3]\n" + + "fmla v16.4s, v12.4s, v14.s[0]\n" + "fmla v18.4s, v12.4s, v14.s[1]\n" + "fmla v20.4s, v12.4s, v14.s[2]\n" + "fmla v22.4s, v12.4s, v14.s[3]\n" + "fmla v24.4s, v12.4s, v15.s[0]\n" + "fmla v26.4s, v12.4s, v15.s[1]\n" + "fmla v28.4s, v12.4s, v15.s[2]\n" + "fmla v30.4s, v12.4s, v15.s[3]\n" + "fmla v25.4s, v13.4s, v15.s[0]\n" + "fmla v27.4s, v13.4s, v15.s[1]\n" + "fmla v29.4s, v13.4s, v15.s[2]\n" + "fmla v31.4s, v13.4s, v15.s[3]\n" + "fmla v17.4s, v13.4s, v14.s[0]\n" + "fmla v19.4s, v13.4s, v14.s[1]\n" + "fmla v21.4s, v13.4s, v14.s[2]\n" + "fmla v23.4s, v13.4s, v14.s[3]\n" + + "78:\n" +#endif + + // Accumulation loop + "cmp w1, w12\n" + "beq 79f\n" + + "2:\n" + "fmla v24.4s, v0.4s, v3.s[0]\n" + "fmla v26.4s, v0.4s, v3.s[1]\n" + "ld1 {v4.4s}, [%[rhs_ptr]], #16\n" + "fmla v28.4s, v0.4s, v3.s[2]\n" + "fmla v30.4s, v0.4s, v3.s[3]\n" + "ld1 {v0.4s}, [%[lhs_ptr]], #16\n" + "fmla v25.4s, v1.4s, v3.s[0]\n" + "fmla v27.4s, v1.4s, v3.s[1]\n" + "add w1, w1, #1\n" + "fmla v29.4s, v1.4s, v3.s[2]\n" + "fmla v31.4s, v1.4s, v3.s[3]\n" + "ld1 {v3.4s}, [%[rhs_ptr]], #16\n" + "fmla v17.4s, v1.4s, v2.s[0]\n" + "fmla v19.4s, v1.4s, v2.s[1]\n" + "cmp w1, w12\n" + "fmla v21.4s, v1.4s, v2.s[2]\n" + "fmla v23.4s, v1.4s, v2.s[3]\n" + "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" + "fmla v16.4s, v0.4s, v4.s[0]\n" + "fmla v18.4s, v0.4s, v4.s[1]\n" + "mov v2.16b, v4.16b\n" + "fmla v20.4s, v0.4s, v4.s[2]\n" + "fmla v22.4s, v0.4s, v4.s[3]\n" + "blt 2b\n" + + "79:\n" + + // End of the inner loop on depth. Now perform the remaining + // multiply-adds of the last level of depth, for which the LHS + // and RHS data is already loaded. + + "fmla v24.4s, v0.4s, v3.s[0]\n" + "fmla v26.4s, v0.4s, v3.s[1]\n" + "fmla v28.4s, v0.4s, v3.s[2]\n" + "fmla v30.4s, v0.4s, v3.s[3]\n" + "fmla v25.4s, v1.4s, v3.s[0]\n" + "fmla v27.4s, v1.4s, v3.s[1]\n" + "fmla v29.4s, v1.4s, v3.s[2]\n" + "fmla v31.4s, v1.4s, v3.s[3]\n" + "fmla v17.4s, v1.4s, v2.s[0]\n" + "fmla v19.4s, v1.4s, v2.s[1]\n" + "fmla v21.4s, v1.4s, v2.s[2]\n" + "fmla v23.4s, v1.4s, v2.s[3]\n" + + // End of accumulation. The registers v16 -- v31 contain the final + // int32 accumulator values of the current 8x8 destination block. + // We now have to compute the final 8-bit values from these int32 + // accumulators, and advance to the next 8x8 block. We intertwine + // these two aspects whenever possible for optimal pipelining, both + // at the data flow level (prefetch data for next block as early as + // possible) and instruction pipelining level (some of the next-block + // work can dual-issue with some of the final work on the current + // block). + + // Logic to advance to the next block in preparation for the next + // iteration of the main loop. For now, we only want to compute + // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are + // not yet ready to update the values of row and col, as we still need + // the current values for the rest of the work on the current block. + + "cmp %w[row], w7\n" // Have we finished the last row? + "bge 4f\n" // If finished last row, go to 4 + // Not finished last row: then advance to next row. + "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #3\n" + "b 5f\n" + "4:\n" // Finished last row... + "mov %[lhs_col_ptr], x5\n" // Go back to first row + // Now we need to advance to the next column. If we already + // finished the last column, then in principle we are done, however + // we can't just return here, as we need to allow the end work of the + // current block to complete. The good news is that at this point it + // doesn't matter what data we load for the next column, since + // we will exit from the main loop below before actually storing + // anything computed from that data. + "cmp %w[col], w8\n" // Have we finished the last column? + "bge 5f\n" // If yes, just carry on without updating the column pointer. + // Not finished last column: then advance to next column. + "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #3\n" + "5:\n" + + // Set the LHS and RHS data pointers to the start of the columns just + // computed. + "mov %[lhs_ptr], %[lhs_col_ptr]\n" + "mov %[rhs_ptr], %[rhs_col_ptr]\n" + + // Load some parameters needed for the end work on current block. + "ldrb w4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" + "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" + + // Offset these base pointers as needed given the current row, col. + "add x5, x1, %x[row], lsl #2\n" + + "tst w4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" + "csel x1, x1, x5, eq\n" + + // Load 8 bias values. + "ld1 {v14.4s}, [x1], #16\n" + "ld1 {v15.4s}, [x1]\n" + + // Now that we know what LHS and RHS data the next iteration of the + // main loop will need to load, we start loading the first 32 bytes of + // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore + // in the rest of the work on the current block. + "ld1 {v0.4s}, [%[lhs_ptr]], #16\n" + "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" + "ld1 {v2.4s}, [%[rhs_ptr]], #16\n" + "ld1 {v3.4s}, [%[rhs_ptr]], #16\n" + + // Perform the bias-addition (per the above, we have just folded into + // the bias the (depth * lhs_zero_point * rhs_zero_point) term.) + "fadd v16.4s, v16.4s, v14.4s\n" + "fadd v17.4s, v17.4s, v15.4s\n" + "fadd v18.4s, v18.4s, v14.4s\n" + "fadd v19.4s, v19.4s, v15.4s\n" + "fadd v20.4s, v20.4s, v14.4s\n" + "fadd v21.4s, v21.4s, v15.4s\n" + "fadd v22.4s, v22.4s, v14.4s\n" + "fadd v23.4s, v23.4s, v15.4s\n" + "fadd v24.4s, v24.4s, v14.4s\n" + "fadd v25.4s, v25.4s, v15.4s\n" + "fadd v26.4s, v26.4s, v14.4s\n" + "fadd v27.4s, v27.4s, v15.4s\n" + "fadd v28.4s, v28.4s, v14.4s\n" + "fadd v29.4s, v29.4s, v15.4s\n" + "fadd v30.4s, v30.4s, v14.4s\n" + "fadd v31.4s, v31.4s, v15.4s\n" + + // Load the clamp_min, clamp_max bounds + "ldr w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + "dup v14.4s, w2\n" // clamp_min + "dup v15.4s, w3\n" // clamp_max + + // Apply the clamp_min bound + "fmax v16.4s, v16.4s, v14.4s\n" + "fmax v17.4s, v17.4s, v14.4s\n" + "fmax v18.4s, v18.4s, v14.4s\n" + "fmax v19.4s, v19.4s, v14.4s\n" + "fmax v20.4s, v20.4s, v14.4s\n" + "fmax v21.4s, v21.4s, v14.4s\n" + "fmax v22.4s, v22.4s, v14.4s\n" + "fmax v23.4s, v23.4s, v14.4s\n" + "fmax v24.4s, v24.4s, v14.4s\n" + "fmax v25.4s, v25.4s, v14.4s\n" + "fmax v26.4s, v26.4s, v14.4s\n" + "fmax v27.4s, v27.4s, v14.4s\n" + "fmax v28.4s, v28.4s, v14.4s\n" + "fmax v29.4s, v29.4s, v14.4s\n" + "fmax v30.4s, v30.4s, v14.4s\n" + "fmax v31.4s, v31.4s, v14.4s\n" + + // Apply the clamp_max bound + "fmin v16.4s, v16.4s, v15.4s\n" + "fmin v17.4s, v17.4s, v15.4s\n" + "fmin v18.4s, v18.4s, v15.4s\n" + "fmin v19.4s, v19.4s, v15.4s\n" + "fmin v20.4s, v20.4s, v15.4s\n" + "fmin v21.4s, v21.4s, v15.4s\n" + "fmin v22.4s, v22.4s, v15.4s\n" + "fmin v23.4s, v23.4s, v15.4s\n" + "fmin v24.4s, v24.4s, v15.4s\n" + "fmin v25.4s, v25.4s, v15.4s\n" + "fmin v26.4s, v26.4s, v15.4s\n" + "fmin v27.4s, v27.4s, v15.4s\n" + "fmin v28.4s, v28.4s, v15.4s\n" + "fmin v29.4s, v29.4s, v15.4s\n" + "fmin v30.4s, v30.4s, v15.4s\n" + "fmin v31.4s, v31.4s, v15.4s\n" + + // Compute how much of the 8x8 block of destination 8bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 8x8, there are some 8x8 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + "sub w2, %w[dst_cols], %w[col]\n" + "mov w3, #8\n" + "cmp w1, #8\n" + // Compute w1 = how many rows of the 8x8 block fit + "csel w1, w1, w3, le\n" + "cmp w2, #8\n" + // Compute w2 = how many cols of the 8x8 block fit + "csel w2, w2, w3, le\n" + + // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. + "cmp w1, w3\n" + "ccmp w2, w3, 0, eq\n" + // Yes, all of the 8x8 block fits, go to fast path. + "beq 30f\n" + // Not all of the 8x8 block fits. + // Set (x3 address, x4 stride) to write to dst_tmp_buf + "mov x3, %[dst_tmp_buf]\n" + "mov x4, #32\n" + "b 31f\n" + "30:\n" + // Yes, all of the 8x8 block fits. + // Set (x3 address, x4 stride) to write directly to destination matrix. + "mov x3, %[dst_ptr]\n" + "mov x4, x11\n" + "31:\n" + + // Write our 8bit values to the destination described by + // (x3 address, x4 stride). + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "str q16, [x3, #0]\n" + "str q17, [x3, #16]\n" + "add x3, x3, x4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + RUY_MAKE_ZERO(v16) + RUY_MAKE_ZERO(v17) + "str q18, [x3, #0]\n" + "str q19, [x3, #16]\n" + "add x3, x3, x4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + RUY_MAKE_ZERO(v18) + RUY_MAKE_ZERO(v19) + "str q20, [x3, #0]\n" + "str q21, [x3, #16]\n" + "add x3, x3, x4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + RUY_MAKE_ZERO(v20) + RUY_MAKE_ZERO(v21) + "str q22, [x3, #0]\n" + "str q23, [x3, #16]\n" + "add x3, x3, x4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + RUY_MAKE_ZERO(v22) + RUY_MAKE_ZERO(v23) + "str q24, [x3, #0]\n" + "str q25, [x3, #16]\n" + "add x3, x3, x4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + RUY_MAKE_ZERO(v24) + RUY_MAKE_ZERO(v25) + "str q26, [x3, #0]\n" + "str q27, [x3, #16]\n" + "add x3, x3, x4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + RUY_MAKE_ZERO(v26) + RUY_MAKE_ZERO(v27) + "str q28, [x3, #0]\n" + "str q29, [x3, #16]\n" + "add x3, x3, x4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + RUY_MAKE_ZERO(v28) + RUY_MAKE_ZERO(v29) + "str q30, [x3, #0]\n" + "str q31, [x3, #16]\n" + RUY_MAKE_ZERO(v30) + RUY_MAKE_ZERO(v31) + + // If all of the 8x8 block fits, we just finished writing it to the + // destination, so we skip the next part. + "beq 41f\n" + // Not all of the 8x8 block fits in the destination matrix. We just + // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over + // it to copy into the destination matrix the part that fits. + "mov x3, %[dst_tmp_buf]\n" + "mov x4, %[dst_ptr]\n" + "mov w6, #0\n" + "50:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov w5, #0\n" + "51:\n" + "ldr w7, [x3, x5, lsl #2]\n" + "str w7, [x4, x5, lsl #2]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 51b\n" + "add w6, w6, #1\n" + "add x3, x3, #32\n" + "add x4, x4, x11\n" + "cmp w6, w2\n" + "blt 50b\n" + "41:\n" + "add %[dst_ptr], %[dst_ptr], #32\n" + // At this point we have completely finished writing values to the + // destination matrix for the current block. + + // Reload some params --- we had used x5 -- x7 for a few other things + // since the last time we had loaded them. + "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" + "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" + + // Move to the next block of the destination matrix, for the next iter + // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already + // been updated earlier. + // Have we reached the end row? + "cmp %w[row], w7\n" + "beq 20f\n" // yes, end row. + // Not end row. Move to the next row. + "add %w[row], %w[row], #8\n" + "b 21f\n" + "20:\n" + // Was already at end row. + "mov %w[row], w6\n" // Move back to first row. + "add %w[col], %w[col], #8\n" // Move to the next column. + "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #3\n" + "mov %[dst_ptr], %[dst_col_ptr]\n" + "21:\n" + + // Main loop exit condition: have we hit the end column? + "cmp %w[col], w8\n" + + // w1 is the number of levels of depth that we have already loaded + // LHS and RHS data for. Corresponding to the initial ld1 instructions + // above, this is currently 1. + "mov w1, #1\n" + + "ble 1b\n" + + // clang-format on + + : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr), + [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), + [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col) + : [ params ] "r"(¶ms), [dst_rows] "r"(params.dst_rows), + [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf) + : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc", + "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", + "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31"); +} + +// Variant of KernelFloatNeonOutOfOrder tuned for in-order CPUs that do not +// support dotprod (while dotprod by itself is not relevant to floating-point, +// this additional bit of information that we have about the target happens to +// be useful here). +// +// So a typical target CPU here would be ARM Cortex-A53 or the original +// Cortex-A55. +// +// This kernel is similar to and inspired by gemmlowp's +// NEON_64bit_GEMM_Float32_WithScalar_A53. +// which was contributed by David Mansell with very helpful +// comments. Specifically, see this comment about tuning for Cortex-A53: +// https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L4215 +void KernelFloatNeonInOrder(const KernelParamsFloat<8, 8>& params) { + profiler::ScopeLabel label("Kernel (kNeon, optimized for in-order cores)"); + + CheckOffsetsInKernelParamsFloat(params); + + const float* lhs_col_ptr = params.lhs_base_ptr; + const float* rhs_col_ptr = params.rhs_base_ptr; + const float* lhs_ptr = lhs_col_ptr; + const float* rhs_ptr = rhs_col_ptr; + float* dst_col_ptr = params.dst_base_ptr; + float* dst_ptr = dst_col_ptr; + int row = params.start_row; + int col = params.start_col; + + // The asm kernel below has the following NEON register allocation: + // + // v16 -- v31 are accumulators. + // During accumulation, v0 -- v3 are used to load data from LHS and RHS. + // + // RHS 1x8 block + // /-----------------------------------------\ + // |v2.s[0] ... v2.s[3] v3.s[0] ... v3.s[3]| + // \-----------------------------------------/ + // LHS 8x1 block + // /---------------------\ /-----------------------------------------\ + // | v0.s[0] | |v16.s[0] ... v30.s[0]| + // | ... | | ... ... | + // | v0.s[3] | |v16.s[3] ... v30.s[3]| + // | v1.s[0] | |v17.s[0] ... v31.s[0]| + // | ... | | ... ... | + // | v1.s[3] | |v17.s[3] ... v31.s[3]| + // \---------------------/ \-----------------------------------------/ + // accumulators 8x8 block + // + // There is no RUY_OPT_MAX_STREAMING 4x-unrolled part in this kernel because + // we did not observe a benefit of such partial unrolling on in-order CPUs. + // + // v4 -- v7 are unused, and v8 -- v15 are used for floading parameters used + // for the post-accumulation part of the kernel. + asm volatile( +#define RUY_MAKE_ZERO(reg) "dup " #reg ".4s, wzr\n" + + // clang-format off + + // Load some parameters into registers. + "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" + "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" + "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" + "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" + "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" + "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" + "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" + + + // Clear accumulators. + RUY_MAKE_ZERO(v16) + // Load the first 32 bytes of LHS and RHS data. + "ld1 {v0.4s}, [%[lhs_ptr]], #16\n" + RUY_MAKE_ZERO(v17) + "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" + RUY_MAKE_ZERO(v18) + "ld1 {v2.4s}, [%[rhs_ptr]], #16\n" + RUY_MAKE_ZERO(v19) + "ld1 {v3.4s}, [%[rhs_ptr]], #16\n" + RUY_MAKE_ZERO(v20) + RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #64]\n") + RUY_MAKE_ZERO(v21) + RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #64]\n") + RUY_MAKE_ZERO(v22) + RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #128]\n") + RUY_MAKE_ZERO(v23) + RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #128]\n") + RUY_MAKE_ZERO(v24) + RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #192]\n") + RUY_MAKE_ZERO(v25) + RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #192]\n") + RUY_MAKE_ZERO(v26) + RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #256]\n") + RUY_MAKE_ZERO(v27) + RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #256]\n") + RUY_MAKE_ZERO(v28) + RUY_MAKE_ZERO(v29) + RUY_MAKE_ZERO(v30) + RUY_MAKE_ZERO(v31) + + // w1 is the number of levels of depth that remain to load + // LHS and RHS data for. Corresponding to the initial ld1 instructions + // above, this is currently depth - 1. + "sub w1, w12, #1\n" + + // Main loop of the whole GEMM, over rows and columns of the + // destination matrix. + "1:\n" + + "cmp w1, #0\n" + "fmla v16.4s, v0.4s, v2.s[0]\n" + "fmla v18.4s, v0.4s, v2.s[1]\n" + "fmla v20.4s, v0.4s, v2.s[2]\n" + "fmla v22.4s, v0.4s, v2.s[3]\n" + + // Accumulation loop + "beq 79f\n" + + "2:\n" + + "fmla v24.4s, v0.4s, v3.s[0]\n" + "ldr x2, [%[lhs_ptr], #8]\n" + "fmla v26.4s, v0.4s, v3.s[1]\n" + "ldr x3, [%[lhs_ptr], #24]\n" + "fmla v28.4s, v0.4s, v3.s[2]\n" + "ldr x5, [%[rhs_ptr], #24]\n" + "fmla v30.4s, v0.4s, v3.s[3]\n" + "ldr x4, [%[rhs_ptr], #8]\n" + "fmla v25.4s, v1.4s, v3.s[0]\n" + "subs w1, w1, #1\n" + "ldr d0, [%[lhs_ptr]], #32\n" + "fmla v27.4s, v1.4s, v3.s[1]\n" + "fmla v29.4s, v1.4s, v3.s[2]\n" + "fmla v31.4s, v1.4s, v3.s[3]\n" + "ins v0.d[1], x2\n" + "ldr d3, [%[rhs_ptr], #16]\n" + "fmla v17.4s, v1.4s, v2.s[0]\n" + "fmla v19.4s, v1.4s, v2.s[1]\n" + "ins v3.d[1], x5\n" + "ldr d4, [%[rhs_ptr]], #32\n" + "fmla v21.4s, v1.4s, v2.s[2]\n" + "fmla v23.4s, v1.4s, v2.s[3]\n" + "fmla v16.4s, v0.4s, v4.s[0]\n" + "ins v4.d[1], x4\n" + "ldr d1, [%[lhs_ptr], #-16]\n" + "fmla v18.4s, v0.4s, v4.s[1]\n" + "fmla v20.4s, v0.4s, v4.s[2]\n" + "ins v1.d[1], x3\n" + RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #256]\n") + "mov v2.16b, v4.16b\n" + RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #256]\n") + "fmla v22.4s, v0.4s, v4.s[3]\n" + "bne 2b\n" + + "79:\n" + + // End of the inner loop on depth. Now perform the remaining + // multiply-adds of the last level of depth, for which the LHS + // and RHS data is already loaded. + + "fmla v24.4s, v0.4s, v3.s[0]\n" + "fmla v26.4s, v0.4s, v3.s[1]\n" + "fmla v28.4s, v0.4s, v3.s[2]\n" + "fmla v30.4s, v0.4s, v3.s[3]\n" + "fmla v25.4s, v1.4s, v3.s[0]\n" + "fmla v27.4s, v1.4s, v3.s[1]\n" + "fmla v29.4s, v1.4s, v3.s[2]\n" + "fmla v31.4s, v1.4s, v3.s[3]\n" + "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + "fmla v17.4s, v1.4s, v2.s[0]\n" + "fmla v19.4s, v1.4s, v2.s[1]\n" + "fmla v21.4s, v1.4s, v2.s[2]\n" + "fmla v23.4s, v1.4s, v2.s[3]\n" + + // End of accumulation. The registers v16 -- v31 contain the final + // int32 accumulator values of the current 8x8 destination block. + // We now have to compute the final 8-bit values from these int32 + // accumulators, and advance to the next 8x8 block. We intertwine + // these two aspects whenever possible for optimal pipelining, both + // at the data flow level (prefetch data for next block as early as + // possible) and instruction pipelining level (some of the next-block + // work can dual-issue with some of the final work on the current + // block). + + // Logic to advance to the next block in preparation for the next + // iteration of the main loop. For now, we only want to compute + // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are + // not yet ready to update the values of row and col, as we still need + // the current values for the rest of the work on the current block. + + "cmp %w[row], w7\n" // Have we finished the last row? + "bge 4f\n" // If finished last row, go to 4 + // Not finished last row: then advance to next row. + "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #3\n" + "b 5f\n" + "4:\n" // Finished last row... + "mov %[lhs_col_ptr], x5\n" // Go back to first row + // Now we need to advance to the next column. If we already + // finished the last column, then in principle we are done, however + // we can't just return here, as we need to allow the end work of the + // current block to complete. The good news is that at this point it + // doesn't matter what data we load for the next column, since + // we will exit from the main loop below before actually storing + // anything computed from that data. + "cmp %w[col], w8\n" // Have we finished the last column? + "bge 5f\n" // If yes, just carry on without updating the column pointer. + // Not finished last column: then advance to next column. + "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #3\n" + "5:\n" + + // Set the LHS and RHS data pointers to the start of the columns just + // computed. + "mov %[lhs_ptr], %[lhs_col_ptr]\n" + "mov %[rhs_ptr], %[rhs_col_ptr]\n" + + // Load some parameters needed for the end work on current block. + "ldrb w4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" + "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" + + // Offset these base pointers as needed given the current row, col. + "add x5, x1, %x[row], lsl #2\n" + + "tst w4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" + "csel x1, x1, x5, eq\n" + + // Load 8 bias values. + "ld1 {v14.4s}, [x1], #16\n" + "ld1 {v15.4s}, [x1]\n" + + // Now that we know what LHS and RHS data the next iteration of the + // main loop will need to load, we start loading the first 32 bytes of + // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore + // in the rest of the work on the current block. + "ld1 {v0.4s}, [%[lhs_ptr]], #16\n" + "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" + "ld1 {v2.4s}, [%[rhs_ptr]], #16\n" + "ld1 {v3.4s}, [%[rhs_ptr]], #16\n" + + // Perform the bias-addition (per the above, we have just folded into + // the bias the (depth * lhs_zero_point * rhs_zero_point) term.) + "fadd v16.4s, v16.4s, v14.4s\n" + "fadd v17.4s, v17.4s, v15.4s\n" + "fadd v18.4s, v18.4s, v14.4s\n" + "fadd v19.4s, v19.4s, v15.4s\n" + "fadd v20.4s, v20.4s, v14.4s\n" + "fadd v21.4s, v21.4s, v15.4s\n" + "fadd v22.4s, v22.4s, v14.4s\n" + "fadd v23.4s, v23.4s, v15.4s\n" + "fadd v24.4s, v24.4s, v14.4s\n" + "fadd v25.4s, v25.4s, v15.4s\n" + "fadd v26.4s, v26.4s, v14.4s\n" + "fadd v27.4s, v27.4s, v15.4s\n" + "fadd v28.4s, v28.4s, v14.4s\n" + "fadd v29.4s, v29.4s, v15.4s\n" + "fadd v30.4s, v30.4s, v14.4s\n" + "fadd v31.4s, v31.4s, v15.4s\n" + + // Load the clamp_min, clamp_max bounds + "ldr w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + "dup v14.4s, w2\n" // clamp_min + "dup v15.4s, w3\n" // clamp_max + + // Apply the clamp_min bound + "fmax v16.4s, v16.4s, v14.4s\n" + "fmax v17.4s, v17.4s, v14.4s\n" + "fmax v18.4s, v18.4s, v14.4s\n" + "fmax v19.4s, v19.4s, v14.4s\n" + "fmax v20.4s, v20.4s, v14.4s\n" + "fmax v21.4s, v21.4s, v14.4s\n" + "fmax v22.4s, v22.4s, v14.4s\n" + "fmax v23.4s, v23.4s, v14.4s\n" + "fmax v24.4s, v24.4s, v14.4s\n" + "fmax v25.4s, v25.4s, v14.4s\n" + "fmax v26.4s, v26.4s, v14.4s\n" + "fmax v27.4s, v27.4s, v14.4s\n" + "fmax v28.4s, v28.4s, v14.4s\n" + "fmax v29.4s, v29.4s, v14.4s\n" + "fmax v30.4s, v30.4s, v14.4s\n" + "fmax v31.4s, v31.4s, v14.4s\n" + + // Apply the clamp_max bound + "fmin v16.4s, v16.4s, v15.4s\n" + "fmin v17.4s, v17.4s, v15.4s\n" + "fmin v18.4s, v18.4s, v15.4s\n" + "fmin v19.4s, v19.4s, v15.4s\n" + "fmin v20.4s, v20.4s, v15.4s\n" + "fmin v21.4s, v21.4s, v15.4s\n" + "fmin v22.4s, v22.4s, v15.4s\n" + "fmin v23.4s, v23.4s, v15.4s\n" + "fmin v24.4s, v24.4s, v15.4s\n" + "fmin v25.4s, v25.4s, v15.4s\n" + "fmin v26.4s, v26.4s, v15.4s\n" + "fmin v27.4s, v27.4s, v15.4s\n" + "fmin v28.4s, v28.4s, v15.4s\n" + "fmin v29.4s, v29.4s, v15.4s\n" + "fmin v30.4s, v30.4s, v15.4s\n" + "fmin v31.4s, v31.4s, v15.4s\n" + + // Compute how much of the 8x8 block of destination 8bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 8x8, there are some 8x8 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + "sub w2, %w[dst_cols], %w[col]\n" + "mov w3, #8\n" + "cmp w1, #8\n" + // Compute w1 = how many rows of the 8x8 block fit + "csel w1, w1, w3, le\n" + "cmp w2, #8\n" + // Compute w2 = how many cols of the 8x8 block fit + "csel w2, w2, w3, le\n" + + // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. + "cmp w1, w3\n" + "ccmp w2, w3, 0, eq\n" + // Yes, all of the 8x8 block fits, go to fast path. + "beq 30f\n" + // Not all of the 8x8 block fits. + // Set (x3 address, x4 stride) to write to dst_tmp_buf + "mov x3, %[dst_tmp_buf]\n" + "mov x4, #32\n" + "b 31f\n" + "30:\n" + // Yes, all of the 8x8 block fits. + // Set (x3 address, x4 stride) to write directly to destination matrix. + "mov x3, %[dst_ptr]\n" + "mov x4, x11\n" + "31:\n" + + // Write our 8bit values to the destination described by + // (x3 address, x4 stride). + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "str q16, [x3, #0]\n" + "str q17, [x3, #16]\n" + "add x3, x3, x4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + RUY_MAKE_ZERO(v16) + RUY_MAKE_ZERO(v17) + "str q18, [x3, #0]\n" + "str q19, [x3, #16]\n" + "add x3, x3, x4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + RUY_MAKE_ZERO(v18) + RUY_MAKE_ZERO(v19) + "str q20, [x3, #0]\n" + "str q21, [x3, #16]\n" + "add x3, x3, x4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + RUY_MAKE_ZERO(v20) + RUY_MAKE_ZERO(v21) + "str q22, [x3, #0]\n" + "str q23, [x3, #16]\n" + "add x3, x3, x4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + RUY_MAKE_ZERO(v22) + RUY_MAKE_ZERO(v23) + "str q24, [x3, #0]\n" + "str q25, [x3, #16]\n" + "add x3, x3, x4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + RUY_MAKE_ZERO(v24) + RUY_MAKE_ZERO(v25) + "str q26, [x3, #0]\n" + "str q27, [x3, #16]\n" + "add x3, x3, x4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + RUY_MAKE_ZERO(v26) + RUY_MAKE_ZERO(v27) + "str q28, [x3, #0]\n" + "str q29, [x3, #16]\n" + "add x3, x3, x4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + RUY_MAKE_ZERO(v28) + RUY_MAKE_ZERO(v29) + "str q30, [x3, #0]\n" + "str q31, [x3, #16]\n" + RUY_MAKE_ZERO(v30) + RUY_MAKE_ZERO(v31) + + // If all of the 8x8 block fits, we just finished writing it to the + // destination, so we skip the next part. + "beq 41f\n" + // Not all of the 8x8 block fits in the destination matrix. We just + // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over + // it to copy into the destination matrix the part that fits. + "mov x3, %[dst_tmp_buf]\n" + "mov x4, %[dst_ptr]\n" + "mov w6, #0\n" + "50:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov w5, #0\n" + "51:\n" + "ldr w7, [x3, x5, lsl #2]\n" + "str w7, [x4, x5, lsl #2]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 51b\n" + "add w6, w6, #1\n" + "add x3, x3, #32\n" + "add x4, x4, x11\n" + "cmp w6, w2\n" + "blt 50b\n" + "41:\n" + "add %[dst_ptr], %[dst_ptr], #32\n" + // At this point we have completely finished writing values to the + // destination matrix for the current block. + + // Reload some params --- we had used x5 -- x7 for a few other things + // since the last time we had loaded them. + "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" + "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" + + // Move to the next block of the destination matrix, for the next iter + // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already + // been updated earlier. + // Have we reached the end row? + "cmp %w[row], w7\n" + "beq 20f\n" // yes, end row. + // Not end row. Move to the next row. + "add %w[row], %w[row], #8\n" + "b 21f\n" + "20:\n" + // Was already at end row. + "mov %w[row], w6\n" // Move back to first row. + "add %w[col], %w[col], #8\n" // Move to the next column. + "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #3\n" + "mov %[dst_ptr], %[dst_col_ptr]\n" + "21:\n" + + // Main loop exit condition: have we hit the end column? + "cmp %w[col], w8\n" + + // w1 is the number of levels of depth that remain to load + // LHS and RHS data for. Corresponding to the initial ld1 instructions + // above, this is currently depth - 1. + "sub w1, w12, #1\n" + + "ble 1b\n" + + // clang-format on + + : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr), + [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), + [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col) + : [ params ] "r"(¶ms), [dst_rows] "r"(params.dst_rows), + [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf) + : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc", + "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", + "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31"); +} + +// Variant of KernelFloatNeonInOrder tuned for in-order CPUs that do +// support dotprod (while dotprod by itself is not relevant to floating-point, +// this additional bit of information that we have about the target happens to +// be useful here). +// +// So a typical target CPU here would be ARM Cortex-A55r1. +// +// This kernel is similar to and inspired by gemmlowp's +// NEON_64bit_GEMM_Float32_WithScalar_A55r1. +// which was contributed by David Mansell with very helpful +// comments. Specifically, see this comment about tuning for Cortex-A55r1: +// https://github.com/google/gemmlowp/blob/36212ad3651871bc3e9a599f1a6d5324778aea25/standalone/neon-gemm-kernel-benchmark.cc#L4412 +void KernelFloatNeonDotprodInOrder(const KernelParamsFloat<8, 8>& params) { + profiler::ScopeLabel label( + "Kernel (kNeonDotprod, optimized for in-order cores)"); + + CheckOffsetsInKernelParamsFloat(params); + + const float* lhs_col_ptr = params.lhs_base_ptr; + const float* rhs_col_ptr = params.rhs_base_ptr; + const float* lhs_ptr = lhs_col_ptr; + const float* rhs_ptr = rhs_col_ptr; + float* dst_col_ptr = params.dst_base_ptr; + float* dst_ptr = dst_col_ptr; + int row = params.start_row; + int col = params.start_col; + + // The asm kernel below has the following NEON register allocation: + // + // v16 -- v31 are accumulators. + // During accumulation, v0 -- v3 are used to load data from LHS and RHS. + // + // RHS 1x8 block + // /-----------------------------------------\ + // |v2.s[0] ... v2.s[3] v3.s[0] ... v3.s[3]| + // \-----------------------------------------/ + // LHS 8x1 block + // /---------------------\ /-----------------------------------------\ + // | v0.s[0] | |v16.s[0] ... v30.s[0]| + // | ... | | ... ... | + // | v0.s[3] | |v16.s[3] ... v30.s[3]| + // | v1.s[0] | |v17.s[0] ... v31.s[0]| + // | ... | | ... ... | + // | v1.s[3] | |v17.s[3] ... v31.s[3]| + // \---------------------/ \-----------------------------------------/ + // accumulators 8x8 block + // + // There is no RUY_OPT_MAX_STREAMING 4x-unrolled part in this kernel because + // we did not observe a benefit of such partial unrolling on in-order CPUs. + // + // v4 -- v7 are unused, and v8 -- v15 are used for floading parameters used + // for the post-accumulation part of the kernel. + asm volatile( +#define RUY_MAKE_ZERO(reg) "dup " #reg ".4s, wzr\n" + + // clang-format off + + // Load some parameters into registers. + "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" + "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" + "ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n" + "ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n" + "ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n" + "ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n" + "ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n" + + + // Clear accumulators. + RUY_MAKE_ZERO(v16) + // Load the first 32 bytes of LHS and RHS data. + "ld1 {v0.4s}, [%[lhs_ptr]], #16\n" + RUY_MAKE_ZERO(v17) + "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" + RUY_MAKE_ZERO(v18) + "ld1 {v2.4s}, [%[rhs_ptr]], #16\n" + RUY_MAKE_ZERO(v19) + "ld1 {v3.4s}, [%[rhs_ptr]], #16\n" + RUY_MAKE_ZERO(v20) + RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #64]\n") + RUY_MAKE_ZERO(v21) + RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #64]\n") + RUY_MAKE_ZERO(v22) + RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #128]\n") + RUY_MAKE_ZERO(v23) + RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #128]\n") + RUY_MAKE_ZERO(v24) + RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #192]\n") + RUY_MAKE_ZERO(v25) + RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #192]\n") + RUY_MAKE_ZERO(v26) + RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #256]\n") + RUY_MAKE_ZERO(v27) + RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #256]\n") + RUY_MAKE_ZERO(v28) + RUY_MAKE_ZERO(v29) + RUY_MAKE_ZERO(v30) + RUY_MAKE_ZERO(v31) + + // w1 is the number of levels of depth that remain to load + // LHS and RHS data for. Corresponding to the initial ld1 instructions + // above, this is currently depth - 1. + "sub w1, w12, #1\n" + + // Main loop of the whole GEMM, over rows and columns of the + // destination matrix. + "1:\n" + + "cmp w1, #0\n" + "fmla v16.4s, v0.4s, v2.s[0]\n" + "fmla v18.4s, v0.4s, v2.s[1]\n" + "fmla v20.4s, v0.4s, v2.s[2]\n" + "fmla v22.4s, v0.4s, v2.s[3]\n" + + // Accumulation loop + "beq 79f\n" + + "2:\n" + + RUY_PREFETCH_LOAD("prfm pldl1keep, [%[lhs_ptr], #256]\n") + "fmla v24.4s, v0.4s, v3.s[0]\n" + "ldr x2, [%[lhs_ptr], #8]\n" + "fmla v26.4s, v0.4s, v3.s[1]\n" + "ldr x3, [%[lhs_ptr], #24]\n" + "fmla v28.4s, v0.4s, v3.s[2]\n" + "ldr x5, [%[rhs_ptr], #24]\n" + "fmla v30.4s, v0.4s, v3.s[3]\n" + "ldr d0, [%[lhs_ptr]], #32\n" + "fmla v25.4s, v1.4s, v3.s[0]\n" + "ldr x4, [%[rhs_ptr], #8]\n" + "fmla v27.4s, v1.4s, v3.s[1]\n" + "subs w1, w1, #1\n" + "fmla v29.4s, v1.4s, v3.s[2]\n" + "ins v0.d[1], x2\n" + "fmla v31.4s, v1.4s, v3.s[3]\n" + "ldr d3, [%[rhs_ptr], #16]\n" + "fmla v17.4s, v1.4s, v2.s[0]\n" + "ins v3.d[1], x5\n" + "fmla v19.4s, v1.4s, v2.s[1]\n" + "ldr d4, [%[rhs_ptr]], #32\n" + "fmla v21.4s, v1.4s, v2.s[2]\n" + "ins v4.d[1], x4\n" + "fmla v23.4s, v1.4s, v2.s[3]\n" + RUY_PREFETCH_LOAD("prfm pldl1keep, [%[rhs_ptr], #256]\n") + "fmla v16.4s, v0.4s, v4.s[0]\n" + "ldr d1, [%[lhs_ptr], #-16]\n" + "fmla v18.4s, v0.4s, v4.s[1]\n" + "ins v1.d[1], x3\n" + "fmla v20.4s, v0.4s, v4.s[2]\n" + "mov v2.16b, v4.16b\n" + "fmla v22.4s, v0.4s, v4.s[3]\n" + "bne 2b\n" + + "79:\n" + + // End of the inner loop on depth. Now perform the remaining + // multiply-adds of the last level of depth, for which the LHS + // and RHS data is already loaded. + + "fmla v24.4s, v0.4s, v3.s[0]\n" + "fmla v26.4s, v0.4s, v3.s[1]\n" + "fmla v28.4s, v0.4s, v3.s[2]\n" + "fmla v30.4s, v0.4s, v3.s[3]\n" + "fmla v25.4s, v1.4s, v3.s[0]\n" + "fmla v27.4s, v1.4s, v3.s[1]\n" + "fmla v29.4s, v1.4s, v3.s[2]\n" + "fmla v31.4s, v1.4s, v3.s[3]\n" + "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + "fmla v17.4s, v1.4s, v2.s[0]\n" + "fmla v19.4s, v1.4s, v2.s[1]\n" + "fmla v21.4s, v1.4s, v2.s[2]\n" + "fmla v23.4s, v1.4s, v2.s[3]\n" + + // End of accumulation. The registers v16 -- v31 contain the final + // int32 accumulator values of the current 8x8 destination block. + // We now have to compute the final 8-bit values from these int32 + // accumulators, and advance to the next 8x8 block. We intertwine + // these two aspects whenever possible for optimal pipelining, both + // at the data flow level (prefetch data for next block as early as + // possible) and instruction pipelining level (some of the next-block + // work can dual-issue with some of the final work on the current + // block). + + // Logic to advance to the next block in preparation for the next + // iteration of the main loop. For now, we only want to compute + // the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are + // not yet ready to update the values of row and col, as we still need + // the current values for the rest of the work on the current block. + + "cmp %w[row], w7\n" // Have we finished the last row? + "bge 4f\n" // If finished last row, go to 4 + // Not finished last row: then advance to next row. + "add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #3\n" + "b 5f\n" + "4:\n" // Finished last row... + "mov %[lhs_col_ptr], x5\n" // Go back to first row + // Now we need to advance to the next column. If we already + // finished the last column, then in principle we are done, however + // we can't just return here, as we need to allow the end work of the + // current block to complete. The good news is that at this point it + // doesn't matter what data we load for the next column, since + // we will exit from the main loop below before actually storing + // anything computed from that data. + "cmp %w[col], w8\n" // Have we finished the last column? + "bge 5f\n" // If yes, just carry on without updating the column pointer. + // Not finished last column: then advance to next column. + "add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #3\n" + "5:\n" + + // Set the LHS and RHS data pointers to the start of the columns just + // computed. + "mov %[lhs_ptr], %[lhs_col_ptr]\n" + "mov %[rhs_ptr], %[rhs_col_ptr]\n" + + // Load some parameters needed for the end work on current block. + "ldrb w4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n" + "ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n" + + // Offset these base pointers as needed given the current row, col. + "add x5, x1, %x[row], lsl #2\n" + + "tst w4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n" + "csel x1, x1, x5, eq\n" + + // Load 8 bias values. + "ld1 {v14.4s}, [x1], #16\n" + "ld1 {v15.4s}, [x1]\n" + + // Now that we know what LHS and RHS data the next iteration of the + // main loop will need to load, we start loading the first 32 bytes of + // each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore + // in the rest of the work on the current block. + "ld1 {v0.4s}, [%[lhs_ptr]], #16\n" + "ld1 {v1.4s}, [%[lhs_ptr]], #16\n" + "ld1 {v2.4s}, [%[rhs_ptr]], #16\n" + "ld1 {v3.4s}, [%[rhs_ptr]], #16\n" + + // Perform the bias-addition (per the above, we have just folded into + // the bias the (depth * lhs_zero_point * rhs_zero_point) term.) + "fadd v16.4s, v16.4s, v14.4s\n" + "fadd v17.4s, v17.4s, v15.4s\n" + "fadd v18.4s, v18.4s, v14.4s\n" + "fadd v19.4s, v19.4s, v15.4s\n" + "fadd v20.4s, v20.4s, v14.4s\n" + "fadd v21.4s, v21.4s, v15.4s\n" + "fadd v22.4s, v22.4s, v14.4s\n" + "fadd v23.4s, v23.4s, v15.4s\n" + "fadd v24.4s, v24.4s, v14.4s\n" + "fadd v25.4s, v25.4s, v15.4s\n" + "fadd v26.4s, v26.4s, v14.4s\n" + "fadd v27.4s, v27.4s, v15.4s\n" + "fadd v28.4s, v28.4s, v14.4s\n" + "fadd v29.4s, v29.4s, v15.4s\n" + "fadd v30.4s, v30.4s, v14.4s\n" + "fadd v31.4s, v31.4s, v15.4s\n" + + // Load the clamp_min, clamp_max bounds + "ldr w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n" + "ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n" + "dup v14.4s, w2\n" // clamp_min + "dup v15.4s, w3\n" // clamp_max + + // Apply the clamp_min bound + "fmax v16.4s, v16.4s, v14.4s\n" + "fmax v17.4s, v17.4s, v14.4s\n" + "fmax v18.4s, v18.4s, v14.4s\n" + "fmax v19.4s, v19.4s, v14.4s\n" + "fmax v20.4s, v20.4s, v14.4s\n" + "fmax v21.4s, v21.4s, v14.4s\n" + "fmax v22.4s, v22.4s, v14.4s\n" + "fmax v23.4s, v23.4s, v14.4s\n" + "fmax v24.4s, v24.4s, v14.4s\n" + "fmax v25.4s, v25.4s, v14.4s\n" + "fmax v26.4s, v26.4s, v14.4s\n" + "fmax v27.4s, v27.4s, v14.4s\n" + "fmax v28.4s, v28.4s, v14.4s\n" + "fmax v29.4s, v29.4s, v14.4s\n" + "fmax v30.4s, v30.4s, v14.4s\n" + "fmax v31.4s, v31.4s, v14.4s\n" + + // Apply the clamp_max bound + "fmin v16.4s, v16.4s, v15.4s\n" + "fmin v17.4s, v17.4s, v15.4s\n" + "fmin v18.4s, v18.4s, v15.4s\n" + "fmin v19.4s, v19.4s, v15.4s\n" + "fmin v20.4s, v20.4s, v15.4s\n" + "fmin v21.4s, v21.4s, v15.4s\n" + "fmin v22.4s, v22.4s, v15.4s\n" + "fmin v23.4s, v23.4s, v15.4s\n" + "fmin v24.4s, v24.4s, v15.4s\n" + "fmin v25.4s, v25.4s, v15.4s\n" + "fmin v26.4s, v26.4s, v15.4s\n" + "fmin v27.4s, v27.4s, v15.4s\n" + "fmin v28.4s, v28.4s, v15.4s\n" + "fmin v29.4s, v29.4s, v15.4s\n" + "fmin v30.4s, v30.4s, v15.4s\n" + "fmin v31.4s, v31.4s, v15.4s\n" + + // Compute how much of the 8x8 block of destination 8bit values that + // we have computed, fit in the destination matrix. Typically, all of + // it fits, but when the destination matrix shape is not a multiple + // of 8x8, there are some 8x8 blocks along the boundaries that do + // not fit entirely. + "sub w1, %w[dst_rows], %w[row]\n" + "sub w2, %w[dst_cols], %w[col]\n" + "mov w3, #8\n" + "cmp w1, #8\n" + // Compute w1 = how many rows of the 8x8 block fit + "csel w1, w1, w3, le\n" + "cmp w2, #8\n" + // Compute w2 = how many cols of the 8x8 block fit + "csel w2, w2, w3, le\n" + + // Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits. + "cmp w1, w3\n" + "ccmp w2, w3, 0, eq\n" + // Yes, all of the 8x8 block fits, go to fast path. + "beq 30f\n" + // Not all of the 8x8 block fits. + // Set (x3 address, x4 stride) to write to dst_tmp_buf + "mov x3, %[dst_tmp_buf]\n" + "mov x4, #32\n" + "b 31f\n" + "30:\n" + // Yes, all of the 8x8 block fits. + // Set (x3 address, x4 stride) to write directly to destination matrix. + "mov x3, %[dst_ptr]\n" + "mov x4, x11\n" + "31:\n" + + // Write our 8bit values to the destination described by + // (x3 address, x4 stride). + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + "str q16, [x3, #0]\n" + "str q17, [x3, #16]\n" + "add x3, x3, x4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + RUY_MAKE_ZERO(v16) + RUY_MAKE_ZERO(v17) + "str q18, [x3, #0]\n" + "str q19, [x3, #16]\n" + "add x3, x3, x4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + RUY_MAKE_ZERO(v18) + RUY_MAKE_ZERO(v19) + "str q20, [x3, #0]\n" + "str q21, [x3, #16]\n" + "add x3, x3, x4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + RUY_MAKE_ZERO(v20) + RUY_MAKE_ZERO(v21) + "str q22, [x3, #0]\n" + "str q23, [x3, #16]\n" + "add x3, x3, x4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + RUY_MAKE_ZERO(v22) + RUY_MAKE_ZERO(v23) + "str q24, [x3, #0]\n" + "str q25, [x3, #16]\n" + "add x3, x3, x4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + RUY_MAKE_ZERO(v24) + RUY_MAKE_ZERO(v25) + "str q26, [x3, #0]\n" + "str q27, [x3, #16]\n" + "add x3, x3, x4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + RUY_MAKE_ZERO(v26) + RUY_MAKE_ZERO(v27) + "str q28, [x3, #0]\n" + "str q29, [x3, #16]\n" + "add x3, x3, x4\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x3]\n") + RUY_MAKE_ZERO(v28) + RUY_MAKE_ZERO(v29) + "str q30, [x3, #0]\n" + "str q31, [x3, #16]\n" + RUY_MAKE_ZERO(v30) + RUY_MAKE_ZERO(v31) + + // If all of the 8x8 block fits, we just finished writing it to the + // destination, so we skip the next part. + "beq 41f\n" + // Not all of the 8x8 block fits in the destination matrix. We just + // wrote it to dst_tmp_buf. Now we perform the slow scalar loop over + // it to copy into the destination matrix the part that fits. + "mov x3, %[dst_tmp_buf]\n" + "mov x4, %[dst_ptr]\n" + "mov w6, #0\n" + "50:\n" + RUY_PREFETCH_STORE("prfm pstl1strm, [x4]\n") + "mov w5, #0\n" + "51:\n" + "ldr w7, [x3, x5, lsl #2]\n" + "str w7, [x4, x5, lsl #2]\n" + "add w5, w5, #1\n" + "cmp w5, w1\n" + "blt 51b\n" + "add w6, w6, #1\n" + "add x3, x3, #32\n" + "add x4, x4, x11\n" + "cmp w6, w2\n" + "blt 50b\n" + "41:\n" + "add %[dst_ptr], %[dst_ptr], #32\n" + // At this point we have completely finished writing values to the + // destination matrix for the current block. + + // Reload some params --- we had used x5 -- x7 for a few other things + // since the last time we had loaded them. + "ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n" + "ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n" + "ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n" + + // Move to the next block of the destination matrix, for the next iter + // of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already + // been updated earlier. + // Have we reached the end row? + "cmp %w[row], w7\n" + "beq 20f\n" // yes, end row. + // Not end row. Move to the next row. + "add %w[row], %w[row], #8\n" + "b 21f\n" + "20:\n" + // Was already at end row. + "mov %w[row], w6\n" // Move back to first row. + "add %w[col], %w[col], #8\n" // Move to the next column. + "add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #3\n" + "mov %[dst_ptr], %[dst_col_ptr]\n" + "21:\n" + + // Main loop exit condition: have we hit the end column? + "cmp %w[col], w8\n" + + // w1 is the number of levels of depth that remain to load + // LHS and RHS data for. Corresponding to the initial ld1 instructions + // above, this is currently depth - 1. + "sub w1, w12, #1\n" + + "ble 1b\n" + + // clang-format on + + : [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr), + [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), + [dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col) + : [ params ] "r"(¶ms), [dst_rows] "r"(params.dst_rows), + [dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf) + : "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc", + "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", + "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31"); +} +#undef RUY_OFFSET_BIAS +#undef RUY_OFFSET_FLAGS +#undef RUY_OFFSET_LHS_BASE_PTR +#undef RUY_OFFSET_CLAMP_MIN +#undef RUY_OFFSET_CLAMP_MAX +#undef RUY_OFFSET_START_ROW +#undef RUY_OFFSET_LAST_ROW +#undef RUY_OFFSET_LAST_COL +#undef RUY_OFFSET_LHS_STRIDE +#undef RUY_OFFSET_RHS_STRIDE +#undef RUY_OFFSET_DST_STRIDE +#undef RUY_OFFSET_DEPTH +#undef RUY_OFFSET_START_COL +#undef RUY_OFFSET_RHS_BASE_PTR +#undef RUY_OFFSET_DST_BASE_PTR + +#endif // RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM) + +} // namespace ruy diff --git a/ruy/kernel_avx2.cc b/ruy/kernel_avx2.cc new file mode 100644 index 0000000..13fe22b --- /dev/null +++ b/ruy/kernel_avx2.cc @@ -0,0 +1,1664 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "ruy/check_macros.h" +#include "ruy/kernel.h" +#include "ruy/opt_set.h" +#include "ruy/platform.h" +#include "ruy/profiler/instrumentation.h" + +#if RUY_PLATFORM(AVX2) && RUY_OPT_ENABLED(RUY_OPT_ASM) +#include // IWYU pragma: keep +#endif + +namespace ruy { + +#if !(RUY_PLATFORM(AVX2) && RUY_OPT_ENABLED(RUY_OPT_ASM)) + +void Kernel8bitAvx2(const KernelParams8bit<8, 8>& params) { + // CPU-ID-based checks should disable the path that would reach this point. + RUY_DCHECK(false); +} + +void Kernel8bitAvx2SingleCol(const KernelParams8bit<8, 8>& params) { + // CPU-ID-based checks should disable the path that would reach this point. + RUY_DCHECK(false); +} + +void KernelFloatAvx2(const KernelParamsFloat<8, 8>& params) { + // CPU-ID-based checks should disable the path that would reach this point. + RUY_DCHECK(false); +} + +void KernelFloatAvx2SingleCol(const KernelParamsFloat<8, 8>& params) { + // CPU-ID-based checks should disable the path that would reach this point. + RUY_DCHECK(false); +} + +#else // RUY_PLATFORM(AVX2) && RUY_OPT_ENABLED(RUY_OPT_ASM) + +static constexpr int kAvx8bitBlockSize = 8; +static constexpr int kAvx8bitInnerSize = 4; + +namespace { +namespace intrin_utils { + +inline __m256 mm256_n_loadu_epi32(int n, const std::int32_t* src) { + switch (n) { + case 0: + return _mm256_setzero_si256(); + case 1: + return _mm256_setr_m128(_mm_setr_epi32(src[0], 0, 0, 0), + _mm_setzero_si128()); + case 2: + return _mm256_setr_m128(_mm_setr_epi32(src[0], src[1], 0, 0), + _mm_setzero_si128()); + case 3: + return _mm256_setr_m128(_mm_setr_epi32(src[0], src[1], src[2], 0), + _mm_setzero_si128()); + case 4: + return _mm256_castsi128_si256( + _mm_loadu_si128(reinterpret_cast<__m128i const*>(src))); + case 5: + return _mm256_setr_epi32(src[0], src[1], src[2], src[3], src[4], 0, 0, 0); + case 6: + return _mm256_setr_epi32(src[0], src[1], src[2], src[3], src[4], src[5], + 0, 0); + case 7: + return _mm256_setr_epi32(src[0], src[1], src[2], src[3], src[4], src[5], + src[6], 0); + case 8: + return _mm256_loadu_si256(reinterpret_cast<__m256i const*>(src)); + default: + RUY_DCHECK_LT(n, 9); + return _mm256_setzero_si256(); + } +} + +inline void mm256_n_storeu_cvtepi32_epi8(std::uint8_t* dst, int residual_rows, + const __m256 v) { + // Select bytes 0, 4, 8, 12 within each lane, effectively truncating. + const __m256i repack_perm = _mm256_set1_epi32(0x0c080400); + __m256i shuffled_v; + if (residual_rows > 1) { + // This selects 0, 4, 8, 12, 0, 4, 8, 12, ..., but we only use the first 4 + // in each 128-bit lane. + shuffled_v = _mm256_shuffle_epi8(v, repack_perm); + } + switch (residual_rows) { + case 0: + break; + case 1: + dst[0] = _mm256_extract_epi8(v, 0); + break; + case 2: + _mm_storeu_si16(dst, _mm256_extracti128_si256(shuffled_v, 0)); + break; + case 3: { + __m128i trailing_packed = _mm256_extracti128_si256(shuffled_v, 0); + _mm_storeu_si16(dst, trailing_packed); + dst[2] = _mm_extract_epi8(trailing_packed, 2); + break; + } + case 4: + _mm_storeu_si32(dst, _mm256_extracti128_si256(shuffled_v, 0)); + break; + case 5: + _mm_storeu_si32(dst, _mm256_extracti128_si256(shuffled_v, 0)); + dst[4] = _mm256_extract_epi8(shuffled_v, 16); + break; + case 6: + _mm_storeu_si32(dst, _mm256_extracti128_si256(shuffled_v, 0)); + _mm_storeu_si16(dst + 4, _mm256_extracti128_si256(shuffled_v, 1)); + break; + case 7: { + _mm_storeu_si32(dst, _mm256_extracti128_si256(shuffled_v, 0)); + __m128i trailing_packed = _mm256_extracti128_si256(shuffled_v, 1); + _mm_storeu_si16(dst + 4, trailing_packed); + dst[6] = _mm_extract_epi8(trailing_packed, 2); + break; + } + case 8: + _mm_storeu_si32(dst, _mm256_extracti128_si256(shuffled_v, 0)); + _mm_storeu_si32(dst + 4, _mm256_extracti128_si256(shuffled_v, 1)); + break; + default: + RUY_DCHECK_LE(residual_rows, 8); + break; + } +} + +inline void mm256_storeu_cvtepi32_epi8(std::uint8_t* dst, const __m256 v) { + // Select bytes 0, 4, 8, 12 within each lane, effectively truncating. + const __m256i repack_perm = _mm256_set1_epi32(0x0c080400); + const __m256i shuffled_v = _mm256_shuffle_epi8(v, repack_perm); + _mm_storeu_si32(dst, _mm256_extracti128_si256(shuffled_v, 0)); + _mm_storeu_si32(dst + 4, _mm256_extracti128_si256(shuffled_v, 1)); +} + +inline void mm256_n_storeu_cvtepi32_epi8(std::int8_t* dst, int residual_rows, + const __m256 v) { + intrin_utils::mm256_n_storeu_cvtepi32_epi8( + reinterpret_cast(dst), residual_rows, v); +} + +inline void mm256_storeu_cvtepi32_epi8(std::int8_t* dst, const __m256 v) { + // Select bytes 0, 4, 8, 12 within each lane, effectively truncating. + const __m256i repack_perm = _mm256_set1_epi32(0x0c080400); + const __m256i shuffled_v = _mm256_shuffle_epi8(v, repack_perm); + _mm_storeu_si32(dst, _mm256_extracti128_si256(shuffled_v, 0)); + _mm_storeu_si32(dst + 4, _mm256_extracti128_si256(shuffled_v, 1)); +} + +inline void mm256_n_storeu_cvtepi32_epi16(std::int16_t* dst, int residual_rows, + const __m256 v) { + // Select bytes 0, 1, 4, 5, 8, 9, 12, 13 within each lane, effectively + // truncating each 16-bit integer. + const __m256i repack_perm = _mm256_set1_epi64x(0x0d0c090805040100); + __m256i shuffled_v; + __m128i shuffled_v_low; + if (residual_rows > 1) { + shuffled_v = _mm256_shuffle_epi8(v, repack_perm); + shuffled_v_low = _mm256_extracti128_si256(shuffled_v, 0); + } else { + shuffled_v_low = _mm256_extracti128_si256(v, 0); + } + switch (residual_rows) { + case 0: + break; + case 1: + _mm_storeu_si16(dst, shuffled_v_low); + break; + case 2: + _mm_storeu_si32(dst, shuffled_v_low); + break; + case 3: { + _mm_storeu_si32(dst, shuffled_v_low); + dst[2] = _mm_extract_epi16(shuffled_v_low, 2); + break; + } + case 4: + _mm_storeu_si64(dst, shuffled_v_low); + break; + case 5: + _mm_storeu_si64(dst, shuffled_v_low); + dst[4] = _mm256_extract_epi16(shuffled_v, 8); + break; + case 6: + _mm_storeu_si64(dst, shuffled_v_low); + _mm_storeu_si32(dst + 4, _mm256_extracti128_si256(shuffled_v, 1)); + break; + case 7: { + _mm_storeu_si64(dst, shuffled_v_low); + __m128i trailing_packed = _mm256_extracti128_si256(shuffled_v, 1); + _mm_storeu_si32(dst + 4, trailing_packed); + dst[6] = _mm_extract_epi16(trailing_packed, 2); + break; + } + case 8: + _mm_storeu_si64(dst, _mm256_extracti128_si256(shuffled_v, 0)); + _mm_storeu_si64(dst + 4, _mm256_extracti128_si256(shuffled_v, 1)); + break; + default: + RUY_DCHECK_LE(residual_rows, 8); + break; + } +} + +inline void mm256_storeu_cvtepi32_epi16(std::int16_t* dst, const __m256 v) { + // Select bytes 0, 1, 4, 5, 8, 9, 12, 13 within each lane, effectively + // truncating each 16-bit integer. + const __m256i repack_perm = _mm256_set1_epi64x(0x0d0c090805040100); + const __m256i shuffled_v = _mm256_shuffle_epi8(v, repack_perm); + _mm_storeu_si64(dst, _mm256_extracti128_si256(shuffled_v, 0)); + _mm_storeu_si64(dst + 4, _mm256_extracti128_si256(shuffled_v, 1)); +} + +inline void mm256_n_storeu_epi32(std::int32_t* dst, int residual_rows, + const __m256 v) { + const __m128i v_low = _mm256_extracti128_si256(v, 0); + switch (residual_rows) { + case 0: + break; + case 1: + _mm_storeu_si32(dst, v_low); + break; + case 2: + _mm_storeu_si64(dst, v_low); + break; + case 3: { + __m128i trailing_packed = v_low; + _mm_storeu_si64(dst, trailing_packed); + dst[2] = _mm_extract_epi32(trailing_packed, 2); + break; + } + case 4: + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), v_low); + break; + case 5: + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), v_low); + dst[4] = _mm256_extract_epi32(v, 4); + break; + case 6: + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), v_low); + _mm_storeu_si64(dst + 4, _mm256_extracti128_si256(v, 1)); + break; + case 7: { + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), v_low); + __m128i trailing_packed = _mm256_extracti128_si256(v, 1); + _mm_storeu_si64(dst + 4, trailing_packed); + dst[6] = _mm_extract_epi32(trailing_packed, 2); + break; + } + case 8: + _mm256_storeu_si256(reinterpret_cast<__m256i*>(dst), v); + break; + default: + RUY_DCHECK_LE(residual_rows, 8); + break; + } +} + +inline void mm256_storeu_epi32(std::int32_t* dst, const __m256 v) { + _mm256_storeu_si256(reinterpret_cast<__m256i*>(dst), v); +} + +inline float mm256_get1_ps(const __m256 a, int i) { + __m256i ai = _mm256_castps_si256(a); + int float_val_as_int; + switch (i) { + case 0: + float_val_as_int = _mm256_extract_epi32(ai, 0); + break; + case 1: + float_val_as_int = _mm256_extract_epi32(ai, 1); + break; + case 2: + float_val_as_int = _mm256_extract_epi32(ai, 2); + break; + case 3: + float_val_as_int = _mm256_extract_epi32(ai, 3); + break; + case 4: + float_val_as_int = _mm256_extract_epi32(ai, 4); + break; + case 5: + float_val_as_int = _mm256_extract_epi32(ai, 5); + break; + case 6: + float_val_as_int = _mm256_extract_epi32(ai, 6); + break; + case 7: + float_val_as_int = _mm256_extract_epi32(ai, 7); + break; + default: + RUY_DCHECK_LT(i, 8); + return .0f; + } + return reinterpret_cast(float_val_as_int); +} + +inline __m256 mm256_n_loadu_ps(int i, const float* src) { + switch (i) { + case 0: + return _mm256_setzero_ps(); + case 1: + return _mm256_setr_m128(_mm_setr_ps(src[0], .0f, .0f, .0f), + _mm_setzero_ps()); + case 2: + return _mm256_setr_m128(_mm_setr_ps(src[0], src[1], .0f, .0f), + _mm_setzero_ps()); + case 3: + return _mm256_setr_m128(_mm_setr_ps(src[0], src[1], src[2], .0f), + _mm_setzero_ps()); + case 4: + return _mm256_setr_m128(_mm_setr_ps(src[0], src[1], src[2], src[3]), + _mm_setzero_ps()); + case 5: + return _mm256_setr_ps(src[0], src[1], src[2], src[3], src[4], .0f, .0f, + .0f); + case 6: + return _mm256_setr_ps(src[0], src[1], src[2], src[3], src[4], src[5], .0f, + .0f); + case 7: + return _mm256_setr_ps(src[0], src[1], src[2], src[3], src[4], src[5], + src[6], .0f); + case 8: + return _mm256_loadu_ps(src); + default: + RUY_DCHECK_LT(i, 9); + return _mm256_setzero_ps(); + } +} + +inline void mm256_n_storeu_ps(float* dst, int residual_rows, const __m256 v) { + for (int i = 0; i < residual_rows; ++i) { + dst[i] = intrin_utils::mm256_get1_ps(v, i); + } +} +} // namespace intrin_utils +} // namespace + +void Kernel8bitAvx2(const KernelParams8bit<8, 8>& params) { + profiler::ScopeLabel label("Kernel kAvx2 8-bit"); + const std::int8_t splitter_idx_data[32] = { + 0, 1, 4, 5, 8, 9, 12, 13, // + 2, 3, 6, 7, 10, 11, 14, 15, // + 0, 1, 4, 5, 8, 9, 12, 13, // + 2, 3, 6, 7, 10, 11, 14, 15 // + }; + + std::int32_t dst_stride; + if ((params.dst_type_id == DstTypeId::kValue) || + (params.dst_type_id == DstTypeId::kValue)) { + dst_stride = params.dst_stride; + } else if (params.dst_type_id == DstTypeId::kValue) { + dst_stride = params.dst_stride / sizeof(std::int16_t); + } else if (params.dst_type_id == DstTypeId::kValue) { + dst_stride = params.dst_stride / sizeof(std::int32_t); + } else { + RUY_DCHECK(false); + } + + int bias_ptr_block_increment = + params.flags & RUY_ASM_FLAG_HAS_BIAS ? kAvx8bitBlockSize : 0; + + const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; + void* dst_col_ptr = params.dst_base_ptr; + const std::int32_t* bias_col_ptr = params.bias; + if (params.flags & RUY_ASM_FLAG_HAS_BIAS) { + bias_col_ptr += params.start_row; + } + + for (int col = params.start_col; col <= params.last_col; + col += kAvx8bitBlockSize) { + const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; + void* dst_ptr = dst_col_ptr; + const std::int32_t* bias_ptr = bias_col_ptr; + + const std::int32_t lhs_zero_point = params.lhs_zero_point; + const bool has_rhs_sums_offsets = + (params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && lhs_zero_point; + std::int32_t rhs_sums_offsets[8]; + if (has_rhs_sums_offsets) { + const __m256i rhs_sums_offset_v = _mm256_mullo_epi32( + _mm256_set1_epi32(lhs_zero_point), + _mm256_loadu_si256( + reinterpret_cast<__m256i const*>(¶ms.rhs_sums[col]))); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_sums_offsets), + rhs_sums_offset_v); + } + + for (int row = params.start_row; row <= params.last_row; + row += kAvx8bitBlockSize) { + const int residual_rows = + std::min(params.dst_rows - row, kAvx8bitBlockSize); + const int residual_cols = + std::min(params.dst_cols - col, kAvx8bitBlockSize); + + const __m256i splitter_idx = _mm256_loadu_si256( + reinterpret_cast<__m256i const*>(splitter_idx_data)); + + __m256i accum_data_v0; + __m256i accum_data_v1; + __m256i accum_data_v2; + __m256i accum_data_v3; + __m256i accum_data_v4; + __m256i accum_data_v5; + __m256i accum_data_v6; + __m256i accum_data_v7; + + // Initialize with bias. + __m256i initial_accum_data = + intrin_utils::mm256_n_loadu_epi32(residual_rows, bias_ptr); + bias_ptr += bias_ptr_block_increment; + + // Adjustments common across columns. + const std::int32_t rhs_zero_point = params.rhs_zero_point; + if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && rhs_zero_point) { + const __m256i lhs_sums_offset = _mm256_mullo_epi32( + _mm256_set1_epi32(rhs_zero_point), + _mm256_loadu_si256( + reinterpret_cast<__m256i const*>(¶ms.lhs_sums[row]))); + initial_accum_data = + _mm256_sub_epi32(initial_accum_data, lhs_sums_offset); + } + const std::int32_t prod_zp_depth = params.prod_zp_depth; + if (prod_zp_depth) { + initial_accum_data = _mm256_add_epi32(initial_accum_data, + _mm256_set1_epi32(prod_zp_depth)); + } + + // Adjustments differing across columns. + if (has_rhs_sums_offsets) { + accum_data_v0 = _mm256_sub_epi32( + initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[0])); + accum_data_v1 = _mm256_sub_epi32( + initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[1])); + accum_data_v2 = _mm256_sub_epi32( + initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[2])); + accum_data_v3 = _mm256_sub_epi32( + initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[3])); + accum_data_v4 = _mm256_sub_epi32( + initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[4])); + accum_data_v5 = _mm256_sub_epi32( + initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[5])); + accum_data_v6 = _mm256_sub_epi32( + initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[6])); + accum_data_v7 = _mm256_sub_epi32( + initial_accum_data, _mm256_set1_epi32(rhs_sums_offsets[7])); + } else { + accum_data_v0 = initial_accum_data; + accum_data_v1 = initial_accum_data; + accum_data_v2 = initial_accum_data; + accum_data_v3 = initial_accum_data; + accum_data_v4 = initial_accum_data; + accum_data_v5 = initial_accum_data; + accum_data_v6 = initial_accum_data; + accum_data_v7 = initial_accum_data; + } + + const std::int8_t* lhs_ptr = lhs_col_ptr; + const std::int8_t* rhs_ptr = rhs_col_ptr; + for (int d = 0; d < params.depth; d += kAvx8bitInnerSize) { + const __m256i lhs_data = + _mm256_load_si256(reinterpret_cast(lhs_ptr)); + const __m256i rhs_data_8bit = + _mm256_load_si256(reinterpret_cast(rhs_ptr)); + + // Each "int32" is two 16-bit RHS values, sign extended from 8-bit. + std::int32_t rhs_data[16]; + const __m128i rhs_data_bottom_lane = + _mm256_castsi256_si128(rhs_data_8bit); + const __m128i rhs_data_top_lane = + _mm256_extracti128_si256(rhs_data_8bit, 1); + const __m256i rhs_16_bit_dup_low = + _mm256_cvtepi8_epi16(rhs_data_bottom_lane); + const __m256i rhs_16_bit_dup_high = + _mm256_cvtepi8_epi16(rhs_data_top_lane); + // Now that we have cast the RHS data, we store it so that each value + // can be separately loaded in the accumulation loop. + _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_data), + rhs_16_bit_dup_low); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_data + 8), + rhs_16_bit_dup_high); + + // NOTE: There may be opportunities for permuting the data in the + // packing code instead of here. + const __m256i lhs_data_split = + _mm256_shuffle_epi8(lhs_data, splitter_idx); + const __m256i lhs_data_split_expand_bottom = + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(lhs_data_split, 0)); + const __m256i lhs_data_split_expand_top = + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(lhs_data_split, 1)); + + // Take bytes 0, 1, 4, 5, 8, 9, ... expanded to 16-bit. + const __m256i lhs_16_bit_low = _mm256_permute2x128_si256( + lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x20); + // Take bytes 2, 3, 6, 7, 10, 11, ... expanded to 16-bit. + const __m256i lhs_16_bit_high = _mm256_permute2x128_si256( + lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x31); + // Accumulate for column 0. + { + const std::int32_t low_rhs_value = rhs_data[0]; + const std::int32_t high_rhs_value = rhs_data[1]; + + const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value); + const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value); + + accum_data_v0 = _mm256_add_epi32( + accum_data_v0, + _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); + accum_data_v0 = _mm256_add_epi32( + accum_data_v0, + _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); + } + // Accumulate for column 1. + { + const std::int32_t low_rhs_value = rhs_data[2]; + const std::int32_t high_rhs_value = rhs_data[3]; + + const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value); + const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value); + + accum_data_v1 = _mm256_add_epi32( + accum_data_v1, + _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); + accum_data_v1 = _mm256_add_epi32( + accum_data_v1, + _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); + } + // Accumulate for column 2. + { + const std::int32_t low_rhs_value = rhs_data[4]; + const std::int32_t high_rhs_value = rhs_data[5]; + + const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value); + const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value); + + accum_data_v2 = _mm256_add_epi32( + accum_data_v2, + _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); + accum_data_v2 = _mm256_add_epi32( + accum_data_v2, + _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); + } + // Accumulate for column 3. + { + const std::int32_t low_rhs_value = rhs_data[6]; + const std::int32_t high_rhs_value = rhs_data[7]; + + const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value); + const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value); + + accum_data_v3 = _mm256_add_epi32( + accum_data_v3, + _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); + accum_data_v3 = _mm256_add_epi32( + accum_data_v3, + _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); + } + // Accumulate for column 4. + { + const std::int32_t low_rhs_value = rhs_data[8]; + const std::int32_t high_rhs_value = rhs_data[9]; + + const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value); + const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value); + + accum_data_v4 = _mm256_add_epi32( + accum_data_v4, + _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); + accum_data_v4 = _mm256_add_epi32( + accum_data_v4, + _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); + } + // Accumulate for column 5. + { + const std::int32_t low_rhs_value = rhs_data[10]; + const std::int32_t high_rhs_value = rhs_data[11]; + + const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value); + const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value); + + accum_data_v5 = _mm256_add_epi32( + accum_data_v5, + _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); + accum_data_v5 = _mm256_add_epi32( + accum_data_v5, + _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); + } + // Accumulate for column 6. + { + const std::int32_t low_rhs_value = rhs_data[12]; + const std::int32_t high_rhs_value = rhs_data[13]; + + const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value); + const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value); + + accum_data_v6 = _mm256_add_epi32( + accum_data_v6, + _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); + accum_data_v6 = _mm256_add_epi32( + accum_data_v6, + _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); + } + // Accumulate for column 7. + { + const std::int32_t low_rhs_value = rhs_data[14]; + const std::int32_t high_rhs_value = rhs_data[15]; + + const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value); + const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value); + + accum_data_v7 = _mm256_add_epi32( + accum_data_v7, + _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); + accum_data_v7 = _mm256_add_epi32( + accum_data_v7, + _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); + } + + lhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize; + rhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize; + } + + if (params.dst_type_id != DstTypeId::kValue) { + __m256i m_vector; + __m256i e_vector; + // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT. + if (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) { + m_vector = intrin_utils::mm256_n_loadu_epi32( + residual_rows, ¶ms.multiplier_fixedpoint[row]); + e_vector = intrin_utils::mm256_n_loadu_epi32( + residual_rows, ¶ms.multiplier_exponent[row]); + } else { + // These arrays have size LhsCols, and are pre-filled. + m_vector = _mm256_set1_epi32(params.multiplier_fixedpoint[0]); + e_vector = _mm256_set1_epi32(params.multiplier_exponent[0]); + } + + const __m256i m_64bit_low = + _mm256_cvtepi32_epi64(_mm256_extracti128_si256(m_vector, 0)); + const __m256i m_64bit_high = + _mm256_cvtepi32_epi64(_mm256_extracti128_si256(m_vector, 1)); + + const __m256i zero_vector = _mm256_setzero_si256(); + const __m256i left_shift = _mm256_max_epi32(e_vector, zero_vector); + const __m256i neg_e_vector = _mm256_sub_epi32(zero_vector, e_vector); + const __m256i right_shift = _mm256_max_epi32(neg_e_vector, zero_vector); + const __m256i final_right_shift = + _mm256_add_epi32(right_shift, _mm256_set1_epi32(31)); + const __m256i final_right_shift_low = _mm256_cvtepi32_epi64( + _mm256_extracti128_si256(final_right_shift, 0)); + const __m256i final_right_shift_high = _mm256_cvtepi32_epi64( + _mm256_extracti128_si256(final_right_shift, 1)); + // Really we want 0x100000000, but use half to avoid overflowing. + const __m256i convert_to_signed_halved = + _mm256_srlv_epi32(_mm256_set1_epi32(0x80000000), right_shift); + const __m256i convert_to_unsigned_64 = + _mm256_set1_epi64x(0x8000000000000000); + + __m256i post_scaling_offset = _mm256_add_epi32( + convert_to_signed_halved, convert_to_signed_halved); + + const __m256i offset_vector = + _mm256_slli_epi64(_mm256_set1_epi64x(1), 30); + // Really these should be shifted by neg_e_vector, but tests pass when + // using right_shift. + const __m256i offset_vector_low = _mm256_add_epi64( + _mm256_sllv_epi64(offset_vector, + _mm256_cvtepi32_epi64( + _mm256_extracti128_si256(right_shift, 0))), + convert_to_unsigned_64); + const __m256i offset_vector_high = _mm256_add_epi64( + _mm256_sllv_epi64(offset_vector, + _mm256_cvtepi32_epi64( + _mm256_extracti128_si256(right_shift, 1))), + convert_to_unsigned_64); + + if (params.dst_zero_point) { + const __m256i dst_zero_point = + _mm256_set1_epi32(params.dst_zero_point); + // The post-scaling offset is subtracted later, so this has the effect + // of adding the zero point. + post_scaling_offset = + _mm256_sub_epi32(post_scaling_offset, dst_zero_point); + } + +#if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING) + RUY_DCHECK(false); +#endif + const __m256i repack_perm = _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7); + + // We cannot do + // + // scaled_v_low = + // _mm256_srav_epi64(scaled_v_low, final_right_shift_low); + // scaled_v_high = + // _mm256_srav_epi64(scaled_v_high, final_right_shift_high); + // + // since this instruction is not in AVX2. Instead we use + // _mm256_srlv_epi64, but this is an unsigned shift, so we applied + // offsets before (convert_to_unsigned_64) and after + // (convert_to_signed_halved). + // + // The overall process is, for 64-bit scaled accumulator: + // unsigned_accum = signed_accum + 1 << 63; + // unsigned_accum = (unsigned_accum >> right_shift) >> 31; + // signed_accum = unsigned_accum - ((1 << 32) >> right_shift) / 2 * 2; + + // There are various ways to repack the results, in the absence of + // _mm256_cvtepi64_epi32() or anything like it. + // A. + // accum_data_v[j] = + // _mm256_set_epi32(_mm256_extract_epi32(scaled_v_high, 6), + // _mm256_extract_epi32(scaled_v_high, 4), + // _mm256_extract_epi32(scaled_v_high, 2), + // _mm256_extract_epi32(scaled_v_high, 0), + // _mm256_extract_epi32(scaled_v_low, 6), + // _mm256_extract_epi32(scaled_v_low, 4), + // _mm256_extract_epi32(scaled_v_low, 2), + // _mm256_extract_epi32(scaled_v_low, 0)); + // B. + // scaled_v_low = _mm256_shuffle_epi32(scaled_v_low, 0xd8); + // scaled_v_high = _mm256_shuffle_epi32(scaled_v_high, 0xd8); + // accum_data_v[j] = + // _mm256_set_epi64x(_mm256_extract_epi64(scaled_v_high, 2), + // _mm256_extract_epi64(scaled_v_high, 0), + // _mm256_extract_epi64(scaled_v_low, 2), + // _mm256_extract_epi64(scaled_v_low, 0)); + // C. + // scaled_v_low = + // _mm256_permutevar8x32_epi32(scaled_v_low, repack_perm); + // scaled_v_high = + // _mm256_permutevar8x32_epi32(scaled_v_high, repack_perm); + // accum_data_v[j] = + // _mm256_permute2x128_si256(scaled_v_low, scaled_v_high, 0x20); + // + // However, we choose the following because it uses two lighter + // instructions. The permutation does have a longer latency, but this + // loop can be unrolled. + // D. + // scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32); + // __m256i results = + // _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa); + // results = _mm256_permutevar8x32_epi32(results, repack_perm); + // accum_data_v[j] = _mm256_sub_epi32(results, post_scaling_offset); + { + __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v0, left_shift); + // Apply the fixed-point part of the multiplier. + __m256i scaled_v_low = _mm256_mul_epi32( + _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)), + m_64bit_low); + __m256i scaled_v_high = _mm256_mul_epi32( + _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)), + m_64bit_high); + + scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector_low); + scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector_high); + + scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low); + scaled_v_high = + _mm256_srlv_epi64(scaled_v_high, final_right_shift_high); + + scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32); + __m256i results = + _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa); + results = _mm256_permutevar8x32_epi32(results, repack_perm); + + accum_data_v0 = _mm256_sub_epi32(results, post_scaling_offset); + } + { + __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v1, left_shift); + // Apply the fixed-point part of the multiplier. + __m256i scaled_v_low = _mm256_mul_epi32( + _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)), + m_64bit_low); + __m256i scaled_v_high = _mm256_mul_epi32( + _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)), + m_64bit_high); + + scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector_low); + scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector_high); + + scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low); + scaled_v_high = + _mm256_srlv_epi64(scaled_v_high, final_right_shift_high); + + scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32); + __m256i results = + _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa); + results = _mm256_permutevar8x32_epi32(results, repack_perm); + + accum_data_v1 = _mm256_sub_epi32(results, post_scaling_offset); + } + { + __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v2, left_shift); + // Apply the fixed-point part of the multiplier. + __m256i scaled_v_low = _mm256_mul_epi32( + _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)), + m_64bit_low); + __m256i scaled_v_high = _mm256_mul_epi32( + _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)), + m_64bit_high); + + scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector_low); + scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector_high); + + scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low); + scaled_v_high = + _mm256_srlv_epi64(scaled_v_high, final_right_shift_high); + + scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32); + __m256i results = + _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa); + results = _mm256_permutevar8x32_epi32(results, repack_perm); + + accum_data_v2 = _mm256_sub_epi32(results, post_scaling_offset); + } + { + __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v3, left_shift); + // Apply the fixed-point part of the multiplier. + __m256i scaled_v_low = _mm256_mul_epi32( + _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)), + m_64bit_low); + __m256i scaled_v_high = _mm256_mul_epi32( + _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)), + m_64bit_high); + + scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector_low); + scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector_high); + + scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low); + scaled_v_high = + _mm256_srlv_epi64(scaled_v_high, final_right_shift_high); + + scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32); + __m256i results = + _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa); + results = _mm256_permutevar8x32_epi32(results, repack_perm); + + accum_data_v3 = _mm256_sub_epi32(results, post_scaling_offset); + } + { + __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v4, left_shift); + // Apply the fixed-point part of the multiplier. + __m256i scaled_v_low = _mm256_mul_epi32( + _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)), + m_64bit_low); + __m256i scaled_v_high = _mm256_mul_epi32( + _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)), + m_64bit_high); + + scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector_low); + scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector_high); + + scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low); + scaled_v_high = + _mm256_srlv_epi64(scaled_v_high, final_right_shift_high); + + scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32); + __m256i results = + _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa); + results = _mm256_permutevar8x32_epi32(results, repack_perm); + + accum_data_v4 = _mm256_sub_epi32(results, post_scaling_offset); + } + { + __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v5, left_shift); + // Apply the fixed-point part of the multiplier. + __m256i scaled_v_low = _mm256_mul_epi32( + _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)), + m_64bit_low); + __m256i scaled_v_high = _mm256_mul_epi32( + _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)), + m_64bit_high); + + scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector_low); + scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector_high); + + scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low); + scaled_v_high = + _mm256_srlv_epi64(scaled_v_high, final_right_shift_high); + + scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32); + __m256i results = + _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa); + results = _mm256_permutevar8x32_epi32(results, repack_perm); + + accum_data_v5 = _mm256_sub_epi32(results, post_scaling_offset); + } + { + __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v6, left_shift); + // Apply the fixed-point part of the multiplier. + __m256i scaled_v_low = _mm256_mul_epi32( + _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)), + m_64bit_low); + __m256i scaled_v_high = _mm256_mul_epi32( + _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)), + m_64bit_high); + + scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector_low); + scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector_high); + + scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low); + scaled_v_high = + _mm256_srlv_epi64(scaled_v_high, final_right_shift_high); + + scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32); + __m256i results = + _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa); + results = _mm256_permutevar8x32_epi32(results, repack_perm); + + accum_data_v6 = _mm256_sub_epi32(results, post_scaling_offset); + } + { + __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v7, left_shift); + // Apply the fixed-point part of the multiplier. + __m256i scaled_v_low = _mm256_mul_epi32( + _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)), + m_64bit_low); + __m256i scaled_v_high = _mm256_mul_epi32( + _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)), + m_64bit_high); + + scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector_low); + scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector_high); + + scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low); + scaled_v_high = + _mm256_srlv_epi64(scaled_v_high, final_right_shift_high); + + scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32); + __m256i results = + _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa); + results = _mm256_permutevar8x32_epi32(results, repack_perm); + + accum_data_v7 = _mm256_sub_epi32(results, post_scaling_offset); + } + } + const __m256i clamp_max_v = _mm256_set1_epi32(params.clamp_max); + const __m256i clamp_min_v = _mm256_set1_epi32(params.clamp_min); + const bool store_full_block = (residual_rows == kAvx8bitBlockSize) && + (residual_cols == kAvx8bitBlockSize); + + __m256i accum_data_v[kAvx8bitBlockSize]; + if (!store_full_block) { + accum_data_v[0] = accum_data_v0; + accum_data_v[1] = accum_data_v1; + accum_data_v[2] = accum_data_v2; + accum_data_v[3] = accum_data_v3; + accum_data_v[4] = accum_data_v4; + accum_data_v[5] = accum_data_v5; + accum_data_v[6] = accum_data_v6; + accum_data_v[7] = accum_data_v7; + } + + if (params.dst_type_id == DstTypeId::kValue) { + std::int8_t* tmp_ptr = static_cast(dst_ptr); + if (store_full_block) { + accum_data_v0 = _mm256_min_epi32(accum_data_v0, clamp_max_v); + accum_data_v0 = _mm256_max_epi32(accum_data_v0, clamp_min_v); + accum_data_v1 = _mm256_min_epi32(accum_data_v1, clamp_max_v); + accum_data_v1 = _mm256_max_epi32(accum_data_v1, clamp_min_v); + accum_data_v2 = _mm256_min_epi32(accum_data_v2, clamp_max_v); + accum_data_v2 = _mm256_max_epi32(accum_data_v2, clamp_min_v); + accum_data_v3 = _mm256_min_epi32(accum_data_v3, clamp_max_v); + accum_data_v3 = _mm256_max_epi32(accum_data_v3, clamp_min_v); + accum_data_v4 = _mm256_min_epi32(accum_data_v4, clamp_max_v); + accum_data_v4 = _mm256_max_epi32(accum_data_v4, clamp_min_v); + accum_data_v5 = _mm256_min_epi32(accum_data_v5, clamp_max_v); + accum_data_v5 = _mm256_max_epi32(accum_data_v5, clamp_min_v); + accum_data_v6 = _mm256_min_epi32(accum_data_v6, clamp_max_v); + accum_data_v6 = _mm256_max_epi32(accum_data_v6, clamp_min_v); + accum_data_v7 = _mm256_min_epi32(accum_data_v7, clamp_max_v); + accum_data_v7 = _mm256_max_epi32(accum_data_v7, clamp_min_v); + intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[0 * dst_stride], + accum_data_v0); + intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[1 * dst_stride], + accum_data_v1); + intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[2 * dst_stride], + accum_data_v2); + intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[3 * dst_stride], + accum_data_v3); + intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[4 * dst_stride], + accum_data_v4); + intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[5 * dst_stride], + accum_data_v5); + intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[6 * dst_stride], + accum_data_v6); + intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[7 * dst_stride], + accum_data_v7); + } else { + for (int j = 0; j < residual_cols; ++j) { + __m256 result = accum_data_v[j]; + result = _mm256_min_epi32(result, clamp_max_v); + result = _mm256_max_epi32(result, clamp_min_v); + intrin_utils::mm256_n_storeu_cvtepi32_epi8(tmp_ptr, residual_rows, + result); + tmp_ptr += dst_stride; + } + } + dst_ptr = static_cast(static_cast(dst_ptr) + + kAvx8bitBlockSize); + } else if (params.dst_type_id == DstTypeId::kValue) { + std::uint8_t* tmp_ptr = static_cast(dst_ptr); + if (store_full_block) { + accum_data_v0 = _mm256_min_epi32(accum_data_v0, clamp_max_v); + accum_data_v0 = _mm256_max_epi32(accum_data_v0, clamp_min_v); + accum_data_v1 = _mm256_min_epi32(accum_data_v1, clamp_max_v); + accum_data_v1 = _mm256_max_epi32(accum_data_v1, clamp_min_v); + accum_data_v2 = _mm256_min_epi32(accum_data_v2, clamp_max_v); + accum_data_v2 = _mm256_max_epi32(accum_data_v2, clamp_min_v); + accum_data_v3 = _mm256_min_epi32(accum_data_v3, clamp_max_v); + accum_data_v3 = _mm256_max_epi32(accum_data_v3, clamp_min_v); + accum_data_v4 = _mm256_min_epi32(accum_data_v4, clamp_max_v); + accum_data_v4 = _mm256_max_epi32(accum_data_v4, clamp_min_v); + accum_data_v5 = _mm256_min_epi32(accum_data_v5, clamp_max_v); + accum_data_v5 = _mm256_max_epi32(accum_data_v5, clamp_min_v); + accum_data_v6 = _mm256_min_epi32(accum_data_v6, clamp_max_v); + accum_data_v6 = _mm256_max_epi32(accum_data_v6, clamp_min_v); + accum_data_v7 = _mm256_min_epi32(accum_data_v7, clamp_max_v); + accum_data_v7 = _mm256_max_epi32(accum_data_v7, clamp_min_v); + intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[0], accum_data_v0); + intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[dst_stride], + accum_data_v1); + intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[2 * dst_stride], + accum_data_v2); + intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[3 * dst_stride], + accum_data_v3); + intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[4 * dst_stride], + accum_data_v4); + intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[5 * dst_stride], + accum_data_v5); + intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[6 * dst_stride], + accum_data_v6); + intrin_utils::mm256_storeu_cvtepi32_epi8(&tmp_ptr[7 * dst_stride], + accum_data_v7); + } else { + for (int j = 0; j < residual_cols; ++j) { + __m256 result = accum_data_v[j]; + result = _mm256_min_epi32(result, clamp_max_v); + result = _mm256_max_epi32(result, clamp_min_v); + intrin_utils::mm256_n_storeu_cvtepi32_epi8(tmp_ptr, residual_rows, + result); + tmp_ptr += dst_stride; + } + } + dst_ptr = static_cast(static_cast(dst_ptr) + + kAvx8bitBlockSize); + } else if (params.dst_type_id == DstTypeId::kValue) { + std::int16_t* tmp_ptr = static_cast(dst_ptr); + if (store_full_block) { + accum_data_v0 = _mm256_min_epi32(accum_data_v0, clamp_max_v); + accum_data_v0 = _mm256_max_epi32(accum_data_v0, clamp_min_v); + accum_data_v1 = _mm256_min_epi32(accum_data_v1, clamp_max_v); + accum_data_v1 = _mm256_max_epi32(accum_data_v1, clamp_min_v); + accum_data_v2 = _mm256_min_epi32(accum_data_v2, clamp_max_v); + accum_data_v2 = _mm256_max_epi32(accum_data_v2, clamp_min_v); + accum_data_v3 = _mm256_min_epi32(accum_data_v3, clamp_max_v); + accum_data_v3 = _mm256_max_epi32(accum_data_v3, clamp_min_v); + accum_data_v4 = _mm256_min_epi32(accum_data_v4, clamp_max_v); + accum_data_v4 = _mm256_max_epi32(accum_data_v4, clamp_min_v); + accum_data_v5 = _mm256_min_epi32(accum_data_v5, clamp_max_v); + accum_data_v5 = _mm256_max_epi32(accum_data_v5, clamp_min_v); + accum_data_v6 = _mm256_min_epi32(accum_data_v6, clamp_max_v); + accum_data_v6 = _mm256_max_epi32(accum_data_v6, clamp_min_v); + accum_data_v7 = _mm256_min_epi32(accum_data_v7, clamp_max_v); + accum_data_v7 = _mm256_max_epi32(accum_data_v7, clamp_min_v); + intrin_utils::mm256_storeu_cvtepi32_epi16(&tmp_ptr[0], accum_data_v0); + intrin_utils::mm256_storeu_cvtepi32_epi16(&tmp_ptr[dst_stride], + accum_data_v1); + intrin_utils::mm256_storeu_cvtepi32_epi16(&tmp_ptr[2 * dst_stride], + accum_data_v2); + intrin_utils::mm256_storeu_cvtepi32_epi16(&tmp_ptr[3 * dst_stride], + accum_data_v3); + intrin_utils::mm256_storeu_cvtepi32_epi16(&tmp_ptr[4 * dst_stride], + accum_data_v4); + intrin_utils::mm256_storeu_cvtepi32_epi16(&tmp_ptr[5 * dst_stride], + accum_data_v5); + intrin_utils::mm256_storeu_cvtepi32_epi16(&tmp_ptr[6 * dst_stride], + accum_data_v6); + intrin_utils::mm256_storeu_cvtepi32_epi16(&tmp_ptr[7 * dst_stride], + accum_data_v7); + } else { + for (int j = 0; j < residual_cols; ++j) { + __m256 result = accum_data_v[j]; + result = _mm256_min_epi32(result, clamp_max_v); + result = _mm256_max_epi32(result, clamp_min_v); + intrin_utils::mm256_n_storeu_cvtepi32_epi16(tmp_ptr, residual_rows, + result); + tmp_ptr += dst_stride; + } + } + dst_ptr = static_cast(static_cast(dst_ptr) + + kAvx8bitBlockSize); + } else if (params.dst_type_id == DstTypeId::kValue) { + if (store_full_block) { + std::int32_t* tmp_ptr = static_cast(dst_ptr); + intrin_utils::mm256_storeu_epi32(&tmp_ptr[0], accum_data_v0); + intrin_utils::mm256_storeu_epi32(&tmp_ptr[dst_stride], accum_data_v1); + intrin_utils::mm256_storeu_epi32(&tmp_ptr[2 * dst_stride], + accum_data_v2); + intrin_utils::mm256_storeu_epi32(&tmp_ptr[3 * dst_stride], + accum_data_v3); + intrin_utils::mm256_storeu_epi32(&tmp_ptr[4 * dst_stride], + accum_data_v4); + intrin_utils::mm256_storeu_epi32(&tmp_ptr[5 * dst_stride], + accum_data_v5); + intrin_utils::mm256_storeu_epi32(&tmp_ptr[6 * dst_stride], + accum_data_v6); + intrin_utils::mm256_storeu_epi32(&tmp_ptr[7 * dst_stride], + accum_data_v7); + } else { + std::int32_t* dst_block_ptr = static_cast(dst_ptr); + for (int j = 0; j < residual_cols; ++j) { + intrin_utils::mm256_n_storeu_epi32(dst_block_ptr, residual_rows, + accum_data_v[j]); + dst_block_ptr += dst_stride; + } + } + dst_ptr = static_cast(static_cast(dst_ptr) + + kAvx8bitBlockSize); + } else { + RUY_DCHECK(false); + } + + lhs_col_ptr += kAvx8bitBlockSize * params.lhs_stride; + } // End row-block loop. + + dst_col_ptr = static_cast(static_cast(dst_col_ptr) + + kAvx8bitBlockSize * params.dst_stride); + rhs_col_ptr += kAvx8bitBlockSize * params.rhs_stride; + } // End col-block loop. +} // NOLINT(readability/fn_size) + +void Kernel8bitAvx2SingleCol(const KernelParams8bit<8, 8>& params) { + profiler::ScopeLabel label("Kernel kAvx2 8-bit GEMV"); + + RUY_DCHECK_EQ(params.dst_cols, 1); + RUY_DCHECK_EQ(params.last_col, 0); + RUY_DCHECK_EQ(params.start_col, 0); + + const std::int8_t splitter_idx_data[32] = { + 0, 1, 4, 5, 8, 9, 12, 13, // + 2, 3, 6, 7, 10, 11, 14, 15, // + 0, 1, 4, 5, 8, 9, 12, 13, // + 2, 3, 6, 7, 10, 11, 14, 15 // + }; + + int bias_ptr_block_increment = + params.flags & RUY_ASM_FLAG_HAS_BIAS ? kAvx8bitBlockSize : 0; + + const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; + void* dst_col_ptr = params.dst_base_ptr; + const std::int32_t* bias_col_ptr = params.bias; + if (params.flags & RUY_ASM_FLAG_HAS_BIAS) { + bias_col_ptr += params.start_row; + } + + const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; + void* dst_ptr = dst_col_ptr; + const std::int32_t* bias_ptr = bias_col_ptr; + + const std::int32_t lhs_zero_point = params.lhs_zero_point; + const bool has_rhs_sums_offsets = + (params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && lhs_zero_point; + std::int32_t rhs_sums_offsets[8]; + if (has_rhs_sums_offsets) { + const __m256i rhs_sums_offset_v = _mm256_mullo_epi32( + _mm256_set1_epi32(lhs_zero_point), + _mm256_loadu_si256( + reinterpret_cast<__m256i const*>(¶ms.rhs_sums[0]))); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_sums_offsets), + rhs_sums_offset_v); + } + + for (int row = params.start_row; row <= params.last_row; + row += kAvx8bitBlockSize) { + const int residual_rows = + std::min(params.dst_rows - row, kAvx8bitBlockSize); + + const __m256i splitter_idx = + _mm256_loadu_si256(reinterpret_cast<__m256i const*>(splitter_idx_data)); + + __m256i accum_data_v0; + + // Initialize with bias. + __m256i initial_accum_data = + intrin_utils::mm256_n_loadu_epi32(residual_rows, bias_ptr); + bias_ptr += bias_ptr_block_increment; + + // Adjustments common across columns. + const std::int32_t rhs_zero_point = params.rhs_zero_point; + if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && rhs_zero_point) { + const __m256i lhs_sums_offset = _mm256_mullo_epi32( + _mm256_set1_epi32(rhs_zero_point), + _mm256_loadu_si256( + reinterpret_cast<__m256i const*>(¶ms.lhs_sums[row]))); + initial_accum_data = + _mm256_sub_epi32(initial_accum_data, lhs_sums_offset); + } + const std::int32_t prod_zp_depth = params.prod_zp_depth; + if (prod_zp_depth) { + initial_accum_data = _mm256_add_epi32(initial_accum_data, + _mm256_set1_epi32(prod_zp_depth)); + } + + // Adjustments differing across columns. + if (has_rhs_sums_offsets) { + accum_data_v0 = _mm256_sub_epi32(initial_accum_data, + _mm256_set1_epi32(rhs_sums_offsets[0])); + } else { + accum_data_v0 = initial_accum_data; + } + + const std::int8_t* lhs_ptr = lhs_col_ptr; + const std::int8_t* rhs_ptr = rhs_col_ptr; + for (int d = 0; d < params.depth; d += kAvx8bitInnerSize) { + const __m256i lhs_data = + _mm256_load_si256(reinterpret_cast(lhs_ptr)); + const __m128i rhs_data_8bit = _mm_loadu_si32(rhs_ptr); + + // Each "int32" is two 16-bit RHS values, sign extended from 8-bit. + // For simplicity we load 4x the data that we need and process twice the + // data that we need and store only the data we need. + std::int32_t rhs_data[2]; + const __m128i rhs_16_bit_dup = _mm_cvtepi8_epi16(rhs_data_8bit); + // Now that we have cast the RHS data, we store it so that each value + // can be separately loaded in the accumulation loop. + _mm_storeu_si64(reinterpret_cast<__m128i*>(rhs_data), rhs_16_bit_dup); + + // NOTE: There may be opportunities for permuting the data in the packing + // code instead of here. + const __m256i lhs_data_split = + _mm256_shuffle_epi8(lhs_data, splitter_idx); + const __m256i lhs_data_split_expand_bottom = + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(lhs_data_split, 0)); + const __m256i lhs_data_split_expand_top = + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(lhs_data_split, 1)); + + // Take bytes 0, 1, 4, 5, 8, 9, ... expanded to 16-bit. + const __m256i lhs_16_bit_low = _mm256_permute2x128_si256( + lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x20); + // Take bytes 2, 3, 6, 7, 10, 11, ... expanded to 16-bit. + const __m256i lhs_16_bit_high = _mm256_permute2x128_si256( + lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x31); + // Accumulate for column 0. + const std::int32_t low_rhs_value = rhs_data[0]; + const std::int32_t high_rhs_value = rhs_data[1]; + + const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value); + const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value); + + accum_data_v0 = _mm256_add_epi32( + accum_data_v0, _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); + accum_data_v0 = _mm256_add_epi32( + accum_data_v0, + _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); + + lhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize; + rhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize; + } + + if (params.dst_type_id != DstTypeId::kValue) { + __m256i m_vector; + __m256i e_vector; + // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT. + if (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) { + m_vector = intrin_utils::mm256_n_loadu_epi32( + residual_rows, ¶ms.multiplier_fixedpoint[row]); + e_vector = intrin_utils::mm256_n_loadu_epi32( + residual_rows, ¶ms.multiplier_exponent[row]); + } else { + // These arrays have size LhsCols, and are pre-filled. + m_vector = _mm256_set1_epi32(params.multiplier_fixedpoint[0]); + e_vector = _mm256_set1_epi32(params.multiplier_exponent[0]); + } + + const __m256i m_64bit_low = + _mm256_cvtepi32_epi64(_mm256_extracti128_si256(m_vector, 0)); + const __m256i m_64bit_high = + _mm256_cvtepi32_epi64(_mm256_extracti128_si256(m_vector, 1)); + + const __m256i zero_vector = _mm256_setzero_si256(); + const __m256i left_shift = _mm256_max_epi32(e_vector, zero_vector); + const __m256i neg_e_vector = _mm256_sub_epi32(zero_vector, e_vector); + const __m256i right_shift = _mm256_max_epi32(neg_e_vector, zero_vector); + const __m256i final_right_shift = + _mm256_add_epi32(right_shift, _mm256_set1_epi32(31)); + const __m256i final_right_shift_low = + _mm256_cvtepi32_epi64(_mm256_extracti128_si256(final_right_shift, 0)); + const __m256i final_right_shift_high = + _mm256_cvtepi32_epi64(_mm256_extracti128_si256(final_right_shift, 1)); + // Really we want 0x100000000, but use half to avoid overflowing. + const __m256i convert_to_signed_halved = + _mm256_srlv_epi32(_mm256_set1_epi32(0x80000000), right_shift); + const __m256i convert_to_unsigned_64 = + _mm256_set1_epi64x(0x8000000000000000); + + __m256i post_scaling_offset = + _mm256_add_epi32(convert_to_signed_halved, convert_to_signed_halved); + + const __m256i offset_vector = + _mm256_slli_epi64(_mm256_set1_epi64x(1), 30); + // Really these should be shifted by neg_e_vector, but tests pass when + // using right_shift. + const __m256i offset_vector_low = _mm256_add_epi64( + _mm256_sllv_epi64( + offset_vector, + _mm256_cvtepi32_epi64(_mm256_extracti128_si256(right_shift, 0))), + convert_to_unsigned_64); + const __m256i offset_vector_high = _mm256_add_epi64( + _mm256_sllv_epi64( + offset_vector, + _mm256_cvtepi32_epi64(_mm256_extracti128_si256(right_shift, 1))), + convert_to_unsigned_64); + + if (params.dst_zero_point) { + const __m256i dst_zero_point = _mm256_set1_epi32(params.dst_zero_point); + // The post-scaling offset is subtracted later, so this has the effect + // of adding the zero point. + post_scaling_offset = + _mm256_sub_epi32(post_scaling_offset, dst_zero_point); + } + +#if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING) + RUY_DCHECK(false); +#endif + const __m256i repack_perm = _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7); + + // See GEMM version for details of this process. + { + __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v0, left_shift); + // Apply the fixed-point part of the multiplier. + __m256i scaled_v_low = _mm256_mul_epi32( + _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)), + m_64bit_low); + __m256i scaled_v_high = _mm256_mul_epi32( + _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)), + m_64bit_high); + + scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector_low); + scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector_high); + + scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low); + scaled_v_high = + _mm256_srlv_epi64(scaled_v_high, final_right_shift_high); + + scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32); + __m256i results = _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa); + results = _mm256_permutevar8x32_epi32(results, repack_perm); + + accum_data_v0 = _mm256_sub_epi32(results, post_scaling_offset); + } + } + const __m256i clamp_max_v = _mm256_set1_epi32(params.clamp_max); + const __m256i clamp_min_v = _mm256_set1_epi32(params.clamp_min); + + if (params.dst_type_id == DstTypeId::kValue) { + std::int8_t* tmp_ptr = static_cast(dst_ptr); + __m256 result = accum_data_v0; + result = _mm256_min_epi32(result, clamp_max_v); + result = _mm256_max_epi32(result, clamp_min_v); + intrin_utils::mm256_n_storeu_cvtepi32_epi8(tmp_ptr, residual_rows, + result); + dst_ptr = static_cast(static_cast(dst_ptr) + + kAvx8bitBlockSize); + } else if (params.dst_type_id == DstTypeId::kValue) { + std::uint8_t* tmp_ptr = static_cast(dst_ptr); + __m256 result = accum_data_v0; + result = _mm256_min_epi32(result, clamp_max_v); + result = _mm256_max_epi32(result, clamp_min_v); + intrin_utils::mm256_n_storeu_cvtepi32_epi8(tmp_ptr, residual_rows, + result); + dst_ptr = static_cast(static_cast(dst_ptr) + + kAvx8bitBlockSize); + } else if (params.dst_type_id == DstTypeId::kValue) { + std::int16_t* tmp_ptr = static_cast(dst_ptr); + __m256 result = accum_data_v0; + result = _mm256_min_epi32(result, clamp_max_v); + result = _mm256_max_epi32(result, clamp_min_v); + intrin_utils::mm256_n_storeu_cvtepi32_epi16(tmp_ptr, residual_rows, + result); + dst_ptr = static_cast(static_cast(dst_ptr) + + kAvx8bitBlockSize); + } else if (params.dst_type_id == DstTypeId::kValue) { + std::int32_t* dst_block_ptr = static_cast(dst_ptr); + intrin_utils::mm256_n_storeu_epi32(dst_block_ptr, residual_rows, + accum_data_v0); + dst_ptr = static_cast(static_cast(dst_ptr) + + kAvx8bitBlockSize); + } else { + RUY_DCHECK(false); + } + + lhs_col_ptr += kAvx8bitBlockSize * params.lhs_stride; + } // End row-block loop. + + dst_col_ptr = static_cast(static_cast(dst_col_ptr) + + kAvx8bitBlockSize * params.dst_stride); + rhs_col_ptr += kAvx8bitBlockSize * params.rhs_stride; +} // NOLINT(readability/fn_size) + +void KernelFloatAvx2(const KernelParamsFloat<8, 8>& params) { + profiler::ScopeLabel label("Kernel kAvx2 float"); + + // As parameters are defined, we need to scale by sizeof(float). + const std::int64_t lhs_stride = params.lhs_stride >> 2; + const std::int64_t dst_stride = params.dst_stride >> 2; + const std::int64_t rhs_stride = params.rhs_stride >> 2; + // + int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 1 : 0; + // AVX2 float block size = 8. + const int end_row = std::min(params.dst_rows, params.last_row + 8); + const int end_col = std::min(params.dst_cols, params.last_col + 8); + // + const float* adj_rhs_col_ptr = + params.rhs_base_ptr - params.start_col * rhs_stride; + float* adj_dst_col_ptr = + params.dst_base_ptr - params.start_col * dst_stride - params.start_row; + const float* adj_lhs_col_ptr = + params.lhs_base_ptr - params.start_row * lhs_stride; + const float* bias_col_ptr = params.bias; + + const __m256 clamp_max_v = _mm256_set1_ps(params.clamp_max); + const __m256 clamp_min_v = _mm256_set1_ps(params.clamp_min); + + int col = params.start_col; + // Loop through cols by float block size, leaving incomplete remainder + for (; col <= end_col - 8; col += 8) { + __m256 accum_data_v[8]; + + const float* rhs_col_ptr = adj_rhs_col_ptr + col * rhs_stride; + float* dst_col_ptr = adj_dst_col_ptr + col * dst_stride; + + for (int row = params.start_row; row < end_row; row += 8) { + const int residual_rows = std::min(end_row - row, 8); + + const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; + float* dst_ptr = dst_col_ptr + row; + const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment; + + // Initialize with bias. + const __m256 initial_accum_data = + intrin_utils::mm256_n_loadu_ps(residual_rows, bias_ptr); + + for (int j = 0; j < 8; ++j) { + accum_data_v[j] = initial_accum_data; + } + + const float* lhs_ptr = lhs_col_ptr; + const float* rhs_ptr = rhs_col_ptr; + for (int d = 0; d < params.depth; ++d) { + const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr); + // In this version RHS values are loaded individually rather than first + // loading together and then extract with broadcasting. This is because + // AVX flavours and instrinsics and compilers in combination do not + // handle this pattern of extraction very well. + const float* rhs_data = rhs_ptr; + + for (int j = 0; j < 8; ++j) { + const __m256 dup_rhs_element_j = _mm256_set1_ps(rhs_data[j]); + accum_data_v[j] = + _mm256_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v[j]); + } + lhs_ptr += 8; + rhs_ptr += 8; + } + + if (residual_rows == 8) { + for (int j = 0; j < 8; ++j) { + float* block_ptr = dst_ptr + j * dst_stride; + accum_data_v[j] = _mm256_min_ps(accum_data_v[j], clamp_max_v); + accum_data_v[j] = _mm256_max_ps(accum_data_v[j], clamp_min_v); + _mm256_storeu_ps(block_ptr, accum_data_v[j]); + } + } else { + for (int j = 0; j < 8; ++j) { + float* block_ptr = dst_ptr + j * dst_stride; + accum_data_v[j] = _mm256_min_ps(accum_data_v[j], clamp_max_v); + accum_data_v[j] = _mm256_max_ps(accum_data_v[j], clamp_min_v); + intrin_utils::mm256_n_storeu_ps(block_ptr, residual_rows, + accum_data_v[j]); + } + } + } // End row-block loop. + } // End col-block loop. + + if (col < end_col) { + // Remaining cols in [0, float block size). + RUY_DCHECK_GE(end_col - col, 0); + RUY_DCHECK_LT(end_col - col, 8); + + __m256 accum_data_v[8]; + + const float* rhs_col_ptr = adj_rhs_col_ptr + col * rhs_stride; + float* dst_col_ptr = adj_dst_col_ptr + col * dst_stride; + const int residual_cols = std::min(end_col - col, 8); + + for (int row = params.start_row; row < end_row; row += 8) { + const int residual_rows = std::min(end_row - row, 8); + + const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; + float* dst_ptr = dst_col_ptr + row; + const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment; + + // Initialize with bias. + const __m256 initial_accum_data = + intrin_utils::mm256_n_loadu_ps(residual_rows, bias_ptr); + + for (int j = 0; j < 8; ++j) { + accum_data_v[j] = initial_accum_data; + } + + const float* lhs_ptr = lhs_col_ptr; + const float* rhs_ptr = rhs_col_ptr; + for (int d = 0; d < params.depth; ++d) { + const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr); + const float* rhs_data = rhs_ptr; + + for (int j = 0; j < 8; ++j) { + const __m256 dup_rhs_element_j = _mm256_set1_ps(rhs_data[j]); + accum_data_v[j] = + _mm256_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v[j]); + } + lhs_ptr += 8; + rhs_ptr += 8; + } + + for (int j = 0; j < residual_cols; ++j) { + float* block_ptr = dst_ptr + j * dst_stride; + accum_data_v[j] = _mm256_min_ps(accum_data_v[j], clamp_max_v); + accum_data_v[j] = _mm256_max_ps(accum_data_v[j], clamp_min_v); + intrin_utils::mm256_n_storeu_ps(block_ptr, residual_rows, + accum_data_v[j]); + } + } // End row-block loop. + } // End col-block terminal conditional. +} + +void KernelFloatAvx2SingleCol(const KernelParamsFloat<8, 8>& params) { + profiler::ScopeLabel label("Kernel kAvx2 float GEMV"); + + RUY_DCHECK_EQ(params.dst_cols, 1); + RUY_DCHECK_EQ(params.last_col, 0); + RUY_DCHECK_EQ(params.start_col, 0); + + // As parameters are defined, we need to scale by sizeof(float). + const std::int64_t lhs_stride = params.lhs_stride >> 2; + // + int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 1 : 0; + // AVX2 float block size = 8. + const int end_row = std::min(params.dst_rows, params.last_row + 8); + + float* adj_dst_col_ptr = params.dst_base_ptr - params.start_row; + const float* adj_lhs_col_ptr = + params.lhs_base_ptr - params.start_row * lhs_stride; + const float* bias_col_ptr = params.bias; + + const __m256 clamp_max_v = _mm256_set1_ps(params.clamp_max); + const __m256 clamp_min_v = _mm256_set1_ps(params.clamp_min); + + __m256 accum_data_v; + + const float* rhs_col_ptr = params.rhs_base_ptr; + float* dst_col_ptr = adj_dst_col_ptr; + + int row = params.start_row; + for (; row <= end_row - 8; row += 8) { + const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; + float* dst_ptr = dst_col_ptr + row; + const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment; + + // Initialize with bias. + accum_data_v = _mm256_loadu_ps(bias_ptr); + + const float* lhs_ptr = lhs_col_ptr; + const float* rhs_ptr = rhs_col_ptr; + int d = 0; + for (; d <= params.depth - 4; d += 4) { + const __m256 lhs_data_0 = _mm256_loadu_ps(lhs_ptr); + const __m256 dup_rhs_element_0 = _mm256_set1_ps(rhs_ptr[0]); + accum_data_v = + _mm256_fmadd_ps(lhs_data_0, dup_rhs_element_0, accum_data_v); + const __m256 dup_rhs_element_1 = _mm256_set1_ps(rhs_ptr[8]); + const __m256 lhs_data_1 = _mm256_loadu_ps(lhs_ptr + 8); + accum_data_v = + _mm256_fmadd_ps(lhs_data_1, dup_rhs_element_1, accum_data_v); + + const __m256 lhs_data_2 = _mm256_loadu_ps(lhs_ptr + 16); + const __m256 dup_rhs_element_2 = _mm256_set1_ps(rhs_ptr[16]); + accum_data_v = + _mm256_fmadd_ps(lhs_data_2, dup_rhs_element_2, accum_data_v); + const __m256 dup_rhs_element_3 = _mm256_set1_ps(rhs_ptr[24]); + const __m256 lhs_data_3 = _mm256_loadu_ps(lhs_ptr + 24); + accum_data_v = + _mm256_fmadd_ps(lhs_data_3, dup_rhs_element_3, accum_data_v); + lhs_ptr += 32; // Loaded 8 * 4 floats. + rhs_ptr += 32; + } + for (; d < params.depth; ++d) { + const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr); + const float* rhs_data = rhs_ptr; + + const __m256 dup_rhs_element_j = _mm256_set1_ps(rhs_data[0]); + accum_data_v = _mm256_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v); + lhs_ptr += 8; + rhs_ptr += 8; + } + + accum_data_v = _mm256_min_ps(accum_data_v, clamp_max_v); + accum_data_v = _mm256_max_ps(accum_data_v, clamp_min_v); + _mm256_storeu_ps(dst_ptr, accum_data_v); + } // End row-block loop. + + if (row < end_row) { + const int residual_rows = end_row - row; + RUY_CHECK_GE(residual_rows, 1); + RUY_CHECK_LT(residual_rows, 8); + + const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; + float* dst_ptr = dst_col_ptr + row; + const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment; + + // Initialize with bias. + accum_data_v = intrin_utils::mm256_n_loadu_ps(residual_rows, bias_ptr); + + const float* lhs_ptr = lhs_col_ptr; + const float* rhs_ptr = rhs_col_ptr; + for (int d = 0; d < params.depth; ++d) { + const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr); + const float* rhs_data = rhs_ptr; + + const __m256 dup_rhs_element_j = _mm256_set1_ps(rhs_data[0]); + accum_data_v = _mm256_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v); + lhs_ptr += 8; + rhs_ptr += 8; + } + + accum_data_v = _mm256_min_ps(accum_data_v, clamp_max_v); + accum_data_v = _mm256_max_ps(accum_data_v, clamp_min_v); + intrin_utils::mm256_n_storeu_ps(dst_ptr, residual_rows, accum_data_v); + } // End handling of residual rows. +} + +#endif // RUY_PLATFORM(AVX2) && RUY_OPT_ENABLED(RUY_OPT_ASM) + +} // namespace ruy diff --git a/ruy/kernel_avx512.cc b/ruy/kernel_avx512.cc new file mode 100644 index 0000000..5e771a5 --- /dev/null +++ b/ruy/kernel_avx512.cc @@ -0,0 +1,1820 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "ruy/check_macros.h" +#include "ruy/kernel.h" +#include "ruy/opt_set.h" +#include "ruy/platform.h" +#include "ruy/profiler/instrumentation.h" + +#if RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_ASM) +#include // IWYU pragma: keep +#endif + +namespace ruy { + +#if !(RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_ASM)) + +void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params) { + // CPU-ID-based checks should disable the path that would reach this point. + RUY_DCHECK(false); +} + +void Kernel8bitAvx512SingleCol(const KernelParams8bit<16, 16>& params) { + // CPU-ID-based checks should disable the path that would reach this point. + RUY_DCHECK(false); +} + +void KernelFloatAvx512(const KernelParamsFloat<16, 16>& params) { + // CPU-ID-based checks should disable the path that would reach this point. + RUY_DCHECK(false); +} + +void KernelFloatAvx512SingleCol(const KernelParamsFloat<16, 16>& params) { + // CPU-ID-based checks should disable the path that would reach this point. + RUY_DCHECK(false); +} + +#else // RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_ASM) + +void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params) { + profiler::ScopeLabel label("Kernel kAvx512 8-bit"); + + std::int32_t dst_stride; + if ((params.dst_type_id == DstTypeId::kValue) || + (params.dst_type_id == DstTypeId::kValue)) { + dst_stride = params.dst_stride; + } else if (params.dst_type_id == DstTypeId::kValue) { + dst_stride = params.dst_stride / sizeof(std::int16_t); + } else if (params.dst_type_id == DstTypeId::kValue) { + dst_stride = params.dst_stride / sizeof(std::int32_t); + } else { + RUY_DCHECK(false); + } + + int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 16 : 0; + + const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; + void* dst_col_ptr = params.dst_base_ptr; + const std::int32_t* bias_col_ptr = params.bias; + if (params.flags & RUY_ASM_FLAG_HAS_BIAS) { + bias_col_ptr += params.start_row; + } + + for (int col = params.start_col; col <= params.last_col; col += 16) { + const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; + void* dst_ptr = dst_col_ptr; + const std::int32_t* bias_ptr = bias_col_ptr; + + const std::int32_t lhs_zero_point = params.lhs_zero_point; + const bool has_rhs_sums_offsets = + (params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && lhs_zero_point; + std::int32_t rhs_sums_offsets[16]; + if (has_rhs_sums_offsets) { + const __m512i rhs_sums_offset_v = + _mm512_mullo_epi32(_mm512_set1_epi32(lhs_zero_point), + _mm512_loadu_epi32(¶ms.rhs_sums[col])); + _mm512_storeu_si512(reinterpret_cast<__m512i*>(rhs_sums_offsets), + rhs_sums_offset_v); + } + + for (int row = params.start_row; row <= params.last_row; row += 16) { + const int residual_rows = std::min(params.dst_rows - row, 16); + const int residual_cols = std::min(params.dst_cols - col, 16); + + __m512i accum_data_v0; + __m512i accum_data_v1; + __m512i accum_data_v2; + __m512i accum_data_v3; + __m512i accum_data_v4; + __m512i accum_data_v5; + __m512i accum_data_v6; + __m512i accum_data_v7; + __m512i accum_data_v8; + __m512i accum_data_v9; + __m512i accum_data_va; + __m512i accum_data_vb; + __m512i accum_data_vc; + __m512i accum_data_vd; + __m512i accum_data_ve; + __m512i accum_data_vf; + + // Initialize with bias. + const __mmask16 row_mask = + (static_cast(1) << residual_rows) - 1; + __m512i initial_accum_data = _mm512_maskz_loadu_epi32(row_mask, bias_ptr); + bias_ptr += bias_ptr_block_increment; + + const std::int32_t rhs_zero_point = params.rhs_zero_point; + if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && rhs_zero_point) { + const __m512i lhs_sums_offset = + _mm512_mullo_epi32(_mm512_set1_epi32(rhs_zero_point), + _mm512_loadu_epi32(¶ms.lhs_sums[row])); + initial_accum_data = + _mm512_sub_epi32(initial_accum_data, lhs_sums_offset); + } + + const std::int32_t prod_zp_depth = params.prod_zp_depth; + if (prod_zp_depth != 0) { + initial_accum_data = _mm512_add_epi32(initial_accum_data, + _mm512_set1_epi32(prod_zp_depth)); + } + + // Adjustments differing across columns. + if (has_rhs_sums_offsets) { + accum_data_v0 = _mm512_sub_epi32( + initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[0])); + accum_data_v1 = _mm512_sub_epi32( + initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[1])); + accum_data_v2 = _mm512_sub_epi32( + initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[2])); + accum_data_v3 = _mm512_sub_epi32( + initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[3])); + accum_data_v4 = _mm512_sub_epi32( + initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[4])); + accum_data_v5 = _mm512_sub_epi32( + initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[5])); + accum_data_v6 = _mm512_sub_epi32( + initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[6])); + accum_data_v7 = _mm512_sub_epi32( + initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[7])); + accum_data_v8 = _mm512_sub_epi32( + initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[8])); + accum_data_v9 = _mm512_sub_epi32( + initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[9])); + accum_data_va = _mm512_sub_epi32( + initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[10])); + accum_data_vb = _mm512_sub_epi32( + initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[11])); + accum_data_vc = _mm512_sub_epi32( + initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[12])); + accum_data_vd = _mm512_sub_epi32( + initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[13])); + accum_data_ve = _mm512_sub_epi32( + initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[14])); + accum_data_vf = _mm512_sub_epi32( + initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[15])); + } else { + accum_data_v0 = initial_accum_data; + accum_data_v1 = initial_accum_data; + accum_data_v2 = initial_accum_data; + accum_data_v3 = initial_accum_data; + accum_data_v4 = initial_accum_data; + accum_data_v5 = initial_accum_data; + accum_data_v6 = initial_accum_data; + accum_data_v7 = initial_accum_data; + accum_data_v8 = initial_accum_data; + accum_data_v9 = initial_accum_data; + accum_data_va = initial_accum_data; + accum_data_vb = initial_accum_data; + accum_data_vc = initial_accum_data; + accum_data_vd = initial_accum_data; + accum_data_ve = initial_accum_data; + accum_data_vf = initial_accum_data; + } + + const std::int8_t* lhs_ptr = lhs_col_ptr; + const std::int8_t* rhs_ptr = rhs_col_ptr; + for (int d = 0; d < params.depth; d += 4) { + const __m512i lhs_data = _mm512_loadu_epi8(lhs_ptr); + __m512i rhs_data_8bit = _mm512_loadu_epi8(rhs_ptr); + + // Each "int32" is two 16-bit RHS values, sign extended from 8-bit. + std::int32_t rhs_data[32]; + const __m256i rhs_data_bottom_lane = + _mm512_castsi512_si256(rhs_data_8bit); + const __m256i rhs_data_top_lane = + _mm512_extracti32x8_epi32(rhs_data_8bit, 1); + const __m512i rhs_16_bit_dup_low = + _mm512_cvtepi8_epi16(rhs_data_bottom_lane); + const __m512i rhs_16_bit_dup_high = + _mm512_cvtepi8_epi16(rhs_data_top_lane); + // Now that we have cast the RHS data, we store it so that each value + // can be separately loaded in the accumulation loop. + _mm512_storeu_si512(reinterpret_cast<__m256i*>(rhs_data), + rhs_16_bit_dup_low); + _mm512_storeu_si512(reinterpret_cast<__m256i*>(rhs_data + 16), + rhs_16_bit_dup_high); + + // Take bytes 0, 1, 4, 5, 8, 9, ... and expand to 16-bit. + const __m512i lhs_16_bit_low = + _mm512_cvtepi8_epi16(_mm512_cvtepi32_epi16(lhs_data)); + // Take bytes 2, 3, 6, 7, 10, 11, ... and expand to 16-bit. + const __m512i lhs_16_bit_high = _mm512_cvtepi8_epi16( + _mm512_cvtepi32_epi16(_mm512_srli_epi32(lhs_data, 16))); + + // Process column 0. + { + __m512i accum_v = accum_data_v0; + constexpr int index = 0; + + const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); + const __m512i rhs_16_bit_dup_high = + _mm512_set1_epi32(rhs_data[index + 1]); + + accum_v = _mm512_add_epi32( + accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); + accum_v = _mm512_add_epi32( + accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); + accum_data_v0 = accum_v; + } + // Process column 1. + { + __m512i accum_v = accum_data_v1; + constexpr int index = 2; + + const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); + const __m512i rhs_16_bit_dup_high = + _mm512_set1_epi32(rhs_data[index + 1]); + + accum_v = _mm512_add_epi32( + accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); + accum_v = _mm512_add_epi32( + accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); + accum_data_v1 = accum_v; + } + // Process column 2. + { + __m512i accum_v = accum_data_v2; + constexpr int index = 4; + + const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); + const __m512i rhs_16_bit_dup_high = + _mm512_set1_epi32(rhs_data[index + 1]); + + accum_v = _mm512_add_epi32( + accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); + accum_v = _mm512_add_epi32( + accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); + accum_data_v2 = accum_v; + } + // Process column 3. + { + __m512i accum_v = accum_data_v3; + constexpr int index = 6; + + const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); + const __m512i rhs_16_bit_dup_high = + _mm512_set1_epi32(rhs_data[index + 1]); + + accum_v = _mm512_add_epi32( + accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); + accum_v = _mm512_add_epi32( + accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); + accum_data_v3 = accum_v; + } + // Process column 4. + { + __m512i accum_v = accum_data_v4; + constexpr int index = 8; + + const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); + const __m512i rhs_16_bit_dup_high = + _mm512_set1_epi32(rhs_data[index + 1]); + + accum_v = _mm512_add_epi32( + accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); + accum_v = _mm512_add_epi32( + accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); + accum_data_v4 = accum_v; + } + // Process column 5. + { + __m512i accum_v = accum_data_v5; + constexpr int index = 10; + + const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); + const __m512i rhs_16_bit_dup_high = + _mm512_set1_epi32(rhs_data[index + 1]); + + accum_v = _mm512_add_epi32( + accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); + accum_v = _mm512_add_epi32( + accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); + accum_data_v5 = accum_v; + } + // Process column 6. + { + __m512i accum_v = accum_data_v6; + constexpr int index = 12; + + const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); + const __m512i rhs_16_bit_dup_high = + _mm512_set1_epi32(rhs_data[index + 1]); + + accum_v = _mm512_add_epi32( + accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); + accum_v = _mm512_add_epi32( + accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); + accum_data_v6 = accum_v; + } + // Process column 7. + { + __m512i accum_v = accum_data_v7; + constexpr int index = 14; + + const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); + const __m512i rhs_16_bit_dup_high = + _mm512_set1_epi32(rhs_data[index + 1]); + + accum_v = _mm512_add_epi32( + accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); + accum_v = _mm512_add_epi32( + accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); + accum_data_v7 = accum_v; + } + // Process column 8. + { + __m512i accum_v = accum_data_v8; + constexpr int index = 16; + + const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); + const __m512i rhs_16_bit_dup_high = + _mm512_set1_epi32(rhs_data[index + 1]); + + accum_v = _mm512_add_epi32( + accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); + accum_v = _mm512_add_epi32( + accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); + accum_data_v8 = accum_v; + } + // Process column 9. + { + __m512i accum_v = accum_data_v9; + constexpr int index = 18; + + const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); + const __m512i rhs_16_bit_dup_high = + _mm512_set1_epi32(rhs_data[index + 1]); + + accum_v = _mm512_add_epi32( + accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); + accum_v = _mm512_add_epi32( + accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); + accum_data_v9 = accum_v; + } + // Process column 10. + { + __m512i accum_v = accum_data_va; + constexpr int index = 20; + + const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); + const __m512i rhs_16_bit_dup_high = + _mm512_set1_epi32(rhs_data[index + 1]); + + accum_v = _mm512_add_epi32( + accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); + accum_v = _mm512_add_epi32( + accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); + accum_data_va = accum_v; + } + // Process column 11. + { + __m512i accum_v = accum_data_vb; + constexpr int index = 22; + + const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); + const __m512i rhs_16_bit_dup_high = + _mm512_set1_epi32(rhs_data[index + 1]); + + accum_v = _mm512_add_epi32( + accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); + accum_v = _mm512_add_epi32( + accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); + accum_data_vb = accum_v; + } + // Process column 12. + { + __m512i accum_v = accum_data_vc; + constexpr int index = 24; + + const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); + const __m512i rhs_16_bit_dup_high = + _mm512_set1_epi32(rhs_data[index + 1]); + + accum_v = _mm512_add_epi32( + accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); + accum_v = _mm512_add_epi32( + accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); + accum_data_vc = accum_v; + } + // Process column 13. + { + __m512i accum_v = accum_data_vd; + constexpr int index = 26; + + const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); + const __m512i rhs_16_bit_dup_high = + _mm512_set1_epi32(rhs_data[index + 1]); + + accum_v = _mm512_add_epi32( + accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); + accum_v = _mm512_add_epi32( + accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); + accum_data_vd = accum_v; + } + // Process column 14. + { + __m512i accum_v = accum_data_ve; + constexpr int index = 28; + + const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); + const __m512i rhs_16_bit_dup_high = + _mm512_set1_epi32(rhs_data[index + 1]); + + accum_v = _mm512_add_epi32( + accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); + accum_v = _mm512_add_epi32( + accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); + accum_data_ve = accum_v; + } + // Process column 15. + { + __m512i accum_v = accum_data_vf; + constexpr int index = 30; + + const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); + const __m512i rhs_16_bit_dup_high = + _mm512_set1_epi32(rhs_data[index + 1]); + + accum_v = _mm512_add_epi32( + accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); + accum_v = _mm512_add_epi32( + accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); + accum_data_vf = accum_v; + } + + lhs_ptr += 16 * 4; + rhs_ptr += 16 * 4; + } + + if (params.dst_type_id != DstTypeId::kValue) { + __m512i m_vector; + __m512i e_vector; + // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT. + if (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) { + m_vector = _mm512_maskz_loadu_epi32( + row_mask, ¶ms.multiplier_fixedpoint[row]); + e_vector = _mm512_maskz_loadu_epi32(row_mask, + ¶ms.multiplier_exponent[row]); + } else { + // These arrays have size LhsCols, and are pre-filled. + m_vector = _mm512_set1_epi32(params.multiplier_fixedpoint[0]); + e_vector = _mm512_set1_epi32(params.multiplier_exponent[0]); + } + + const __m512i m_64bit_low = + _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(m_vector, 0)); + const __m512i m_64bit_high = + _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(m_vector, 1)); + + const __m512i zero_vector = _mm512_setzero_epi32(); + const __m512i left_shift = _mm512_max_epi32(e_vector, zero_vector); + const __m512i neg_e_vector = _mm512_sub_epi32(zero_vector, e_vector); + const __m512i right_shift = _mm512_max_epi32(neg_e_vector, zero_vector); + const __m512i final_right_shift = + _mm512_add_epi32(right_shift, _mm512_set1_epi32(31)); + const __m512i final_right_shift_low = _mm512_cvtepi32_epi64( + _mm512_extracti32x8_epi32(final_right_shift, 0)); + const __m512i final_right_shift_high = _mm512_cvtepi32_epi64( + _mm512_extracti32x8_epi32(final_right_shift, 1)); + + const __m512i offset_vector = + _mm512_slli_epi64(_mm512_set1_epi64(1), 30); + // Really these should be shifted by neg_e_vector, but tests pass when + // using right_shift. + const __m512i offset_vector_low = _mm512_sllv_epi64( + offset_vector, + _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(right_shift, 0))); + const __m512i offset_vector_high = _mm512_sllv_epi64( + offset_vector, + _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(right_shift, 1))); + + // Shift and round column 0. + { + accum_data_v0 = _mm512_sllv_epi32(accum_data_v0, left_shift); + // Apply the fixed-point part of the multiplier. + __m512i scaled_v_low = + _mm512_mul_epi32(_mm512_cvtepi32_epi64( + _mm512_extracti32x8_epi32(accum_data_v0, 0)), + m_64bit_low); + __m512i scaled_v_high = + _mm512_mul_epi32(_mm512_cvtepi32_epi64( + _mm512_extracti32x8_epi32(accum_data_v0, 1)), + m_64bit_high); + + scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); + scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); + + scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); + scaled_v_high = + _mm512_srav_epi64(scaled_v_high, final_right_shift_high); + + accum_data_v0 = + _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); + accum_data_v0 = _mm512_inserti32x8( + accum_data_v0, _mm512_cvtepi64_epi32(scaled_v_high), 1); + } + // Shift and round column 1. + { + accum_data_v1 = _mm512_sllv_epi32(accum_data_v1, left_shift); + // Apply the fixed-point part of the multiplier. + __m512i scaled_v_low = + _mm512_mul_epi32(_mm512_cvtepi32_epi64( + _mm512_extracti32x8_epi32(accum_data_v1, 0)), + m_64bit_low); + __m512i scaled_v_high = + _mm512_mul_epi32(_mm512_cvtepi32_epi64( + _mm512_extracti32x8_epi32(accum_data_v1, 1)), + m_64bit_high); + + scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); + scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); + + scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); + scaled_v_high = + _mm512_srav_epi64(scaled_v_high, final_right_shift_high); + + accum_data_v1 = + _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); + accum_data_v1 = _mm512_inserti32x8( + accum_data_v1, _mm512_cvtepi64_epi32(scaled_v_high), 1); + } + // Shift and round column 2. + { + accum_data_v2 = _mm512_sllv_epi32(accum_data_v2, left_shift); + // Apply the fixed-point part of the multiplier. + __m512i scaled_v_low = + _mm512_mul_epi32(_mm512_cvtepi32_epi64( + _mm512_extracti32x8_epi32(accum_data_v2, 0)), + m_64bit_low); + __m512i scaled_v_high = + _mm512_mul_epi32(_mm512_cvtepi32_epi64( + _mm512_extracti32x8_epi32(accum_data_v2, 1)), + m_64bit_high); + + scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); + scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); + + scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); + scaled_v_high = + _mm512_srav_epi64(scaled_v_high, final_right_shift_high); + + accum_data_v2 = + _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); + accum_data_v2 = _mm512_inserti32x8( + accum_data_v2, _mm512_cvtepi64_epi32(scaled_v_high), 1); + } + // Shift and round column 3. + { + accum_data_v3 = _mm512_sllv_epi32(accum_data_v3, left_shift); + // Apply the fixed-point part of the multiplier. + __m512i scaled_v_low = + _mm512_mul_epi32(_mm512_cvtepi32_epi64( + _mm512_extracti32x8_epi32(accum_data_v3, 0)), + m_64bit_low); + __m512i scaled_v_high = + _mm512_mul_epi32(_mm512_cvtepi32_epi64( + _mm512_extracti32x8_epi32(accum_data_v3, 1)), + m_64bit_high); + + scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); + scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); + + scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); + scaled_v_high = + _mm512_srav_epi64(scaled_v_high, final_right_shift_high); + + accum_data_v3 = + _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); + accum_data_v3 = _mm512_inserti32x8( + accum_data_v3, _mm512_cvtepi64_epi32(scaled_v_high), 1); + } + // Shift and round column 4. + { + accum_data_v4 = _mm512_sllv_epi32(accum_data_v4, left_shift); + // Apply the fixed-point part of the multiplier. + __m512i scaled_v_low = + _mm512_mul_epi32(_mm512_cvtepi32_epi64( + _mm512_extracti32x8_epi32(accum_data_v4, 0)), + m_64bit_low); + __m512i scaled_v_high = + _mm512_mul_epi32(_mm512_cvtepi32_epi64( + _mm512_extracti32x8_epi32(accum_data_v4, 1)), + m_64bit_high); + + scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); + scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); + + scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); + scaled_v_high = + _mm512_srav_epi64(scaled_v_high, final_right_shift_high); + + accum_data_v4 = + _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); + accum_data_v4 = _mm512_inserti32x8( + accum_data_v4, _mm512_cvtepi64_epi32(scaled_v_high), 1); + } + // Shift and round column 5. + { + accum_data_v5 = _mm512_sllv_epi32(accum_data_v5, left_shift); + // Apply the fixed-point part of the multiplier. + __m512i scaled_v_low = + _mm512_mul_epi32(_mm512_cvtepi32_epi64( + _mm512_extracti32x8_epi32(accum_data_v5, 0)), + m_64bit_low); + __m512i scaled_v_high = + _mm512_mul_epi32(_mm512_cvtepi32_epi64( + _mm512_extracti32x8_epi32(accum_data_v5, 1)), + m_64bit_high); + + scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); + scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); + + scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); + scaled_v_high = + _mm512_srav_epi64(scaled_v_high, final_right_shift_high); + + accum_data_v5 = + _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); + accum_data_v5 = _mm512_inserti32x8( + accum_data_v5, _mm512_cvtepi64_epi32(scaled_v_high), 1); + } + // Shift and round column 6. + { + accum_data_v6 = _mm512_sllv_epi32(accum_data_v6, left_shift); + // Apply the fixed-point part of the multiplier. + __m512i scaled_v_low = + _mm512_mul_epi32(_mm512_cvtepi32_epi64( + _mm512_extracti32x8_epi32(accum_data_v6, 0)), + m_64bit_low); + __m512i scaled_v_high = + _mm512_mul_epi32(_mm512_cvtepi32_epi64( + _mm512_extracti32x8_epi32(accum_data_v6, 1)), + m_64bit_high); + + scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); + scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); + + scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); + scaled_v_high = + _mm512_srav_epi64(scaled_v_high, final_right_shift_high); + + accum_data_v6 = + _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); + accum_data_v6 = _mm512_inserti32x8( + accum_data_v6, _mm512_cvtepi64_epi32(scaled_v_high), 1); + } + // Shift and round column 7. + { + accum_data_v7 = _mm512_sllv_epi32(accum_data_v7, left_shift); + // Apply the fixed-point part of the multiplier. + __m512i scaled_v_low = + _mm512_mul_epi32(_mm512_cvtepi32_epi64( + _mm512_extracti32x8_epi32(accum_data_v7, 0)), + m_64bit_low); + __m512i scaled_v_high = + _mm512_mul_epi32(_mm512_cvtepi32_epi64( + _mm512_extracti32x8_epi32(accum_data_v7, 1)), + m_64bit_high); + + scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); + scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); + + scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); + scaled_v_high = + _mm512_srav_epi64(scaled_v_high, final_right_shift_high); + + accum_data_v7 = + _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); + accum_data_v7 = _mm512_inserti32x8( + accum_data_v7, _mm512_cvtepi64_epi32(scaled_v_high), 1); + } + // Shift and round column 8. + { + accum_data_v8 = _mm512_sllv_epi32(accum_data_v8, left_shift); + // Apply the fixed-point part of the multiplier. + __m512i scaled_v_low = + _mm512_mul_epi32(_mm512_cvtepi32_epi64( + _mm512_extracti32x8_epi32(accum_data_v8, 0)), + m_64bit_low); + __m512i scaled_v_high = + _mm512_mul_epi32(_mm512_cvtepi32_epi64( + _mm512_extracti32x8_epi32(accum_data_v8, 1)), + m_64bit_high); + + scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); + scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); + + scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); + scaled_v_high = + _mm512_srav_epi64(scaled_v_high, final_right_shift_high); + + accum_data_v8 = + _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); + accum_data_v8 = _mm512_inserti32x8( + accum_data_v8, _mm512_cvtepi64_epi32(scaled_v_high), 1); + } + // Shift and round column 9. + { + accum_data_v9 = _mm512_sllv_epi32(accum_data_v9, left_shift); + // Apply the fixed-point part of the multiplier. + __m512i scaled_v_low = + _mm512_mul_epi32(_mm512_cvtepi32_epi64( + _mm512_extracti32x8_epi32(accum_data_v9, 0)), + m_64bit_low); + __m512i scaled_v_high = + _mm512_mul_epi32(_mm512_cvtepi32_epi64( + _mm512_extracti32x8_epi32(accum_data_v9, 1)), + m_64bit_high); + + scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); + scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); + + scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); + scaled_v_high = + _mm512_srav_epi64(scaled_v_high, final_right_shift_high); + + accum_data_v9 = + _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); + accum_data_v9 = _mm512_inserti32x8( + accum_data_v9, _mm512_cvtepi64_epi32(scaled_v_high), 1); + } + // Shift and round column 10. + { + accum_data_va = _mm512_sllv_epi32(accum_data_va, left_shift); + // Apply the fixed-point part of the multiplier. + __m512i scaled_v_low = + _mm512_mul_epi32(_mm512_cvtepi32_epi64( + _mm512_extracti32x8_epi32(accum_data_va, 0)), + m_64bit_low); + __m512i scaled_v_high = + _mm512_mul_epi32(_mm512_cvtepi32_epi64( + _mm512_extracti32x8_epi32(accum_data_va, 1)), + m_64bit_high); + + scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); + scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); + + scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); + scaled_v_high = + _mm512_srav_epi64(scaled_v_high, final_right_shift_high); + + accum_data_va = + _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); + accum_data_va = _mm512_inserti32x8( + accum_data_va, _mm512_cvtepi64_epi32(scaled_v_high), 1); + } + // Shift and round column 11. + { + accum_data_vb = _mm512_sllv_epi32(accum_data_vb, left_shift); + // Apply the fixed-point part of the multiplier. + __m512i scaled_v_low = + _mm512_mul_epi32(_mm512_cvtepi32_epi64( + _mm512_extracti32x8_epi32(accum_data_vb, 0)), + m_64bit_low); + __m512i scaled_v_high = + _mm512_mul_epi32(_mm512_cvtepi32_epi64( + _mm512_extracti32x8_epi32(accum_data_vb, 1)), + m_64bit_high); + + scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); + scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); + + scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); + scaled_v_high = + _mm512_srav_epi64(scaled_v_high, final_right_shift_high); + + accum_data_vb = + _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); + accum_data_vb = _mm512_inserti32x8( + accum_data_vb, _mm512_cvtepi64_epi32(scaled_v_high), 1); + } + // Shift and round column 12. + { + accum_data_vc = _mm512_sllv_epi32(accum_data_vc, left_shift); + // Apply the fixed-point part of the multiplier. + __m512i scaled_v_low = + _mm512_mul_epi32(_mm512_cvtepi32_epi64( + _mm512_extracti32x8_epi32(accum_data_vc, 0)), + m_64bit_low); + __m512i scaled_v_high = + _mm512_mul_epi32(_mm512_cvtepi32_epi64( + _mm512_extracti32x8_epi32(accum_data_vc, 1)), + m_64bit_high); + + scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); + scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); + + scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); + scaled_v_high = + _mm512_srav_epi64(scaled_v_high, final_right_shift_high); + + accum_data_vc = + _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); + accum_data_vc = _mm512_inserti32x8( + accum_data_vc, _mm512_cvtepi64_epi32(scaled_v_high), 1); + } + // Shift and round column 13. + { + accum_data_vd = _mm512_sllv_epi32(accum_data_vd, left_shift); + // Apply the fixed-point part of the multiplier. + __m512i scaled_v_low = + _mm512_mul_epi32(_mm512_cvtepi32_epi64( + _mm512_extracti32x8_epi32(accum_data_vd, 0)), + m_64bit_low); + __m512i scaled_v_high = + _mm512_mul_epi32(_mm512_cvtepi32_epi64( + _mm512_extracti32x8_epi32(accum_data_vd, 1)), + m_64bit_high); + + scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); + scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); + + scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); + scaled_v_high = + _mm512_srav_epi64(scaled_v_high, final_right_shift_high); + + accum_data_vd = + _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); + accum_data_vd = _mm512_inserti32x8( + accum_data_vd, _mm512_cvtepi64_epi32(scaled_v_high), 1); + } + // Shift and round column 14. + { + accum_data_ve = _mm512_sllv_epi32(accum_data_ve, left_shift); + // Apply the fixed-point part of the multiplier. + __m512i scaled_v_low = + _mm512_mul_epi32(_mm512_cvtepi32_epi64( + _mm512_extracti32x8_epi32(accum_data_ve, 0)), + m_64bit_low); + __m512i scaled_v_high = + _mm512_mul_epi32(_mm512_cvtepi32_epi64( + _mm512_extracti32x8_epi32(accum_data_ve, 1)), + m_64bit_high); + + scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); + scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); + + scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); + scaled_v_high = + _mm512_srav_epi64(scaled_v_high, final_right_shift_high); + + accum_data_ve = + _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); + accum_data_ve = _mm512_inserti32x8( + accum_data_ve, _mm512_cvtepi64_epi32(scaled_v_high), 1); + } + // Shift and round column 15. + { + accum_data_vf = _mm512_sllv_epi32(accum_data_vf, left_shift); + // Apply the fixed-point part of the multiplier. + __m512i scaled_v_low = + _mm512_mul_epi32(_mm512_cvtepi32_epi64( + _mm512_extracti32x8_epi32(accum_data_vf, 0)), + m_64bit_low); + __m512i scaled_v_high = + _mm512_mul_epi32(_mm512_cvtepi32_epi64( + _mm512_extracti32x8_epi32(accum_data_vf, 1)), + m_64bit_high); + + scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); + scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); + + scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); + scaled_v_high = + _mm512_srav_epi64(scaled_v_high, final_right_shift_high); + + accum_data_vf = + _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); + accum_data_vf = _mm512_inserti32x8( + accum_data_vf, _mm512_cvtepi64_epi32(scaled_v_high), 1); + } +#if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING) + RUY_DCHECK(false); +#endif + + if (params.dst_zero_point != 0) { + __m512i dst_zero_point = _mm512_set1_epi32(params.dst_zero_point); + accum_data_v0 = _mm512_add_epi32(accum_data_v0, dst_zero_point); + accum_data_v1 = _mm512_add_epi32(accum_data_v1, dst_zero_point); + accum_data_v2 = _mm512_add_epi32(accum_data_v2, dst_zero_point); + accum_data_v3 = _mm512_add_epi32(accum_data_v3, dst_zero_point); + accum_data_v4 = _mm512_add_epi32(accum_data_v4, dst_zero_point); + accum_data_v5 = _mm512_add_epi32(accum_data_v5, dst_zero_point); + accum_data_v6 = _mm512_add_epi32(accum_data_v6, dst_zero_point); + accum_data_v7 = _mm512_add_epi32(accum_data_v7, dst_zero_point); + accum_data_v8 = _mm512_add_epi32(accum_data_v8, dst_zero_point); + accum_data_v9 = _mm512_add_epi32(accum_data_v9, dst_zero_point); + accum_data_va = _mm512_add_epi32(accum_data_va, dst_zero_point); + accum_data_vb = _mm512_add_epi32(accum_data_vb, dst_zero_point); + accum_data_vc = _mm512_add_epi32(accum_data_vc, dst_zero_point); + accum_data_vd = _mm512_add_epi32(accum_data_vd, dst_zero_point); + accum_data_ve = _mm512_add_epi32(accum_data_ve, dst_zero_point); + accum_data_vf = _mm512_add_epi32(accum_data_vf, dst_zero_point); + } + } + + const __m512i clamp_max_v = _mm512_set1_epi32(params.clamp_max); + const __m512i clamp_min_v = _mm512_set1_epi32(params.clamp_min); + + const bool store_full_block = + (residual_rows == 16) && (residual_cols == 16); + + __m512i accum_data_v[16]; + + // In most cases we would make this conditional on (!store_full_block) and + // unwind the clamp-and-store loop, but the benefit appears small. + { + accum_data_v[0] = accum_data_v0; + accum_data_v[1] = accum_data_v1; + accum_data_v[2] = accum_data_v2; + accum_data_v[3] = accum_data_v3; + accum_data_v[4] = accum_data_v4; + accum_data_v[5] = accum_data_v5; + accum_data_v[6] = accum_data_v6; + accum_data_v[7] = accum_data_v7; + accum_data_v[8] = accum_data_v8; + accum_data_v[9] = accum_data_v9; + accum_data_v[10] = accum_data_va; + accum_data_v[11] = accum_data_vb; + accum_data_v[12] = accum_data_vc; + accum_data_v[13] = accum_data_vd; + accum_data_v[14] = accum_data_ve; + accum_data_v[15] = accum_data_vf; + } + + if (params.dst_type_id == DstTypeId::kValue) { + std::int8_t* tmp_ptr = static_cast(dst_ptr); + const int block_col_offset = dst_stride; + if (store_full_block) { + for (int j = 0; j < 16; ++j) { + __m512i result = accum_data_v[j]; + result = _mm512_min_epi32(result, clamp_max_v); + result = _mm512_max_epi32(result, clamp_min_v); + _mm_storeu_epi8(tmp_ptr + j * block_col_offset, + _mm512_cvtepi32_epi8(result)); + } + } else { + for (int j = 0; j < residual_cols; ++j) { + __m512i result = accum_data_v[j]; + result = _mm512_min_epi32(result, clamp_max_v); + result = _mm512_max_epi32(result, clamp_min_v); + _mm_mask_storeu_epi8(tmp_ptr + j * block_col_offset, row_mask, + _mm512_cvtepi32_epi8(result)); + } + } + dst_ptr = static_cast(static_cast(dst_ptr) + 16); + } else if (params.dst_type_id == DstTypeId::kValue) { + std::uint8_t* tmp_ptr = static_cast(dst_ptr); + const int block_col_offset = dst_stride; + if (store_full_block) { + for (int j = 0; j < residual_cols; ++j) { + __m512i result = accum_data_v[j]; + result = _mm512_min_epi32(result, clamp_max_v); + result = _mm512_max_epi32(result, clamp_min_v); + _mm_storeu_epi8(tmp_ptr + j * block_col_offset, + _mm512_cvtepi32_epi8(result)); + } + } else { + for (int j = 0; j < residual_cols; ++j) { + __m512i result = accum_data_v[j]; + result = _mm512_min_epi32(result, clamp_max_v); + result = _mm512_max_epi32(result, clamp_min_v); + _mm_mask_storeu_epi8(tmp_ptr + j * block_col_offset, row_mask, + _mm512_cvtepi32_epi8(result)); + } + } + dst_ptr = static_cast(static_cast(dst_ptr) + 16); + } else if (params.dst_type_id == DstTypeId::kValue) { + std::int16_t* tmp_ptr = static_cast(dst_ptr); + const int block_col_offset = dst_stride; + if (store_full_block) { + for (int j = 0; j < 16; ++j) { + __m512i result = accum_data_v[j]; + result = _mm512_min_epi32(result, clamp_max_v); + result = _mm512_max_epi32(result, clamp_min_v); + _mm256_storeu_epi16(tmp_ptr + j * block_col_offset, + _mm512_cvtepi32_epi16(result)); + } + } else { + for (int j = 0; j < residual_cols; ++j) { + __m512i result = accum_data_v[j]; + result = _mm512_min_epi32(result, clamp_max_v); + result = _mm512_max_epi32(result, clamp_min_v); + _mm256_mask_storeu_epi16(tmp_ptr + j * block_col_offset, row_mask, + _mm512_cvtepi32_epi16(result)); + } + } + dst_ptr = static_cast(static_cast(dst_ptr) + 16); + } else if (params.dst_type_id == DstTypeId::kValue) { + if (store_full_block) { + std::int32_t* tmp_ptr = static_cast(dst_ptr); + for (int j = 0; j < 16; ++j) { + _mm512_storeu_epi32(tmp_ptr + j * dst_stride, accum_data_v[j]); + } + } else { + std::int32_t* tmp_ptr = static_cast(dst_ptr); + for (int j = 0; j < residual_cols; ++j) { + _mm512_mask_storeu_epi32(tmp_ptr + j * dst_stride, row_mask, + accum_data_v[j]); + } + } + dst_ptr = static_cast(static_cast(dst_ptr) + 16); + } else { + RUY_DCHECK(false); + } + + lhs_col_ptr += 16 * params.lhs_stride; + } // End row-block loop. + + dst_col_ptr = static_cast(static_cast(dst_col_ptr) + + 16 * params.dst_stride); + rhs_col_ptr += 16 * params.rhs_stride; + } // End col-block loop. +} // NOLINT(readability/fn_size) + +void Kernel8bitAvx512SingleCol(const KernelParams8bit<16, 16>& params) { + profiler::ScopeLabel label("Kernel kAvx512 8-bit GEMV"); + + RUY_DCHECK_EQ(params.dst_cols, 1); + RUY_DCHECK_EQ(params.last_col, 0); + RUY_DCHECK_EQ(params.start_col, 0); + + std::int32_t dst_stride; + if ((params.dst_type_id == DstTypeId::kValue) || + (params.dst_type_id == DstTypeId::kValue)) { + dst_stride = params.dst_stride; + } else if (params.dst_type_id == DstTypeId::kValue) { + dst_stride = params.dst_stride / sizeof(std::int16_t); + } else if (params.dst_type_id == DstTypeId::kValue) { + dst_stride = params.dst_stride / sizeof(std::int32_t); + } else { + RUY_DCHECK(false); + } + + int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 16 : 0; + + const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; + void* dst_col_ptr = params.dst_base_ptr; + const std::int32_t* bias_col_ptr = params.bias; + if (params.flags & RUY_ASM_FLAG_HAS_BIAS) { + bias_col_ptr += params.start_row; + } + + const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; + void* dst_ptr = dst_col_ptr; + const std::int32_t* bias_ptr = bias_col_ptr; + + const std::int32_t lhs_zero_point = params.lhs_zero_point; + const bool has_rhs_sums_offsets = + (params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && lhs_zero_point; + std::int32_t rhs_sums_offsets[16]; + if (has_rhs_sums_offsets) { + const __m512i rhs_sums_offset_v = + _mm512_mullo_epi32(_mm512_set1_epi32(lhs_zero_point), + _mm512_loadu_epi32(¶ms.rhs_sums[0])); + _mm512_storeu_si512(reinterpret_cast<__m512i*>(rhs_sums_offsets), + rhs_sums_offset_v); + } + + for (int row = params.start_row; row <= params.last_row; row += 16) { + const int residual_rows = std::min(params.dst_rows - row, 16); + + __m512i accum_data_v0; + + // Initialize with bias. + const __mmask16 row_mask = + (static_cast(1) << residual_rows) - 1; + __m512i initial_accum_data = _mm512_maskz_loadu_epi32(row_mask, bias_ptr); + bias_ptr += bias_ptr_block_increment; + + const std::int32_t rhs_zero_point = params.rhs_zero_point; + if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && rhs_zero_point) { + const __m512i lhs_sums_offset = + _mm512_mullo_epi32(_mm512_set1_epi32(rhs_zero_point), + _mm512_loadu_epi32(¶ms.lhs_sums[row])); + initial_accum_data = + _mm512_sub_epi32(initial_accum_data, lhs_sums_offset); + } + + const std::int32_t prod_zp_depth = params.prod_zp_depth; + if (prod_zp_depth != 0) { + initial_accum_data = _mm512_add_epi32(initial_accum_data, + _mm512_set1_epi32(prod_zp_depth)); + } + + // Adjustments differing across columns. + if (has_rhs_sums_offsets) { + accum_data_v0 = _mm512_sub_epi32(initial_accum_data, + _mm512_set1_epi32(rhs_sums_offsets[0])); + } else { + accum_data_v0 = initial_accum_data; + } + + const std::int8_t* lhs_ptr = lhs_col_ptr; + const std::int8_t* rhs_ptr = rhs_col_ptr; + for (int d = 0; d < params.depth; d += 4) { + const __m512i lhs_data = _mm512_loadu_epi8(lhs_ptr); + const __m128i rhs_data_8bit = _mm_loadu_epi8(rhs_ptr); + + // Each "int32" is two 16-bit RHS values, sign extended from 8-bit. + // For simplicity we load 4x the data that we need and process twice the + // data that we need and store only the data we need. + std::int32_t rhs_data[2]; + const __m128i rhs_16_bit_dup = _mm_cvtepi8_epi16(rhs_data_8bit); + // Now that we have cast the RHS data, we store it so that each value + // can be separately loaded in the accumulation loop. + _mm_storeu_si64(reinterpret_cast<__m128i*>(rhs_data), rhs_16_bit_dup); + + // Take bytes 0, 1, 4, 5, 8, 9, ... and expand to 16-bit. + const __m512i lhs_16_bit_low = + _mm512_cvtepi8_epi16(_mm512_cvtepi32_epi16(lhs_data)); + // Take bytes 2, 3, 6, 7, 10, 11, ... and expand to 16-bit. + const __m512i lhs_16_bit_high = _mm512_cvtepi8_epi16( + _mm512_cvtepi32_epi16(_mm512_srli_epi32(lhs_data, 16))); + + // Process column 0. + __m512i accum_v = accum_data_v0; + constexpr int index = 0; + + const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); + const __m512i rhs_16_bit_dup_high = + _mm512_set1_epi32(rhs_data[index + 1]); + + accum_v = _mm512_add_epi32( + accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); + accum_v = _mm512_add_epi32( + accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); + accum_data_v0 = accum_v; + + lhs_ptr += 16 * 4; + rhs_ptr += 16 * 4; + } + + if (params.dst_type_id != DstTypeId::kValue) { + __m512i m_vector; + __m512i e_vector; + // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT. + if (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) { + m_vector = _mm512_maskz_loadu_epi32(row_mask, + ¶ms.multiplier_fixedpoint[row]); + e_vector = _mm512_maskz_loadu_epi32(row_mask, + ¶ms.multiplier_exponent[row]); + } else { + // These arrays have size LhsCols, and are pre-filled. + m_vector = _mm512_set1_epi32(params.multiplier_fixedpoint[0]); + e_vector = _mm512_set1_epi32(params.multiplier_exponent[0]); + } + + const __m512i m_64bit_low = + _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(m_vector, 0)); + const __m512i m_64bit_high = + _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(m_vector, 1)); + + const __m512i zero_vector = _mm512_setzero_epi32(); + const __m512i left_shift = _mm512_max_epi32(e_vector, zero_vector); + const __m512i neg_e_vector = _mm512_sub_epi32(zero_vector, e_vector); + const __m512i right_shift = _mm512_max_epi32(neg_e_vector, zero_vector); + const __m512i final_right_shift = + _mm512_add_epi32(right_shift, _mm512_set1_epi32(31)); + const __m512i final_right_shift_low = _mm512_cvtepi32_epi64( + _mm512_extracti32x8_epi32(final_right_shift, 0)); + const __m512i final_right_shift_high = _mm512_cvtepi32_epi64( + _mm512_extracti32x8_epi32(final_right_shift, 1)); + + const __m512i offset_vector = _mm512_slli_epi64(_mm512_set1_epi64(1), 30); + // Really these should be shifted by neg_e_vector, but tests pass when + // using right_shift. + const __m512i offset_vector_low = _mm512_sllv_epi64( + offset_vector, + _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(right_shift, 0))); + const __m512i offset_vector_high = _mm512_sllv_epi64( + offset_vector, + _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(right_shift, 1))); + + // Shift and round column 0. + accum_data_v0 = _mm512_sllv_epi32(accum_data_v0, left_shift); + // Apply the fixed-point part of the multiplier. + __m512i scaled_v_low = _mm512_mul_epi32( + _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(accum_data_v0, 0)), + m_64bit_low); + __m512i scaled_v_high = _mm512_mul_epi32( + _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(accum_data_v0, 1)), + m_64bit_high); + + scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); + scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); + + scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); + scaled_v_high = _mm512_srav_epi64(scaled_v_high, final_right_shift_high); + + accum_data_v0 = + _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); + accum_data_v0 = _mm512_inserti32x8( + accum_data_v0, _mm512_cvtepi64_epi32(scaled_v_high), 1); +#if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING) + RUY_DCHECK(false); +#endif + + if (params.dst_zero_point != 0) { + __m512i dst_zero_point = _mm512_set1_epi32(params.dst_zero_point); + accum_data_v0 = _mm512_add_epi32(accum_data_v0, dst_zero_point); + } + } + + const __m512i clamp_max_v = _mm512_set1_epi32(params.clamp_max); + const __m512i clamp_min_v = _mm512_set1_epi32(params.clamp_min); + + if (params.dst_type_id == DstTypeId::kValue) { + std::int8_t* tmp_ptr = static_cast(dst_ptr); + __m512i result = accum_data_v0; + result = _mm512_min_epi32(result, clamp_max_v); + result = _mm512_max_epi32(result, clamp_min_v); + _mm_mask_storeu_epi8(tmp_ptr, row_mask, _mm512_cvtepi32_epi8(result)); + dst_ptr = static_cast(static_cast(dst_ptr) + 16); + } else if (params.dst_type_id == DstTypeId::kValue) { + std::uint8_t* tmp_ptr = static_cast(dst_ptr); + __m512i result = accum_data_v0; + result = _mm512_min_epi32(result, clamp_max_v); + result = _mm512_max_epi32(result, clamp_min_v); + _mm_mask_storeu_epi8(tmp_ptr, row_mask, _mm512_cvtepi32_epi8(result)); + dst_ptr = static_cast(static_cast(dst_ptr) + 16); + } else if (params.dst_type_id == DstTypeId::kValue) { + std::int16_t* tmp_ptr = static_cast(dst_ptr); + __m512i result = accum_data_v0; + result = _mm512_min_epi32(result, clamp_max_v); + result = _mm512_max_epi32(result, clamp_min_v); + _mm256_mask_storeu_epi16(tmp_ptr, row_mask, + _mm512_cvtepi32_epi16(result)); + dst_ptr = static_cast(static_cast(dst_ptr) + 16); + } else if (params.dst_type_id == DstTypeId::kValue) { + std::int32_t* tmp_ptr = static_cast(dst_ptr); + _mm512_mask_storeu_epi32(tmp_ptr, row_mask, accum_data_v0); + dst_ptr = static_cast(static_cast(dst_ptr) + 16); + } else { + RUY_DCHECK(false); + } + + lhs_col_ptr += 16 * params.lhs_stride; + } // End row-block loop. +} // NOLINT(readability/fn_size) + +void KernelFloatAvx512(const KernelParamsFloat<16, 16>& params) { + profiler::ScopeLabel label("Kernel kAvx512 float"); + + // As parameters are defined, we need to scale by sizeof(float). + const std::int64_t lhs_stride = params.lhs_stride >> 2; + const std::int64_t dst_stride = params.dst_stride >> 2; + const std::int64_t rhs_stride = params.rhs_stride >> 2; + + int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 1 : 0; + const int end_row = std::min(params.dst_rows, params.last_row + 16); + const int end_col = std::min(params.dst_cols, params.last_col + 16); + + const float* adj_rhs_col_ptr = + params.rhs_base_ptr - params.start_col * rhs_stride; + float* adj_dst_col_ptr = + params.dst_base_ptr - params.start_col * dst_stride - params.start_row; + const float* adj_lhs_col_ptr = + params.lhs_base_ptr - params.start_row * lhs_stride; + const float* bias_col_ptr = params.bias; + + const __m512 clamp_max_v = _mm512_set1_ps(params.clamp_max); + const __m512 clamp_min_v = _mm512_set1_ps(params.clamp_min); + + int col = params.start_col; + for (; col <= end_col - 16; col += 16) { + const float* rhs_col_ptr = adj_rhs_col_ptr + col * rhs_stride; + float* dst_col_ptr = adj_dst_col_ptr + col * dst_stride; + + int row = params.start_row; + for (; row <= end_row - 16; row += 16) { + const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; + float* dst_ptr = dst_col_ptr + row; + const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment; + + // Initialize with bias. + const __m512 initial_accum_data = _mm512_loadu_ps(bias_ptr); + + // Process block in two halves, split by columns. + { + constexpr int mmm = 0; + + __m512 accum_data_v0 = initial_accum_data; + __m512 accum_data_v1 = initial_accum_data; + __m512 accum_data_v2 = initial_accum_data; + __m512 accum_data_v3 = initial_accum_data; + __m512 accum_data_v4 = initial_accum_data; + __m512 accum_data_v5 = initial_accum_data; + __m512 accum_data_v6 = initial_accum_data; + __m512 accum_data_v7 = initial_accum_data; + + const float* lhs_ptr = lhs_col_ptr; + const float* rhs_ptr = rhs_col_ptr + 8 * mmm; + for (int d = 0; d < (params.depth - 1); ++d) { + const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr); + // In this version RHS values are loaded individually rather than + // first loading together and then extract with broadcasting. This is + // because AVX flavours and instrinsics and compilers in combination + // do not handle this pattern of extraction very well. + const float* rhs_data = rhs_ptr; + lhs_ptr += 16; + rhs_ptr += 16; + + { + const __m512 dup_rhs_element_j0 = _mm512_set1_ps(rhs_data[0]); + accum_data_v0 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0); + const __m512 dup_rhs_element_j1 = _mm512_set1_ps(rhs_data[1]); + accum_data_v1 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1); + const __m512 dup_rhs_element_j2 = _mm512_set1_ps(rhs_data[2]); + accum_data_v2 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2); + const __m512 dup_rhs_element_j3 = _mm512_set1_ps(rhs_data[3]); + accum_data_v3 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3); + const __m512 dup_rhs_element_j4 = _mm512_set1_ps(rhs_data[4]); + accum_data_v4 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4); + const __m512 dup_rhs_element_j5 = _mm512_set1_ps(rhs_data[5]); + accum_data_v5 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5); + const __m512 dup_rhs_element_j6 = _mm512_set1_ps(rhs_data[6]); + accum_data_v6 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6); + const __m512 dup_rhs_element_j7 = _mm512_set1_ps(rhs_data[7]); + accum_data_v7 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7); + } + } + { + const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr); + const float* rhs_data = rhs_ptr; + { + const __m512 dup_rhs_element_j0 = _mm512_set1_ps(rhs_data[0]); + accum_data_v0 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0); + const __m512 dup_rhs_element_j1 = _mm512_set1_ps(rhs_data[1]); + accum_data_v1 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1); + const __m512 dup_rhs_element_j2 = _mm512_set1_ps(rhs_data[2]); + accum_data_v2 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2); + const __m512 dup_rhs_element_j3 = _mm512_set1_ps(rhs_data[3]); + accum_data_v3 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3); + const __m512 dup_rhs_element_j4 = _mm512_set1_ps(rhs_data[4]); + accum_data_v4 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4); + const __m512 dup_rhs_element_j5 = _mm512_set1_ps(rhs_data[5]); + accum_data_v5 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5); + const __m512 dup_rhs_element_j6 = _mm512_set1_ps(rhs_data[6]); + accum_data_v6 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6); + const __m512 dup_rhs_element_j7 = _mm512_set1_ps(rhs_data[7]); + accum_data_v7 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7); + } + { + float* block_ptr = dst_ptr + (mmm * 8 + 0) * dst_stride; + accum_data_v0 = _mm512_min_ps(accum_data_v0, clamp_max_v); + accum_data_v0 = _mm512_max_ps(accum_data_v0, clamp_min_v); + _mm512_storeu_ps(block_ptr + 0 * dst_stride, accum_data_v0); + accum_data_v1 = _mm512_min_ps(accum_data_v1, clamp_max_v); + accum_data_v1 = _mm512_max_ps(accum_data_v1, clamp_min_v); + _mm512_storeu_ps(block_ptr + 1 * dst_stride, accum_data_v1); + accum_data_v2 = _mm512_min_ps(accum_data_v2, clamp_max_v); + accum_data_v2 = _mm512_max_ps(accum_data_v2, clamp_min_v); + _mm512_storeu_ps(block_ptr + 2 * dst_stride, accum_data_v2); + accum_data_v3 = _mm512_min_ps(accum_data_v3, clamp_max_v); + accum_data_v3 = _mm512_max_ps(accum_data_v3, clamp_min_v); + _mm512_storeu_ps(block_ptr + 3 * dst_stride, accum_data_v3); + accum_data_v4 = _mm512_min_ps(accum_data_v4, clamp_max_v); + accum_data_v4 = _mm512_max_ps(accum_data_v4, clamp_min_v); + _mm512_storeu_ps(block_ptr + 4 * dst_stride, accum_data_v4); + accum_data_v5 = _mm512_min_ps(accum_data_v5, clamp_max_v); + accum_data_v5 = _mm512_max_ps(accum_data_v5, clamp_min_v); + _mm512_storeu_ps(block_ptr + 5 * dst_stride, accum_data_v5); + accum_data_v6 = _mm512_min_ps(accum_data_v6, clamp_max_v); + accum_data_v6 = _mm512_max_ps(accum_data_v6, clamp_min_v); + _mm512_storeu_ps(block_ptr + 6 * dst_stride, accum_data_v6); + accum_data_v7 = _mm512_min_ps(accum_data_v7, clamp_max_v); + accum_data_v7 = _mm512_max_ps(accum_data_v7, clamp_min_v); + _mm512_storeu_ps(block_ptr + 7 * dst_stride, accum_data_v7); + } + } + } // Inner half-block loop, unrolled, first iteration. + { + constexpr int mmm = 1; + + __m512 accum_data_v0 = initial_accum_data; + __m512 accum_data_v1 = initial_accum_data; + __m512 accum_data_v2 = initial_accum_data; + __m512 accum_data_v3 = initial_accum_data; + __m512 accum_data_v4 = initial_accum_data; + __m512 accum_data_v5 = initial_accum_data; + __m512 accum_data_v6 = initial_accum_data; + __m512 accum_data_v7 = initial_accum_data; + + const float* lhs_ptr = lhs_col_ptr; + const float* rhs_ptr = rhs_col_ptr + 8 * mmm; + for (int d = 0; d < (params.depth - 1); ++d) { + const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr); + const float* rhs_data = rhs_ptr; + lhs_ptr += 16; + rhs_ptr += 16; + { + const __m512 dup_rhs_element_j0 = _mm512_set1_ps(rhs_data[0]); + accum_data_v0 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0); + const __m512 dup_rhs_element_j1 = _mm512_set1_ps(rhs_data[1]); + accum_data_v1 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1); + const __m512 dup_rhs_element_j2 = _mm512_set1_ps(rhs_data[2]); + accum_data_v2 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2); + const __m512 dup_rhs_element_j3 = _mm512_set1_ps(rhs_data[3]); + accum_data_v3 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3); + const __m512 dup_rhs_element_j4 = _mm512_set1_ps(rhs_data[4]); + accum_data_v4 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4); + const __m512 dup_rhs_element_j5 = _mm512_set1_ps(rhs_data[5]); + accum_data_v5 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5); + const __m512 dup_rhs_element_j6 = _mm512_set1_ps(rhs_data[6]); + accum_data_v6 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6); + const __m512 dup_rhs_element_j7 = _mm512_set1_ps(rhs_data[7]); + accum_data_v7 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7); + } + } + { + const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr); + const float* rhs_data = rhs_ptr; + { + const __m512 dup_rhs_element_j0 = _mm512_set1_ps(rhs_data[0]); + accum_data_v0 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0); + const __m512 dup_rhs_element_j1 = _mm512_set1_ps(rhs_data[1]); + accum_data_v1 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1); + const __m512 dup_rhs_element_j2 = _mm512_set1_ps(rhs_data[2]); + accum_data_v2 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2); + const __m512 dup_rhs_element_j3 = _mm512_set1_ps(rhs_data[3]); + accum_data_v3 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3); + const __m512 dup_rhs_element_j4 = _mm512_set1_ps(rhs_data[4]); + accum_data_v4 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4); + const __m512 dup_rhs_element_j5 = _mm512_set1_ps(rhs_data[5]); + accum_data_v5 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5); + const __m512 dup_rhs_element_j6 = _mm512_set1_ps(rhs_data[6]); + accum_data_v6 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6); + const __m512 dup_rhs_element_j7 = _mm512_set1_ps(rhs_data[7]); + accum_data_v7 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7); + } + { + float* block_ptr = dst_ptr + (mmm * 8 + 0) * dst_stride; + accum_data_v0 = _mm512_min_ps(accum_data_v0, clamp_max_v); + accum_data_v0 = _mm512_max_ps(accum_data_v0, clamp_min_v); + _mm512_storeu_ps(block_ptr + 0 * dst_stride, accum_data_v0); + accum_data_v1 = _mm512_min_ps(accum_data_v1, clamp_max_v); + accum_data_v1 = _mm512_max_ps(accum_data_v1, clamp_min_v); + _mm512_storeu_ps(block_ptr + 1 * dst_stride, accum_data_v1); + accum_data_v2 = _mm512_min_ps(accum_data_v2, clamp_max_v); + accum_data_v2 = _mm512_max_ps(accum_data_v2, clamp_min_v); + _mm512_storeu_ps(block_ptr + 2 * dst_stride, accum_data_v2); + accum_data_v3 = _mm512_min_ps(accum_data_v3, clamp_max_v); + accum_data_v3 = _mm512_max_ps(accum_data_v3, clamp_min_v); + _mm512_storeu_ps(block_ptr + 3 * dst_stride, accum_data_v3); + accum_data_v4 = _mm512_min_ps(accum_data_v4, clamp_max_v); + accum_data_v4 = _mm512_max_ps(accum_data_v4, clamp_min_v); + _mm512_storeu_ps(block_ptr + 4 * dst_stride, accum_data_v4); + accum_data_v5 = _mm512_min_ps(accum_data_v5, clamp_max_v); + accum_data_v5 = _mm512_max_ps(accum_data_v5, clamp_min_v); + _mm512_storeu_ps(block_ptr + 5 * dst_stride, accum_data_v5); + accum_data_v6 = _mm512_min_ps(accum_data_v6, clamp_max_v); + accum_data_v6 = _mm512_max_ps(accum_data_v6, clamp_min_v); + _mm512_storeu_ps(block_ptr + 6 * dst_stride, accum_data_v6); + accum_data_v7 = _mm512_min_ps(accum_data_v7, clamp_max_v); + accum_data_v7 = _mm512_max_ps(accum_data_v7, clamp_min_v); + _mm512_storeu_ps(block_ptr + 7 * dst_stride, accum_data_v7); + } + } + } // Inner half-block loop, unrolled, second iteration. + } // End row-block loop. + + // The unrolling within this conditional may be somewhat pointless. It + // depends on the kinds of models. + if (row < end_row) { + const int residual_rows = end_row - row; + + const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; + float* dst_ptr = dst_col_ptr + row; + const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment; + + // Initialize with bias. + const __mmask16 row_mask = + (static_cast(1) << residual_rows) - 1; + const __m512 initial_accum_data = + _mm512_maskz_loadu_ps(row_mask, bias_ptr); + + // Process block in two halves, split by columns. + for (int mmm = 0; mmm < 2; ++mmm) { + __m512 accum_data_v0 = initial_accum_data; + __m512 accum_data_v1 = initial_accum_data; + __m512 accum_data_v2 = initial_accum_data; + __m512 accum_data_v3 = initial_accum_data; + __m512 accum_data_v4 = initial_accum_data; + __m512 accum_data_v5 = initial_accum_data; + __m512 accum_data_v6 = initial_accum_data; + __m512 accum_data_v7 = initial_accum_data; + + const float* lhs_ptr = lhs_col_ptr; + const float* rhs_ptr = rhs_col_ptr + 8 * mmm; + for (int d = 0; d < (params.depth - 1); ++d) { + const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr); + const float* rhs_data = rhs_ptr; + lhs_ptr += 16; + rhs_ptr += 16; + { + const __m512 dup_rhs_element_j0 = _mm512_set1_ps(rhs_data[0]); + accum_data_v0 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0); + const __m512 dup_rhs_element_j1 = _mm512_set1_ps(rhs_data[1]); + accum_data_v1 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1); + const __m512 dup_rhs_element_j2 = _mm512_set1_ps(rhs_data[2]); + accum_data_v2 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2); + const __m512 dup_rhs_element_j3 = _mm512_set1_ps(rhs_data[3]); + accum_data_v3 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3); + const __m512 dup_rhs_element_j4 = _mm512_set1_ps(rhs_data[4]); + accum_data_v4 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4); + const __m512 dup_rhs_element_j5 = _mm512_set1_ps(rhs_data[5]); + accum_data_v5 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5); + const __m512 dup_rhs_element_j6 = _mm512_set1_ps(rhs_data[6]); + accum_data_v6 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6); + const __m512 dup_rhs_element_j7 = _mm512_set1_ps(rhs_data[7]); + accum_data_v7 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7); + } + } + { + const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr); + const float* rhs_data = rhs_ptr; + { + const __m512 dup_rhs_element_j0 = _mm512_set1_ps(rhs_data[0]); + accum_data_v0 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0); + const __m512 dup_rhs_element_j1 = _mm512_set1_ps(rhs_data[1]); + accum_data_v1 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1); + const __m512 dup_rhs_element_j2 = _mm512_set1_ps(rhs_data[2]); + accum_data_v2 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2); + const __m512 dup_rhs_element_j3 = _mm512_set1_ps(rhs_data[3]); + accum_data_v3 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3); + const __m512 dup_rhs_element_j4 = _mm512_set1_ps(rhs_data[4]); + accum_data_v4 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4); + const __m512 dup_rhs_element_j5 = _mm512_set1_ps(rhs_data[5]); + accum_data_v5 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5); + const __m512 dup_rhs_element_j6 = _mm512_set1_ps(rhs_data[6]); + accum_data_v6 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6); + const __m512 dup_rhs_element_j7 = _mm512_set1_ps(rhs_data[7]); + accum_data_v7 = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7); + } + { + float* block_ptr = dst_ptr + (mmm * 8 + 0) * dst_stride; + accum_data_v0 = _mm512_min_ps(accum_data_v0, clamp_max_v); + accum_data_v0 = _mm512_max_ps(accum_data_v0, clamp_min_v); + _mm512_mask_storeu_ps(block_ptr + 0 * dst_stride, row_mask, + accum_data_v0); + accum_data_v1 = _mm512_min_ps(accum_data_v1, clamp_max_v); + accum_data_v1 = _mm512_max_ps(accum_data_v1, clamp_min_v); + _mm512_mask_storeu_ps(block_ptr + 1 * dst_stride, row_mask, + accum_data_v1); + accum_data_v2 = _mm512_min_ps(accum_data_v2, clamp_max_v); + accum_data_v2 = _mm512_max_ps(accum_data_v2, clamp_min_v); + _mm512_mask_storeu_ps(block_ptr + 2 * dst_stride, row_mask, + accum_data_v2); + accum_data_v3 = _mm512_min_ps(accum_data_v3, clamp_max_v); + accum_data_v3 = _mm512_max_ps(accum_data_v3, clamp_min_v); + _mm512_mask_storeu_ps(block_ptr + 3 * dst_stride, row_mask, + accum_data_v3); + accum_data_v4 = _mm512_min_ps(accum_data_v4, clamp_max_v); + accum_data_v4 = _mm512_max_ps(accum_data_v4, clamp_min_v); + _mm512_mask_storeu_ps(block_ptr + 4 * dst_stride, row_mask, + accum_data_v4); + accum_data_v5 = _mm512_min_ps(accum_data_v5, clamp_max_v); + accum_data_v5 = _mm512_max_ps(accum_data_v5, clamp_min_v); + _mm512_mask_storeu_ps(block_ptr + 5 * dst_stride, row_mask, + accum_data_v5); + accum_data_v6 = _mm512_min_ps(accum_data_v6, clamp_max_v); + accum_data_v6 = _mm512_max_ps(accum_data_v6, clamp_min_v); + _mm512_mask_storeu_ps(block_ptr + 6 * dst_stride, row_mask, + accum_data_v6); + accum_data_v7 = _mm512_min_ps(accum_data_v7, clamp_max_v); + accum_data_v7 = _mm512_max_ps(accum_data_v7, clamp_min_v); + _mm512_mask_storeu_ps(block_ptr + 7 * dst_stride, row_mask, + accum_data_v7); + } + } + } // Inner half-block loop. + } // Residual rows, main col-block loop. + } // End col-block loop. + + if (col < end_col) { + RUY_DCHECK_GE(end_col - col, 0); + RUY_DCHECK_LT(end_col - col, 16); + + __m512 accum_data_v[8]; + + const float* rhs_col_ptr = adj_rhs_col_ptr + col * rhs_stride; + float* dst_col_ptr = adj_dst_col_ptr + col * dst_stride; + + for (int row = params.start_row; row < end_row; row += 16) { + const int residual_rows = std::min(end_row - row, 16); + + const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; + float* dst_ptr = dst_col_ptr + row; + const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment; + + // Initialize with bias. + const __mmask16 row_mask = + (static_cast(1) << residual_rows) - 1; + const __m512 initial_accum_data = + _mm512_maskz_loadu_ps(row_mask, bias_ptr); + + // Process block in two halves, split by columns. + for (int mmm = 0; mmm < 2; ++mmm) { + for (int j = 0; j < 8; ++j) { + accum_data_v[j] = initial_accum_data; + } + + const float* lhs_ptr = lhs_col_ptr; + const float* rhs_ptr = rhs_col_ptr + 8 * mmm; + for (int d = 0; d < params.depth; ++d) { + const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr); + const float* rhs_data = rhs_ptr; + + for (int j = 0; j < 8; ++j) { + const __m512 dup_rhs_element_j = _mm512_set1_ps(rhs_data[j]); + accum_data_v[j] = + _mm512_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v[j]); + } + lhs_ptr += 16; + rhs_ptr += 16; + } + + const int residual_cols = std::min(end_col - col - 8 * mmm, 8); + + if (residual_rows == 16) { + if (residual_cols == 8) { + for (int j = 0; j < 8; ++j) { + float* block_ptr = dst_ptr + (mmm * 8 + j) * dst_stride; + accum_data_v[j] = _mm512_min_ps(accum_data_v[j], clamp_max_v); + accum_data_v[j] = _mm512_max_ps(accum_data_v[j], clamp_min_v); + _mm512_storeu_ps(block_ptr, accum_data_v[j]); + } + } else { + for (int j = 0; j < residual_cols; ++j) { + float* block_ptr = dst_ptr + (mmm * 8 + j) * dst_stride; + accum_data_v[j] = _mm512_min_ps(accum_data_v[j], clamp_max_v); + accum_data_v[j] = _mm512_max_ps(accum_data_v[j], clamp_min_v); + _mm512_storeu_ps(block_ptr, accum_data_v[j]); + } + } + } else { + for (int j = 0; j < residual_cols; ++j) { + float* block_ptr = dst_ptr + (mmm * 8 + j) * dst_stride; + accum_data_v[j] = _mm512_min_ps(accum_data_v[j], clamp_max_v); + accum_data_v[j] = _mm512_max_ps(accum_data_v[j], clamp_min_v); + _mm512_mask_storeu_ps(block_ptr, row_mask, accum_data_v[j]); + } + } + } // Inner half-block loop. + } // End row-block loop. + } // Residual cols. +} + +void KernelFloatAvx512SingleCol(const KernelParamsFloat<16, 16>& params) { + profiler::ScopeLabel label("Kernel kAvx512 float GEMV"); + + RUY_DCHECK_EQ(params.dst_cols, 1); + RUY_DCHECK_EQ(params.last_col, 0); + RUY_DCHECK_EQ(params.start_col, 0); + + // As parameters are defined, we need to scale by sizeof(float). + const std::int64_t lhs_stride = params.lhs_stride >> 2; + + int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 1 : 0; + const int end_row = std::min(params.dst_rows, params.last_row + 16); + + float* adj_dst_col_ptr = params.dst_base_ptr - params.start_row; + const float* adj_lhs_col_ptr = + params.lhs_base_ptr - params.start_row * lhs_stride; + const float* bias_col_ptr = params.bias; + + const __m512 clamp_max_v = _mm512_set1_ps(params.clamp_max); + const __m512 clamp_min_v = _mm512_set1_ps(params.clamp_min); + + __m512 accum_data_v; + + const float* rhs_col_ptr = params.rhs_base_ptr; + float* dst_col_ptr = adj_dst_col_ptr; + + int row = params.start_row; + for (; row <= end_row - 16; row += 16) { + const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; + float* dst_ptr = dst_col_ptr + row; + const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment; + + // Initialize with bias. + accum_data_v = _mm512_loadu_ps(bias_ptr); + + const float* lhs_ptr = lhs_col_ptr; + const float* rhs_ptr = rhs_col_ptr; + for (int d = 0; d < params.depth; ++d) { + const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr); + const float rhs_data = *rhs_ptr; + + const __m512 dup_rhs_element_j = _mm512_set1_ps(rhs_data); + accum_data_v = _mm512_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v); + lhs_ptr += 16; + rhs_ptr += 16; + } + + accum_data_v = _mm512_min_ps(accum_data_v, clamp_max_v); + accum_data_v = _mm512_max_ps(accum_data_v, clamp_min_v); + _mm512_storeu_ps(dst_ptr, accum_data_v); + } // End row-block loop. + + if (row < end_row) { + const int residual_rows = end_row - row; + RUY_CHECK_GE(residual_rows, 1); + RUY_CHECK_LT(residual_rows, 16); + + const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; + float* dst_ptr = dst_col_ptr + row; + const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment; + + // Initialize with bias. + const __mmask16 row_mask = + (static_cast(1) << residual_rows) - 1; + accum_data_v = _mm512_maskz_loadu_ps(row_mask, bias_ptr); + + const float* lhs_ptr = lhs_col_ptr; + const float* rhs_ptr = rhs_col_ptr; + for (int d = 0; d < params.depth; ++d) { + const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr); + const float rhs_data = *rhs_ptr; + + const __m512 dup_rhs_element_j = _mm512_set1_ps(rhs_data); + accum_data_v = _mm512_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v); + lhs_ptr += 16; + rhs_ptr += 16; + } + + accum_data_v = _mm512_min_ps(accum_data_v, clamp_max_v); + accum_data_v = _mm512_max_ps(accum_data_v, clamp_min_v); + _mm512_mask_storeu_ps(dst_ptr, row_mask, accum_data_v); + } // End handling of residual rows. +} + +#endif // RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_ASM) + +} // namespace ruy diff --git a/ruy/kernel_avxvnni.cc b/ruy/kernel_avxvnni.cc new file mode 100644 index 0000000..4513b20 --- /dev/null +++ b/ruy/kernel_avxvnni.cc @@ -0,0 +1,435 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "ruy/check_macros.h" +#include "ruy/kernel.h" +#include "ruy/opt_set.h" +#include "ruy/platform.h" +#include "ruy/profiler/instrumentation.h" + +#if RUY_PLATFORM(AVX_VNNI) && RUY_OPT_ENABLED(RUY_OPT_ASM) +#include // IWYU pragma: keep +#endif + +namespace ruy { + +#if !(RUY_PLATFORM(AVX_VNNI) && RUY_OPT_ENABLED(RUY_OPT_ASM)) + +void Kernel8bitAvxVnni(const KernelParams8bit<16, 16>& params) { + // CPU-ID-based checks should disable the path that would reach this point. + RUY_DCHECK(false); +} + +void KernelFloatAvxVnni(const KernelParamsFloat<16, 16>& params) { + // CPU-ID-based checks should disable the path that would reach this point. + RUY_DCHECK(false); +} + +#else // RUY_PLATFORM(AVX_VNNI) && RUY_OPT_ENABLED(RUY_OPT_ASM) + +static constexpr int kAvxFloatBlockSize = 16; +static constexpr int kAvx8bitBlockSize = 16; +static constexpr int kAvx8bitInnerSize = 4; + +// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. +// Optimization is not finished. In particular the dimensions of the kernel +// blocks can be changed as desired. +// +// When removing this comment, update profiling label below. +void Kernel8bitAvxVnni(const KernelParams8bit<16, 16>& params) { + profiler::ScopeLabel label("Kernel kAvxVnni 8-bit (UNFINISHED)"); + + std::int32_t accum_data[kAvx8bitBlockSize][kAvx8bitBlockSize]; + + int bias_ptr_block_increment = + params.flags & RUY_ASM_FLAG_HAS_BIAS ? kAvx8bitBlockSize : 0; + + const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; + void* dst_col_ptr = params.dst_base_ptr; + const std::int32_t* bias_col_ptr = params.bias; + if (params.flags & RUY_ASM_FLAG_HAS_BIAS) { + bias_col_ptr += params.start_row; + } + + for (int col = params.start_col; col <= params.last_col; + col += kAvx8bitBlockSize) { + const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; + void* dst_ptr = dst_col_ptr; + const std::int32_t* bias_ptr = bias_col_ptr; + + for (int row = params.start_row; row <= params.last_row; + row += kAvx8bitBlockSize) { + const int residual_rows = + std::min(params.dst_rows - row, kAvx8bitBlockSize); + const int residual_cols = + std::min(params.dst_cols - col, kAvx8bitBlockSize); + + // Initialize with bias. + std::int32_t initial_accum_data[kAvx8bitBlockSize]; + for (int i = 0; i < kAvx8bitBlockSize; ++i) { + initial_accum_data[i] = 0; + } + for (int i = 0; i < residual_rows; ++i) { + initial_accum_data[i] = bias_ptr[i]; + } + + for (int j = 0; j < kAvx8bitBlockSize; ++j) { + for (int i = 0; i < kAvx8bitBlockSize; ++i) { + accum_data[j][i] = initial_accum_data[i]; + } + } + bias_ptr += bias_ptr_block_increment; + + std::int8_t lhs_data[kAvx8bitBlockSize][kAvx8bitInnerSize]; + std::int8_t rhs_data[kAvx8bitBlockSize][kAvx8bitInnerSize]; + const std::int8_t* lhs_ptr = lhs_col_ptr; + const std::int8_t* rhs_ptr = rhs_col_ptr; + for (int d = 0; d < params.depth; d += kAvx8bitInnerSize) { + for (int i = 0; i < kAvx8bitBlockSize; ++i) { + for (int x = 0; x < kAvx8bitInnerSize; ++x) { + lhs_data[i][x] = lhs_ptr[i * kAvx8bitInnerSize + x]; + rhs_data[i][x] = rhs_ptr[i * kAvx8bitInnerSize + x]; + } + } + + for (int j = 0; j < kAvx8bitBlockSize; ++j) { + for (int i = 0; i < kAvx8bitBlockSize; ++i) { + for (int x = 0; x < kAvx8bitInnerSize; ++x) { + accum_data[j][i] += lhs_data[i][x] * rhs_data[j][x]; + } + } + } + + lhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize; + rhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize; + } + + if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && params.rhs_zero_point) { + for (int j = 0; j < kAvx8bitBlockSize; ++j) { + for (int i = 0; i < kAvx8bitBlockSize; ++i) { + accum_data[j][i] -= + params.rhs_zero_point * params.lhs_sums[row + i]; + } + } + } + if ((params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && params.lhs_zero_point) { + for (int j = 0; j < kAvx8bitBlockSize; ++j) { + for (int i = 0; i < kAvx8bitBlockSize; ++i) { + accum_data[j][i] -= + params.lhs_zero_point * params.rhs_sums[col + j]; + } + } + } + if (params.lhs_zero_point && params.rhs_zero_point) { + for (int j = 0; j < kAvx8bitBlockSize; ++j) { + for (int i = 0; i < kAvx8bitBlockSize; ++i) { + accum_data[j][i] += params.prod_zp_depth; + } + } + } + + if (params.dst_type_id != DstTypeId::kValue) { + std::int32_t m_vector[kAvx8bitBlockSize]; + std::int32_t e_vector[kAvx8bitBlockSize]; + // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT. + if (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) { + int i = 0; + for (; i < residual_rows; ++i) { + m_vector[i] = params.multiplier_fixedpoint[row + i]; + e_vector[i] = params.multiplier_exponent[row + i]; + } + for (; i < kAvx8bitBlockSize; ++i) { + m_vector[i] = m_vector[0]; + e_vector[i] = e_vector[0]; + } + } else { + // These arrays have size LhsCols, and are pre-filled. + for (int i = 0; i < kAvx8bitBlockSize; ++i) { + m_vector[i] = params.multiplier_fixedpoint[i]; + e_vector[i] = params.multiplier_exponent[i]; + } + } + + for (int j = 0; j < kAvx8bitBlockSize; ++j) { + for (int i = 0; i < kAvx8bitBlockSize; ++i) { + accum_data[j][i] = MultiplyByQuantizedMultiplier( + accum_data[j][i], m_vector[i], e_vector[i]); + } + } + + if (params.dst_zero_point) { + for (int j = 0; j < kAvx8bitBlockSize; ++j) { + for (int i = 0; i < kAvx8bitBlockSize; ++i) { + accum_data[j][i] += params.dst_zero_point; + } + } + } + + for (int j = 0; j < kAvx8bitBlockSize; ++j) { + for (int i = 0; i < kAvx8bitBlockSize; ++i) { + accum_data[j][i] = + std::min(accum_data[j][i], params.clamp_max); + accum_data[j][i] = + std::max(accum_data[j][i], params.clamp_min); + } + } + } + + const bool store_full_block = (residual_rows == kAvx8bitBlockSize) && + (residual_cols == kAvx8bitBlockSize); + + if (params.dst_type_id == DstTypeId::kValue) { + std::int8_t* tmp_ptr = + store_full_block + ? static_cast(dst_ptr) + : const_cast( + reinterpret_cast(params.dst_tmp_buf)); + const int block_col_offset = + store_full_block ? params.dst_stride / sizeof(std::int8_t) + : kAvx8bitBlockSize; + for (int j = 0; j < kAvx8bitBlockSize; ++j) { + for (int i = 0; i < kAvx8bitBlockSize; ++i) { + tmp_ptr[i] = accum_data[j][i]; + } + tmp_ptr += block_col_offset; + } + + if (!store_full_block) { + const std::int8_t* block_ptr = + reinterpret_cast(params.dst_tmp_buf); + for (int j = 0; j < residual_cols; ++j) { + for (int i = 0; i < residual_rows; ++i) { + static_cast( + dst_ptr)[j * params.dst_stride / sizeof(std::int8_t) + i] = + block_ptr[i]; + } + block_ptr += kAvx8bitBlockSize; + } + } + dst_ptr = static_cast(static_cast(dst_ptr) + + kAvx8bitBlockSize); + } else if (params.dst_type_id == DstTypeId::kValue) { + std::uint8_t* tmp_ptr = store_full_block + ? static_cast(dst_ptr) + : const_cast( + reinterpret_cast( + params.dst_tmp_buf)); + const int block_col_offset = + store_full_block ? params.dst_stride : kAvx8bitBlockSize; + for (int j = 0; j < kAvx8bitBlockSize; ++j) { + for (int i = 0; i < kAvx8bitBlockSize; ++i) { + tmp_ptr[i] = accum_data[j][i]; + } + tmp_ptr += block_col_offset; + } + + if (!store_full_block) { + const std::uint8_t* block_ptr = + reinterpret_cast(params.dst_tmp_buf); + for (int j = 0; j < residual_cols; ++j) { + for (int i = 0; i < residual_rows; ++i) { + static_cast( + dst_ptr)[j * params.dst_stride / sizeof(std::uint8_t) + i] = + block_ptr[i]; + } + block_ptr += kAvx8bitBlockSize; + } + } + dst_ptr = static_cast(static_cast(dst_ptr) + + kAvx8bitBlockSize); + } else if (params.dst_type_id == DstTypeId::kValue) { + if (store_full_block) { + std::int16_t* tmp_ptr = static_cast(dst_ptr); + const int block_col_offset = params.dst_stride / sizeof(std::int16_t); + for (int j = 0; j < kAvx8bitBlockSize; ++j) { + for (int i = 0; i < kAvx8bitBlockSize; ++i) { + tmp_ptr[i] = accum_data[j][i]; + } + tmp_ptr += block_col_offset; + } + } else { + std::int16_t* tmp_ptr = const_cast( + reinterpret_cast(params.dst_tmp_buf)); + const int block_col_offset = kAvx8bitBlockSize; + for (int j = 0; j < kAvx8bitBlockSize; ++j) { + for (int i = 0; i < kAvx8bitBlockSize; ++i) { + tmp_ptr[i] = accum_data[j][i]; + } + tmp_ptr += block_col_offset; + } + const std::int16_t* block_ptr = + reinterpret_cast(params.dst_tmp_buf); + std::int16_t* dst_block_ptr = static_cast(dst_ptr); + for (int j = 0; j < residual_cols; ++j) { + for (int i = 0; i < residual_rows; ++i) { + dst_block_ptr[i] = block_ptr[i]; + } + dst_block_ptr += params.dst_stride / sizeof(std::int16_t); + block_ptr += kAvx8bitBlockSize; + } + } + dst_ptr = static_cast(static_cast(dst_ptr) + + kAvx8bitBlockSize); + } else if (params.dst_type_id == DstTypeId::kValue) { + if (store_full_block) { + std::int32_t* tmp_ptr = static_cast(dst_ptr); + const int block_col_offset = params.dst_stride / sizeof(std::int32_t); + for (int j = 0; j < kAvx8bitBlockSize; ++j) { + for (int i = 0; i < kAvx8bitBlockSize; ++i) { + tmp_ptr[i] = accum_data[j][i]; + } + tmp_ptr += block_col_offset; + } + } else { + std::int32_t* dst_block_ptr = static_cast(dst_ptr); + for (int j = 0; j < residual_cols; ++j) { + for (int i = 0; i < residual_rows; ++i) { + dst_block_ptr[i] = accum_data[j][i]; + } + dst_block_ptr += params.dst_stride / sizeof(std::int32_t); + } + } + dst_ptr = static_cast(static_cast(dst_ptr) + + kAvx8bitBlockSize); + } else { + RUY_DCHECK(false); + } + + lhs_col_ptr += kAvx8bitBlockSize * params.lhs_stride; + } // End row-block loop. + + dst_col_ptr = static_cast(static_cast(dst_col_ptr) + + kAvx8bitBlockSize * params.dst_stride); + rhs_col_ptr += kAvx8bitBlockSize * params.rhs_stride; + } // End col-block loop. +} // NOLINT(readability/fn_size) + +// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. +// Optimization is not finished. In particular the dimensions of the kernel +// blocks can be changed as desired. +// +// When removing this comment, update profiling label below. +void KernelFloatAvxVnni(const KernelParamsFloat<16, 16>& params) { + profiler::ScopeLabel label("Kernel kAvxVnni float (UNFINISHED)"); + + float lhs_data[kAvxFloatBlockSize]; + float rhs_data[kAvxFloatBlockSize]; + float accum_data[kAvxFloatBlockSize][kAvxFloatBlockSize]; + int bias_ptr_block_increment = + params.flags & RUY_ASM_FLAG_HAS_BIAS ? kAvxFloatBlockSize : 0; + + const float* rhs_col_ptr = params.rhs_base_ptr; + float* dst_col_ptr = params.dst_base_ptr; + const float* bias_col_ptr = params.bias; + if (params.flags & RUY_ASM_FLAG_HAS_BIAS) { + bias_col_ptr += params.start_row; + } + + for (int col = params.start_col; col <= params.last_col; + col += kAvxFloatBlockSize) { + const float* lhs_col_ptr = params.lhs_base_ptr; + float* dst_ptr = dst_col_ptr; + const float* bias_ptr = bias_col_ptr; + + for (int row = params.start_row; row <= params.last_row; + row += kAvxFloatBlockSize) { + const int residual_rows = + std::min(params.dst_rows - row, kAvxFloatBlockSize); + const int residual_cols = + std::min(params.dst_cols - col, kAvxFloatBlockSize); + + // Initialize with bias. + float initial_accum_data[kAvxFloatBlockSize]; + for (int i = 0; i < kAvxFloatBlockSize; ++i) { + initial_accum_data[i] = 0.0f; + } + for (int i = 0; i < residual_rows; ++i) { + initial_accum_data[i] = bias_ptr[i]; + } + for (int j = 0; j < kAvxFloatBlockSize; ++j) { + for (int i = 0; i < kAvxFloatBlockSize; ++i) { + accum_data[j][i] = initial_accum_data[i]; + } + } + bias_ptr += bias_ptr_block_increment; + + const float* lhs_ptr = lhs_col_ptr; + const float* rhs_ptr = rhs_col_ptr; + for (int d = 0; d < params.depth; ++d) { + for (int i = 0; i < kAvxFloatBlockSize; ++i) { + lhs_data[i] = lhs_ptr[i]; + rhs_data[i] = rhs_ptr[i]; + } + + for (int j = 0; j < kAvxFloatBlockSize; ++j) { + for (int i = 0; i < kAvxFloatBlockSize; ++i) { + accum_data[j][i] += lhs_data[i] * rhs_data[j]; + } + } + + lhs_ptr += kAvxFloatBlockSize; + rhs_ptr += kAvxFloatBlockSize; + } + + for (int j = 0; j < kAvxFloatBlockSize; ++j) { + for (int i = 0; i < kAvxFloatBlockSize; ++i) { + accum_data[j][i] = + std::min(accum_data[j][i], params.clamp_max); + accum_data[j][i] = + std::max(accum_data[j][i], params.clamp_min); + } + } + + const bool store_full_block = (residual_rows == kAvxFloatBlockSize) && + (residual_cols == kAvxFloatBlockSize); + + { + float* block_ptr = + store_full_block ? dst_ptr : const_cast(params.dst_tmp_buf); + const int block_col_offset = store_full_block + ? params.dst_stride / sizeof(float) + : kAvxFloatBlockSize; + for (int j = 0; j < kAvxFloatBlockSize; ++j) { + for (int i = 0; i < kAvxFloatBlockSize; ++i) { + block_ptr[i] = accum_data[j][i]; + } + block_ptr += block_col_offset; + } + } + if (!store_full_block) { + const float* block_ptr = params.dst_tmp_buf; + for (int j = 0; j < residual_cols; ++j) { + for (int i = 0; i < residual_rows; ++i) { + dst_ptr[j * params.dst_stride / sizeof(float) + i] = block_ptr[i]; + } + block_ptr += kAvxFloatBlockSize; + } + } + + lhs_col_ptr += kAvxFloatBlockSize * params.lhs_stride / sizeof(float); + dst_ptr += kAvxFloatBlockSize; + } // End row-block loop. + + dst_col_ptr += kAvxFloatBlockSize * params.dst_stride / sizeof(float); + rhs_col_ptr += kAvxFloatBlockSize * params.rhs_stride / sizeof(float); + } // End col-block loop. +} + +#endif // RUY_PLATFORM(AVX_VNNI) && RUY_OPT_ENABLED(RUY_OPT_ASM) + +} // namespace ruy diff --git a/ruy/kernel_common.h b/ruy/kernel_common.h new file mode 100644 index 0000000..0cd123f --- /dev/null +++ b/ruy/kernel_common.h @@ -0,0 +1,481 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_KERNEL_COMMON_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_KERNEL_COMMON_H_ + +#include +#include +#include + +#include "ruy/check_macros.h" +#include "ruy/common.h" +#include "ruy/internal_matrix.h" +#include "ruy/matrix.h" +#include "ruy/opt_set.h" +#include "ruy/path.h" +#include "ruy/platform.h" +#include "ruy/profiler/instrumentation.h" +#include "ruy/side_pair.h" +#include "ruy/size_util.h" +#include "ruy/spec.h" +#include "ruy/tune.h" + +namespace ruy { + +template +struct Kernel {}; + +template +void RunKernelTyped(Tuning tuning, const PackedMatrix& lhs, + const PackedMatrix& rhs, const Spec& spec, + int start_row, int start_col, int end_row, int end_col, + Matrix* dst) { + using Kernel = Kernel; + Kernel kernel(tuning); +#if !defined(NDEBUG) || !RUY_OPT_ENABLED(RUY_OPT_FAT_KERNEL) + using LhsLayout = typename Kernel::LhsLayout; + using RhsLayout = typename Kernel::RhsLayout; +#endif + // end_row and end_col may be larger than dst dimensions. + // that is because kernels write directly to the destination matrix, whose + // dimensions may not be a multiple of the kernel dimensions, and we try to + // keep this annoyance localized as an implementation detail in kernels, + // by allowing to pass rounded-up values down as far as possible. + // These assertions encode the contract. + RUY_DCHECK_LE(0, start_row); + RUY_DCHECK_LE(start_row, end_row); + RUY_DCHECK_LT(end_row, dst->layout.rows + LhsLayout::kCols); + RUY_DCHECK_EQ((end_row - start_row) % LhsLayout::kCols, 0); + RUY_DCHECK_LE(0, start_col); + RUY_DCHECK_LE(start_col, end_col); + RUY_DCHECK_LT(end_col, dst->layout.cols + RhsLayout::kCols); + RUY_DCHECK_EQ((end_col - start_col) % RhsLayout::kCols, 0); +#if RUY_OPT_ENABLED(RUY_OPT_FAT_KERNEL) + kernel.Run(lhs, rhs, spec, start_row, start_col, end_row, end_col, dst); +#else + for (int col = start_col; col < end_col; col += RhsLayout::kCols) { + int block_end_col = std::min(col + RhsLayout::kCols, end_col); + for (int row = start_row; row < end_row; row += LhsLayout::kCols) { + int block_end_row = std::min(row + LhsLayout::kCols, end_row); + kernel.Run(lhs, rhs, spec, row, col, block_end_row, block_end_col, dst); + } + } +#endif +} + +// Main entry point for kernels. +template +void RunKernel(Tuning tuning, const SidePair& src, void* spec, + const SidePair& start, const SidePair& end, + DMatrix* dst) { + Matrix mdst = ToMatrix(*dst); + RunKernelTyped( + tuning, ToPackedMatrix(src[Side::kLhs]), + ToPackedMatrix(src[Side::kRhs]), + *static_cast(spec), start[Side::kLhs], start[Side::kRhs], + end[Side::kLhs], end[Side::kRhs], &mdst); +} + +// Copied from gemmlowp/fixedpoint. +inline std::int32_t SaturatingRoundingDoublingHighMul(std::int32_t a, + std::int32_t b) { + bool overflow = a == b && a == std::numeric_limits::min(); + std::int64_t a_64(a); + std::int64_t b_64(b); + std::int64_t ab_64 = a_64 * b_64; + std::int32_t nudge = ab_64 >= 0 ? (1 << 30) : (1 - (1 << 30)); + std::int32_t ab_x2_high32 = + static_cast((ab_64 + nudge) / (1ll << 31)); + return overflow ? std::numeric_limits::max() : ab_x2_high32; +} + +inline std::int32_t RoundingDivideByPOT(std::int32_t numerator, int exponent) { + std::int32_t sign = numerator >= 0 ? 1 : -1; + std::int32_t abs_numerator = std::abs(numerator); + std::int32_t mask = (1LL << exponent) - 1; + std::int32_t remainder = abs_numerator & mask; + std::int32_t threshold = mask >> 1; + std::int32_t abs_result = + (abs_numerator >> exponent) + (remainder > threshold ? 1 : 0); + return sign * abs_result; +} + +// Copied from TF Lite code. +inline std::int32_t MultiplyByQuantizedMultiplier( + std::int32_t x, std::int32_t quantized_multiplier, int shift) { + int left_shift = shift > 0 ? shift : 0; + int right_shift = shift > 0 ? 0 : -shift; + return RoundingDivideByPOT(SaturatingRoundingDoublingHighMul( + x * (1 << left_shift), quantized_multiplier), + right_shift); +} + +// Helper to apply a fixed-point multiplier. Only 'applicable' if AccumScalar +// is int32 (i.e. in all cases except floating-point) and if the destination is +// not int32 (i.e. unless the user wants to get raw accumulators). +template ::value && + !std::is_same::value> +struct ApplyMultiplierImpl {}; + +// Specialization in non-applicable case: do nothing, just check that values +// are default. +template +struct ApplyMultiplierImpl { + using AccumScalar = typename Spec::AccumScalar; + using DstScalar = typename Spec::DstScalar; + static void Run(const Spec& spec, int row, AccumScalar* accum) { + RUY_DCHECK_EQ(spec.multiplier_fixedpoint, 0); + RUY_DCHECK_EQ(spec.multiplier_exponent, 0); + } +}; + +template +struct ApplyMultiplierImpl { + using AccumScalar = typename Spec::AccumScalar; + using DstScalar = typename Spec::DstScalar; + static void Run(const Spec& spec, int row, AccumScalar* accum) { + AccumScalar m = spec.multiplier_fixedpoint_perchannel + ? spec.multiplier_fixedpoint_perchannel[row] + : spec.multiplier_fixedpoint; + int e = spec.multiplier_exponent_perchannel + ? spec.multiplier_exponent_perchannel[row] + : spec.multiplier_exponent; + *accum = MultiplyByQuantizedMultiplier(*accum, m, e); + } +}; + +template +void ApplyMultiplier(const Spec& spec, int row, + typename Spec::AccumScalar* accum) { + ApplyMultiplierImpl::Run(spec, row, accum); +} + +template +struct Kernel { + using AccumScalar = typename Spec::AccumScalar; + using LhsLayout = typename Spec::StandardCppKernelLhsLayout; + using RhsLayout = typename Spec::StandardCppKernelRhsLayout; + explicit Kernel(Tuning) {} + void Run(const PackedMatrix& lhs, + const PackedMatrix& rhs, const Spec& spec, int start_row, + int start_col, int end_row, int end_col, + Matrix* dst) const { + // See the comment in RunKernelTyped. end_row may be larger than + // dst->layout.rows. It's the responsibility of the kernel to avoid + // overrunning dst boundaries, which we do here by computing + // clamped_end_row. + int clamped_end_row = std::min(end_row, dst->layout.rows); + int clamped_end_col = std::min(end_col, dst->layout.cols); + RUY_DCHECK_LE(0, start_row); + RUY_DCHECK_LE(start_row, clamped_end_row); + RUY_DCHECK_LE(clamped_end_row, dst->layout.rows); + RUY_DCHECK_LE(clamped_end_row, end_row); + RUY_DCHECK_LE(end_row - clamped_end_row, LhsLayout::kCols); + RUY_DCHECK_LE(0, start_col); + RUY_DCHECK_LE(start_col, clamped_end_col); + RUY_DCHECK_LE(clamped_end_col, dst->layout.cols); + RUY_DCHECK_LE(clamped_end_col, end_col); + RUY_DCHECK_LE(end_col - clamped_end_col, RhsLayout::kCols); + profiler::ScopeLabel label("Kernel (Standard Cpp)"); + const int depth = lhs.layout.rows; + for (int i = start_row; i < clamped_end_row; i++) { + for (int j = start_col; j < clamped_end_col; j++) { + using AccumScalar = typename Spec::AccumScalar; + AccumScalar accum = 0; + for (int k = 0; k < depth; k++) { + AccumScalar lhs_val = Element(lhs, k, i); + AccumScalar rhs_val = Element(rhs, k, j); + accum += lhs_val * rhs_val; + } + if (spec.bias) { + accum += spec.bias[i]; + } + if (lhs.zero_point) { + accum -= lhs.zero_point * rhs.sums[j]; + } + if (rhs.zero_point) { + accum -= rhs.zero_point * lhs.sums[i]; + } + if (lhs.zero_point && rhs.zero_point) { + accum += lhs.zero_point * rhs.zero_point * depth; + } + ApplyMultiplier(spec, i, &accum); + accum += dst->zero_point; + accum = std::min(accum, spec.clamp_max); + accum = std::max(accum, spec.clamp_min); + *ElementPtr(dst, i, j) = static_cast(accum); + } + } + } +}; + +#define RUY_INHERIT_KERNEL(PARENT, CHILD) \ + template \ + struct Kernel \ + : Kernel { \ + explicit Kernel(Tuning tuning) \ + : Kernel(tuning) {} \ + }; + +#if RUY_PLATFORM(NEON) +RUY_INHERIT_KERNEL(Path::kStandardCpp, Path::kNeon) +RUY_INHERIT_KERNEL(Path::kNeon, Path::kNeonDotprod) +#elif RUY_PLATFORM(X86) +RUY_INHERIT_KERNEL(Path::kStandardCpp, Path::kSse42) +RUY_INHERIT_KERNEL(Path::kSse42, Path::kAvx2) +RUY_INHERIT_KERNEL(Path::kAvx2, Path::kAvx512) +RUY_INHERIT_KERNEL(Path::kAvx512, Path::kAvxVnni) +#endif + +// KernelParams are shared across 32-bit and 64-bit NEON code, and x86 code. +// +// In other cases, we still define (empty) versions, so that dummy kernels +// can use the classes in function signatures. +#if ((RUY_PLATFORM(NEON_64) || RUY_PLATFORM(NEON_32)) && \ + RUY_OPT_ENABLED(RUY_OPT_ASM)) || \ + RUY_PLATFORM(X86) + +#define RUY_ASM_FLAG_HAS_BIAS 0x1 +#define RUY_ASM_FLAG_HAS_LHS_SUMS 0x2 +#define RUY_ASM_FLAG_HAS_RHS_SUMS 0x4 +#define RUY_ASM_FLAG_HAS_PERCHANNEL 0x8 +#define RUY_ASM_FLAG_NEEDS_LEFT_SHIFT 0x10 + +#define RUY_ASM_TYPE_ID_UINT8 1 +#define RUY_ASM_TYPE_ID_INT8 2 +#define RUY_ASM_TYPE_ID_INT16 3 +#define RUY_ASM_TYPE_ID_INT32 4 + +template +struct DstTypeId {}; + +template <> +struct DstTypeId { + static constexpr int kValue = RUY_ASM_TYPE_ID_UINT8; +}; + +template <> +struct DstTypeId { + static constexpr int kValue = RUY_ASM_TYPE_ID_INT8; +}; + +template <> +struct DstTypeId { + static constexpr int kValue = RUY_ASM_TYPE_ID_INT16; +}; + +template <> +struct DstTypeId { + static constexpr int kValue = RUY_ASM_TYPE_ID_INT32; +}; + +template +struct KernelParams8bit { + static constexpr int kMaxDstTypeSize = 4; + + const std::int32_t* bias; + const std::int32_t* lhs_sums; + const std::int32_t* rhs_sums; + const std::int8_t* lhs_base_ptr; + const std::int32_t* multiplier_fixedpoint; + const std::int32_t* multiplier_exponent; + const std::int8_t* rhs_base_ptr; + void* dst_base_ptr; + std::int32_t lhs_zero_point; + std::int32_t rhs_zero_point; + std::int32_t dst_zero_point; + std::int32_t prod_zp_depth; + std::int32_t start_row; + std::int32_t start_col; + std::int32_t last_row; + std::int32_t last_col; + std::int32_t dst_rows; + std::int32_t dst_cols; + std::int32_t lhs_stride; + std::int32_t rhs_stride; + std::int32_t dst_stride; + std::int32_t depth; + std::int32_t clamp_min; + std::int32_t clamp_max; + std::uint8_t flags; + std::uint8_t dst_type_id; + const std::int32_t zero_data[LhsCols] = {0}; + std::uint8_t dst_tmp_buf[LhsCols * RhsCols * kMaxDstTypeSize]; + std::int32_t multiplier_fixedpoint_buf[LhsCols]; + std::int32_t multiplier_exponent_buf[LhsCols]; +}; + +template +void MakeKernelParams8bit(const PackedMatrix& lhs, + const PackedMatrix& rhs, + const BasicSpec& spec, + int start_row, int start_col, int end_row, + int end_col, Matrix* dst, + KernelParams8bit* params) { + using Params = KernelParams8bit; + + static_assert(sizeof(DstScalar) <= Params::kMaxDstTypeSize, ""); + + const int depth = lhs.layout.rows; + RUY_DCHECK_EQ(start_row % LhsCols, 0); + RUY_DCHECK_EQ(start_col % RhsCols, 0); + RUY_DCHECK_EQ(end_row % LhsCols, 0); + RUY_DCHECK_EQ(end_col % RhsCols, 0); + + params->lhs_base_ptr = lhs.data + start_row * lhs.layout.stride; + params->rhs_base_ptr = rhs.data + start_col * rhs.layout.stride; + params->flags = 0; + params->bias = params->zero_data; + if (spec.bias) { + params->bias = spec.bias; + params->flags |= RUY_ASM_FLAG_HAS_BIAS; + } + if (lhs.sums) { + params->lhs_sums = lhs.sums; + params->flags |= RUY_ASM_FLAG_HAS_LHS_SUMS; + } + if (rhs.sums) { + params->rhs_sums = rhs.sums; + params->flags |= RUY_ASM_FLAG_HAS_RHS_SUMS; + } + params->start_row = start_row; + params->start_col = start_col; + params->last_row = end_row - LhsCols; + params->last_col = end_col - RhsCols; + params->lhs_stride = lhs.layout.stride; + params->rhs_stride = rhs.layout.stride; + params->dst_stride = sizeof(DstScalar) * dst->layout.stride; + params->lhs_zero_point = lhs.zero_point; + params->rhs_zero_point = rhs.zero_point; + params->dst_zero_point = dst->zero_point; + params->depth = depth; + params->prod_zp_depth = lhs.zero_point * rhs.zero_point * depth; + if (spec.multiplier_fixedpoint_perchannel) { + params->flags |= RUY_ASM_FLAG_NEEDS_LEFT_SHIFT; + params->flags |= RUY_ASM_FLAG_HAS_PERCHANNEL; + params->multiplier_fixedpoint = spec.multiplier_fixedpoint_perchannel; + params->multiplier_exponent = spec.multiplier_exponent_perchannel; + } else { + if (spec.multiplier_exponent > 0) { + params->flags |= RUY_ASM_FLAG_NEEDS_LEFT_SHIFT; + } + params->multiplier_fixedpoint = params->multiplier_fixedpoint_buf; + params->multiplier_exponent = params->multiplier_exponent_buf; + for (int i = 0; i < LhsCols; i++) { + params->multiplier_fixedpoint_buf[i] = spec.multiplier_fixedpoint; + params->multiplier_exponent_buf[i] = spec.multiplier_exponent; + } + } + params->clamp_min = spec.clamp_min; + params->clamp_max = spec.clamp_max; + params->dst_rows = dst->layout.rows; + params->dst_cols = dst->layout.cols; + + RUY_DCHECK_LT(params->last_row, params->dst_rows); + RUY_DCHECK_LT(params->last_col, params->dst_cols); + + params->dst_type_id = DstTypeId::kValue; + params->dst_base_ptr = + dst->data.get() + start_col * dst->layout.stride + start_row; +} + +template +struct KernelParamsFloat { + const float* lhs_base_ptr; + const float* rhs_base_ptr; + float* dst_base_ptr; + const float* bias; + std::int32_t start_row; + std::int32_t start_col; + std::int32_t last_row; + std::int32_t last_col; + std::int32_t dst_rows; + std::int32_t dst_cols; + std::int32_t lhs_stride; + std::int32_t rhs_stride; + std::int32_t dst_stride; + std::int32_t depth; + float clamp_min; + float clamp_max; + std::uint8_t flags; + const float zero_data[LhsCols] = {0}; + float dst_tmp_buf[LhsCols * RhsCols]; +}; + +template +inline void MakeKernelParamsFloat(const PackedMatrix& lhs, + const PackedMatrix& rhs, + const BasicSpec& spec, + int start_row, int start_col, int end_row, + int end_col, Matrix* dst, + KernelParamsFloat* params) { + const int depth = lhs.layout.rows; + RUY_DCHECK_EQ(start_row % LhsCols, 0); + RUY_DCHECK_EQ(start_col % RhsCols, 0); + RUY_DCHECK_EQ(end_row % LhsCols, 0); + RUY_DCHECK_EQ(end_col % RhsCols, 0); + + params->lhs_base_ptr = lhs.data + start_row * lhs.layout.stride; + params->rhs_base_ptr = rhs.data + start_col * rhs.layout.stride; + params->dst_base_ptr = + dst->data.get() + start_col * dst->layout.stride + start_row; + + std::uint8_t flags = 0; + params->bias = params->zero_data; + if (spec.bias) { + params->bias = spec.bias; + flags |= RUY_ASM_FLAG_HAS_BIAS; + } + params->flags = flags; + params->start_row = start_row; + params->start_col = start_col; + params->last_row = end_row - LhsCols; + params->last_col = end_col - RhsCols; + params->lhs_stride = sizeof(float) * lhs.layout.stride; + params->rhs_stride = sizeof(float) * rhs.layout.stride; + params->dst_stride = sizeof(float) * dst->layout.stride; + params->depth = depth; + params->clamp_min = spec.clamp_min; + params->clamp_max = spec.clamp_max; + params->dst_rows = dst->layout.rows; + params->dst_cols = dst->layout.cols; + + RUY_DCHECK_LT(params->last_row, params->dst_rows); + RUY_DCHECK_LT(params->last_col, params->dst_cols); +} + +#else // ((RUY_PLATFORM(NEON_64) || RUY_PLATFORM(NEON_32)) && + // RUY_OPT_ENABLED(RUY_OPT_ASM)) || RUY_PLATFORM(X86) + +template +struct KernelParams8bit {}; + +template +struct KernelParamsFloat {}; + +#endif // ((RUY_PLATFORM(NEON_64) || RUY_PLATFORM(NEON_32)) && + // RUY_OPT_ENABLED(RUY_OPT_ASM)) || RUY_PLATFORM(X86) + +} // namespace ruy + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_KERNEL_COMMON_H_ diff --git a/ruy/kernel_sse42.cc b/ruy/kernel_sse42.cc new file mode 100644 index 0000000..747ca1c --- /dev/null +++ b/ruy/kernel_sse42.cc @@ -0,0 +1,428 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "ruy/check_macros.h" +#include "ruy/kernel.h" +#include "ruy/opt_set.h" +#include "ruy/platform.h" +#include "ruy/profiler/instrumentation.h" + +#if RUY_PLATFORM(SSE42) && RUY_OPT_ENABLED(RUY_OPT_ASM) +#include // IWYU pragma: keep +#endif + +namespace ruy { + +#if !(RUY_PLATFORM(SSE42) && RUY_OPT_ENABLED(RUY_OPT_ASM)) + +void Kernel8bitSse42(const KernelParams8bit<8, 8>& params) { + // CPU-ID-based checks should disable the path that would reach this point. + RUY_DCHECK(false); +} + +void KernelFloatSse42(const KernelParamsFloat<8, 8>& params) { + // CPU-ID-based checks should disable the path that would reach this point. + RUY_DCHECK(false); +} + +#else // RUY_PLATFORM(SSE42) && RUY_OPT_ENABLED(RUY_OPT_ASM) + +static constexpr int kAvxFloatBlockSize = 8; +static constexpr int kAvx8bitBlockSize = 8; +static constexpr int kAvx8bitInnerSize = 4; + +// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. +// Optimization is not finished. In particular the dimensions of the kernel +// blocks can be changed as desired. +// +// When removing this comment, update profiling label below. +void Kernel8bitSse42(const KernelParams8bit<8, 8>& params) { + profiler::ScopeLabel label("Kernel kSse42 8-bit (UNFINISHED)"); + std::int32_t accum_data[kAvx8bitBlockSize][kAvx8bitBlockSize]; + + int bias_ptr_block_increment = + params.flags & RUY_ASM_FLAG_HAS_BIAS ? kAvx8bitBlockSize : 0; + + const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; + void* dst_col_ptr = params.dst_base_ptr; + const std::int32_t* bias_col_ptr = params.bias; + if (params.flags & RUY_ASM_FLAG_HAS_BIAS) { + bias_col_ptr += params.start_row; + } + + for (int col = params.start_col; col <= params.last_col; + col += kAvx8bitBlockSize) { + const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; + void* dst_ptr = dst_col_ptr; + const std::int32_t* bias_ptr = bias_col_ptr; + + for (int row = params.start_row; row <= params.last_row; + row += kAvx8bitBlockSize) { + const int residual_rows = + std::min(params.dst_rows - row, kAvx8bitBlockSize); + const int residual_cols = + std::min(params.dst_cols - col, kAvx8bitBlockSize); + + // Initialize with bias. + std::int32_t initial_accum_data[kAvx8bitBlockSize]; + for (int i = 0; i < kAvx8bitBlockSize; ++i) { + initial_accum_data[i] = 0; + } + for (int i = 0; i < residual_rows; ++i) { + initial_accum_data[i] = bias_ptr[i]; + } + for (int j = 0; j < kAvx8bitBlockSize; ++j) { + for (int i = 0; i < kAvx8bitBlockSize; ++i) { + accum_data[j][i] = initial_accum_data[i]; + } + } + bias_ptr += bias_ptr_block_increment; + + std::int8_t lhs_data[kAvx8bitBlockSize][kAvx8bitInnerSize]; + std::int8_t rhs_data[kAvx8bitBlockSize][kAvx8bitInnerSize]; + const std::int8_t* lhs_ptr = lhs_col_ptr; + const std::int8_t* rhs_ptr = rhs_col_ptr; + for (int d = 0; d < params.depth; d += kAvx8bitInnerSize) { + for (int i = 0; i < kAvx8bitBlockSize; ++i) { + for (int x = 0; x < kAvx8bitInnerSize; ++x) { + lhs_data[i][x] = lhs_ptr[i * kAvx8bitInnerSize + x]; + rhs_data[i][x] = rhs_ptr[i * kAvx8bitInnerSize + x]; + } + } + for (int j = 0; j < kAvx8bitBlockSize; ++j) { + for (int i = 0; i < kAvx8bitBlockSize; ++i) { + for (int x = 0; x < kAvx8bitInnerSize; ++x) { + accum_data[j][i] += lhs_data[i][x] * rhs_data[j][x]; + } + } + } + lhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize; + rhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize; + } + + if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && params.rhs_zero_point) { + for (int j = 0; j < kAvx8bitBlockSize; ++j) { + for (int i = 0; i < kAvx8bitBlockSize; ++i) { + accum_data[j][i] -= + params.rhs_zero_point * params.lhs_sums[row + i]; + } + } + } + if ((params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && params.lhs_zero_point) { + for (int j = 0; j < kAvx8bitBlockSize; ++j) { + for (int i = 0; i < kAvx8bitBlockSize; ++i) { + accum_data[j][i] -= + params.lhs_zero_point * params.rhs_sums[col + j]; + } + } + } + if (params.lhs_zero_point && params.rhs_zero_point) { + for (int j = 0; j < kAvx8bitBlockSize; ++j) { + for (int i = 0; i < kAvx8bitBlockSize; ++i) { + accum_data[j][i] += params.prod_zp_depth; + } + } + } + + if (params.dst_type_id != DstTypeId::kValue) { + std::int32_t m_vector[kAvx8bitBlockSize]; + std::int32_t e_vector[kAvx8bitBlockSize]; + // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT. + if (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) { + int i = 0; + for (; i < residual_rows; ++i) { + m_vector[i] = params.multiplier_fixedpoint[row + i]; + e_vector[i] = params.multiplier_exponent[row + i]; + } + for (; i < kAvx8bitBlockSize; ++i) { + m_vector[i] = m_vector[0]; + e_vector[i] = e_vector[0]; + } + } else { + // These arrays have size LhsCols, and are pre-filled. + for (int i = 0; i < kAvx8bitBlockSize; ++i) { + m_vector[i] = params.multiplier_fixedpoint[i]; + e_vector[i] = params.multiplier_exponent[i]; + } + } + for (int j = 0; j < kAvx8bitBlockSize; ++j) { + for (int i = 0; i < kAvx8bitBlockSize; ++i) { + accum_data[j][i] = MultiplyByQuantizedMultiplier( + accum_data[j][i], m_vector[i], e_vector[i]); + } + } + + if (params.dst_zero_point) { + for (int j = 0; j < kAvx8bitBlockSize; ++j) { + for (int i = 0; i < kAvx8bitBlockSize; ++i) { + accum_data[j][i] += params.dst_zero_point; + } + } + } + + for (int j = 0; j < kAvx8bitBlockSize; ++j) { + for (int i = 0; i < kAvx8bitBlockSize; ++i) { + accum_data[j][i] = + std::min(accum_data[j][i], params.clamp_max); + accum_data[j][i] = + std::max(accum_data[j][i], params.clamp_min); + } + } + } + + const bool store_full_block = (residual_rows == kAvx8bitBlockSize) && + (residual_cols == kAvx8bitBlockSize); + + if (params.dst_type_id == DstTypeId::kValue) { + std::int8_t* tmp_ptr = + store_full_block + ? static_cast(dst_ptr) + : const_cast( + reinterpret_cast(params.dst_tmp_buf)); + const int block_col_offset = + store_full_block ? params.dst_stride / sizeof(std::int8_t) + : kAvx8bitBlockSize; + for (int j = 0; j < kAvx8bitBlockSize; ++j) { + for (int i = 0; i < kAvx8bitBlockSize; ++i) { + tmp_ptr[i] = accum_data[j][i]; + } + tmp_ptr += block_col_offset; + } + + if (!store_full_block) { + const std::int8_t* block_ptr = + reinterpret_cast(params.dst_tmp_buf); + for (int j = 0; j < residual_cols; ++j) { + for (int i = 0; i < residual_rows; ++i) { + static_cast( + dst_ptr)[j * params.dst_stride / sizeof(std::int8_t) + i] = + block_ptr[i]; + } + block_ptr += kAvx8bitBlockSize; + } + } + dst_ptr = static_cast(static_cast(dst_ptr) + + kAvx8bitBlockSize); + } else if (params.dst_type_id == DstTypeId::kValue) { + std::uint8_t* tmp_ptr = store_full_block + ? static_cast(dst_ptr) + : const_cast( + reinterpret_cast( + params.dst_tmp_buf)); + const int block_col_offset = + store_full_block ? params.dst_stride : kAvx8bitBlockSize; + for (int j = 0; j < kAvx8bitBlockSize; ++j) { + for (int i = 0; i < kAvx8bitBlockSize; ++i) { + tmp_ptr[i] = accum_data[j][i]; + } + tmp_ptr += block_col_offset; + } + + if (!store_full_block) { + const std::uint8_t* block_ptr = + reinterpret_cast(params.dst_tmp_buf); + for (int j = 0; j < residual_cols; ++j) { + for (int i = 0; i < residual_rows; ++i) { + static_cast( + dst_ptr)[j * params.dst_stride / sizeof(std::uint8_t) + i] = + block_ptr[i]; + } + block_ptr += kAvx8bitBlockSize; + } + } + dst_ptr = static_cast(static_cast(dst_ptr) + + kAvx8bitBlockSize); + } else if (params.dst_type_id == DstTypeId::kValue) { + if (store_full_block) { + std::int16_t* tmp_ptr = static_cast(dst_ptr); + const int block_col_offset = params.dst_stride / sizeof(std::int16_t); + for (int j = 0; j < kAvx8bitBlockSize; ++j) { + for (int i = 0; i < kAvx8bitBlockSize; ++i) { + tmp_ptr[i] = accum_data[j][i]; + } + tmp_ptr += block_col_offset; + } + } else { + std::int16_t* tmp_ptr = const_cast( + reinterpret_cast(params.dst_tmp_buf)); + const int block_col_offset = kAvx8bitBlockSize; + for (int j = 0; j < kAvx8bitBlockSize; ++j) { + for (int i = 0; i < kAvx8bitBlockSize; ++i) { + tmp_ptr[i] = accum_data[j][i]; + } + tmp_ptr += block_col_offset; + } + const std::int16_t* block_ptr = + reinterpret_cast(params.dst_tmp_buf); + std::int16_t* dst_block_ptr = static_cast(dst_ptr); + for (int j = 0; j < residual_cols; ++j) { + for (int i = 0; i < residual_rows; ++i) { + dst_block_ptr[i] = block_ptr[i]; + } + dst_block_ptr += params.dst_stride / sizeof(std::int16_t); + block_ptr += kAvx8bitBlockSize; + } + } + dst_ptr = static_cast(static_cast(dst_ptr) + + kAvx8bitBlockSize); + } else if (params.dst_type_id == DstTypeId::kValue) { + if (store_full_block) { + std::int32_t* tmp_ptr = static_cast(dst_ptr); + const int block_col_offset = params.dst_stride / sizeof(std::int32_t); + for (int j = 0; j < kAvx8bitBlockSize; ++j) { + for (int i = 0; i < kAvx8bitBlockSize; ++i) { + tmp_ptr[i] = accum_data[j][i]; + } + tmp_ptr += block_col_offset; + } + } else { + std::int32_t* dst_block_ptr = static_cast(dst_ptr); + for (int j = 0; j < residual_cols; ++j) { + for (int i = 0; i < residual_rows; ++i) { + dst_block_ptr[i] = accum_data[j][i]; + } + dst_block_ptr += params.dst_stride / sizeof(std::int32_t); + } + } + dst_ptr = static_cast(static_cast(dst_ptr) + + kAvx8bitBlockSize); + } else { + RUY_DCHECK(false); + } + + lhs_col_ptr += kAvx8bitBlockSize * params.lhs_stride; + } // End row-block loop. + + dst_col_ptr = static_cast(static_cast(dst_col_ptr) + + kAvx8bitBlockSize * params.dst_stride); + rhs_col_ptr += kAvx8bitBlockSize * params.rhs_stride; + } // End col-block loop. +} // NOLINT(readability/fn_size) + +// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. +// Optimization is not finished. In particular the dimensions of the kernel +// blocks can be changed as desired. +// +// When removing this comment, update profiling label below. +void KernelFloatSse42(const KernelParamsFloat<8, 8>& params) { + profiler::ScopeLabel label("Kernel kSse42 float (UNFINISHED)"); + + float lhs_data[kAvxFloatBlockSize]; + float rhs_data[kAvxFloatBlockSize]; + float accum_data[kAvxFloatBlockSize][kAvxFloatBlockSize]; + int bias_ptr_block_increment = + params.flags & RUY_ASM_FLAG_HAS_BIAS ? kAvxFloatBlockSize : 0; + + const float* rhs_col_ptr = params.rhs_base_ptr; + float* dst_col_ptr = params.dst_base_ptr; + const float* bias_col_ptr = params.bias; + if (params.flags & RUY_ASM_FLAG_HAS_BIAS) { + bias_col_ptr += params.start_row; + } + + for (int col = params.start_col; col <= params.last_col; + col += kAvxFloatBlockSize) { + const float* lhs_col_ptr = params.lhs_base_ptr; + float* dst_ptr = dst_col_ptr; + const float* bias_ptr = bias_col_ptr; + + for (int row = params.start_row; row <= params.last_row; + row += kAvxFloatBlockSize) { + const int residual_rows = + std::min(params.dst_rows - row, kAvxFloatBlockSize); + const int residual_cols = + std::min(params.dst_cols - col, kAvxFloatBlockSize); + + // Initialize with bias. + float initial_accum_data[kAvxFloatBlockSize]; + for (int i = 0; i < kAvxFloatBlockSize; ++i) { + initial_accum_data[i] = 0.0f; + } + for (int i = 0; i < residual_rows; ++i) { + initial_accum_data[i] = bias_ptr[i]; + } + for (int j = 0; j < kAvxFloatBlockSize; ++j) { + for (int i = 0; i < kAvxFloatBlockSize; ++i) { + accum_data[j][i] = initial_accum_data[i]; + } + } + bias_ptr += bias_ptr_block_increment; + + const float* lhs_ptr = lhs_col_ptr; + const float* rhs_ptr = rhs_col_ptr; + for (int d = 0; d < params.depth; ++d) { + for (int i = 0; i < kAvxFloatBlockSize; ++i) { + lhs_data[i] = lhs_ptr[i]; + rhs_data[i] = rhs_ptr[i]; + } + for (int j = 0; j < kAvxFloatBlockSize; ++j) { + for (int i = 0; i < kAvxFloatBlockSize; ++i) { + accum_data[j][i] += lhs_data[i] * rhs_data[j]; + } + } + lhs_ptr += kAvxFloatBlockSize; + rhs_ptr += kAvxFloatBlockSize; + } + + for (int j = 0; j < kAvxFloatBlockSize; ++j) { + for (int i = 0; i < kAvxFloatBlockSize; ++i) { + accum_data[j][i] = + std::min(accum_data[j][i], params.clamp_max); + accum_data[j][i] = + std::max(accum_data[j][i], params.clamp_min); + } + } + + const bool store_full_block = (residual_rows == kAvxFloatBlockSize) && + (residual_cols == kAvxFloatBlockSize); + + { + float* block_ptr = + store_full_block ? dst_ptr : const_cast(params.dst_tmp_buf); + const int block_col_offset = store_full_block + ? params.dst_stride / sizeof(float) + : kAvxFloatBlockSize; + for (int j = 0; j < kAvxFloatBlockSize; ++j) { + for (int i = 0; i < kAvxFloatBlockSize; ++i) { + block_ptr[i] = accum_data[j][i]; + } + block_ptr += block_col_offset; + } + } + if (!store_full_block) { + const float* block_ptr = params.dst_tmp_buf; + for (int j = 0; j < residual_cols; ++j) { + for (int i = 0; i < residual_rows; ++i) { + dst_ptr[j * params.dst_stride / sizeof(float) + i] = block_ptr[i]; + } + block_ptr += kAvxFloatBlockSize; + } + } + + lhs_col_ptr += kAvxFloatBlockSize * params.lhs_stride / sizeof(float); + dst_ptr += kAvxFloatBlockSize; + } // End row-block loop. + + dst_col_ptr += kAvxFloatBlockSize * params.dst_stride / sizeof(float); + rhs_col_ptr += kAvxFloatBlockSize * params.rhs_stride / sizeof(float); + } // End col-block loop. +} + +#endif // RUY_PLATFORM(SSE42) && RUY_OPT_ENABLED(RUY_OPT_ASM) + +} // namespace ruy diff --git a/ruy/kernel_x86.h b/ruy/kernel_x86.h new file mode 100644 index 0000000..dbcf42b --- /dev/null +++ b/ruy/kernel_x86.h @@ -0,0 +1,222 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_KERNEL_X86_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_KERNEL_X86_H_ + +#include + +#include "ruy/common.h" +#include "ruy/internal_matrix.h" +#include "ruy/kernel_common.h" +#include "ruy/matrix.h" +#include "ruy/opt_set.h" +#include "ruy/path.h" +#include "ruy/platform.h" +#include "ruy/spec.h" +#include "ruy/tune.h" + +namespace ruy { + +#if RUY_PLATFORM(X86) +// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. +// Optimization is not finished. In particular the dimensions of the kernel +// blocks can be changed as desired. +// +void Kernel8bitSse42(const KernelParams8bit<8, 8>& params); + +template +struct Kernel> { + Tuning tuning = Tuning::kAuto; + using LhsLayout = FixedKernelLayout; + using RhsLayout = FixedKernelLayout; + explicit Kernel(Tuning tuning_) : tuning(tuning_) {} + void Run(const PackedMatrix& lhs, + const PackedMatrix& rhs, + const BasicSpec& spec, int start_row, + int start_col, int end_row, int end_col, + Matrix* dst) const { + KernelParams8bit params; + MakeKernelParams8bit(lhs, rhs, spec, start_row, start_col, end_row, end_col, + dst, ¶ms); + Kernel8bitSse42(params); + } +}; + +void KernelFloatSse42(const KernelParamsFloat<8, 8>& params); + +template <> +struct Kernel> { + Tuning tuning = Tuning::kAuto; + using LhsLayout = FixedKernelLayout; + using RhsLayout = FixedKernelLayout; + explicit Kernel(Tuning tuning_) : tuning(tuning_) {} + void Run(const PackedMatrix& lhs, const PackedMatrix& rhs, + const BasicSpec& spec, int start_row, int start_col, + int end_row, int end_col, Matrix* dst) const { + KernelParamsFloat params; + MakeKernelParamsFloat(lhs, rhs, spec, start_row, start_col, end_row, + end_col, dst, ¶ms); + KernelFloatSse42(params); + } +}; + +void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params); +void Kernel8bitAvx512SingleCol(const KernelParams8bit<16, 16>& params); + +template +struct Kernel> { + Tuning tuning = Tuning::kAuto; + using LhsLayout = FixedKernelLayout; + using RhsLayout = FixedKernelLayout; + explicit Kernel(Tuning tuning_) : tuning(tuning_) {} + void Run(const PackedMatrix& lhs, + const PackedMatrix& rhs, + const BasicSpec& spec, int start_row, + int start_col, int end_row, int end_col, + Matrix* dst) const { + KernelParams8bit params; + MakeKernelParams8bit(lhs, rhs, spec, start_row, start_col, end_row, end_col, + dst, ¶ms); + if (dst->layout.cols == 1) { + Kernel8bitAvx512SingleCol(params); + } else { + Kernel8bitAvx512(params); + } + } +}; + +void KernelFloatAvx512(const KernelParamsFloat<16, 16>& params); +void KernelFloatAvx512SingleCol(const KernelParamsFloat<16, 16>& param); + +template <> +struct Kernel> { + Tuning tuning = Tuning::kAuto; + using LhsLayout = FixedKernelLayout; + using RhsLayout = FixedKernelLayout; + explicit Kernel(Tuning tuning_) : tuning(tuning_) {} + void Run(const PackedMatrix& lhs, const PackedMatrix& rhs, + const BasicSpec& spec, int start_row, int start_col, + int end_row, int end_col, Matrix* dst) const { + KernelParamsFloat params; + MakeKernelParamsFloat(lhs, rhs, spec, start_row, start_col, end_row, + end_col, dst, ¶ms); + if (dst->layout.cols == 1) { + KernelFloatAvx512SingleCol(params); + } else { + KernelFloatAvx512(params); + } + } +}; + +void Kernel8bitAvx2(const KernelParams8bit<8, 8>& params); +void Kernel8bitAvx2SingleCol(const KernelParams8bit<8, 8>& params); + +template +struct Kernel> { + Tuning tuning = Tuning::kAuto; + using LhsLayout = FixedKernelLayout; + using RhsLayout = FixedKernelLayout; + explicit Kernel(Tuning tuning_) : tuning(tuning_) {} + void Run(const PackedMatrix& lhs, + const PackedMatrix& rhs, + const BasicSpec& spec, int start_row, + int start_col, int end_row, int end_col, + Matrix* dst) const { + KernelParams8bit params; + MakeKernelParams8bit(lhs, rhs, spec, start_row, start_col, end_row, end_col, + dst, ¶ms); + if (dst->layout.cols == 1) { + Kernel8bitAvx2SingleCol(params); + } else { + Kernel8bitAvx2(params); + } + } +}; + +void KernelFloatAvx2(const KernelParamsFloat<8, 8>& params); +void KernelFloatAvx2SingleCol(const KernelParamsFloat<8, 8>& params); + +template <> +struct Kernel> { + Tuning tuning = Tuning::kAuto; + using LhsLayout = FixedKernelLayout; + using RhsLayout = FixedKernelLayout; + explicit Kernel(Tuning tuning_) : tuning(tuning_) {} + void Run(const PackedMatrix& lhs, const PackedMatrix& rhs, + const BasicSpec& spec, int start_row, int start_col, + int end_row, int end_col, Matrix* dst) const { + KernelParamsFloat params; + MakeKernelParamsFloat(lhs, rhs, spec, start_row, start_col, end_row, + end_col, dst, ¶ms); + if (dst->layout.cols == 1) { + KernelFloatAvx2SingleCol(params); + } else { + KernelFloatAvx2(params); + } + } +}; + +// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. +// Optimization is not finished. In particular the dimensions of the kernel +// blocks can be changed as desired. +// +void Kernel8bitAvxVnni(const KernelParams8bit<16, 16>& params); + +template +struct Kernel> { + Tuning tuning = Tuning::kAuto; + using LhsLayout = FixedKernelLayout; + using RhsLayout = FixedKernelLayout; + explicit Kernel(Tuning tuning_) : tuning(tuning_) {} + void Run(const PackedMatrix& lhs, + const PackedMatrix& rhs, + const BasicSpec& spec, int start_row, + int start_col, int end_row, int end_col, + Matrix* dst) const { + KernelParams8bit params; + MakeKernelParams8bit(lhs, rhs, spec, start_row, start_col, end_row, end_col, + dst, ¶ms); + Kernel8bitAvxVnni(params); + } +}; + +void KernelFloatAvxVnni(const KernelParamsFloat<16, 16>& params); + +template <> +struct Kernel> { + Tuning tuning = Tuning::kAuto; + using LhsLayout = FixedKernelLayout; + using RhsLayout = FixedKernelLayout; + explicit Kernel(Tuning tuning_) : tuning(tuning_) {} + void Run(const PackedMatrix& lhs, const PackedMatrix& rhs, + const BasicSpec& spec, int start_row, int start_col, + int end_row, int end_col, Matrix* dst) const { + KernelParamsFloat params; + MakeKernelParamsFloat(lhs, rhs, spec, start_row, start_col, end_row, + end_col, dst, ¶ms); + KernelFloatAvxVnni(params); + } +}; + +#endif // RUY_PLATFORM(X86) + +} // namespace ruy + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_KERNEL_X86_H_ diff --git a/ruy/matrix.h b/ruy/matrix.h new file mode 100644 index 0000000..2dcb081 --- /dev/null +++ b/ruy/matrix.h @@ -0,0 +1,182 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_MATRIX_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_MATRIX_H_ + +#include +#include // IWYU pragma: keep +#include + +#include "ruy/check_macros.h" + +namespace ruy { + +// Layout storage order. Here and elsewhere, 'col' is short for 'column'. +// 'column-major' means that each column is contiguous in memory. +enum class Order : std::uint8_t { kColMajor, kRowMajor }; + +// Describes the shape and storage layout of a matrix. +struct Layout final { + std::int32_t rows = 0; + std::int32_t cols = 0; + // Stride is the offset between two adjacent matrix elements + // in the non-contiguous direction. + std::int32_t stride = 0; + Order order = Order::kColMajor; +}; + +namespace detail { + +// Thin wrapper around a pointer that tracks its constness dynamically. +// +// This is our take on the C++ problem of enforcing constness of data +// wrapped in a containers class: it's not worth the hassle of trying to +// make it fully work at compile-time. +// Instead, we only enforce constness at runtime, and to make it +// zero-overhead, we only enforce it in debug builds. +template +class ConstCheckingPtr final { + public: + using element_type = T; + + // Convenience methods. Most `set` calls go through these. + ConstCheckingPtr& operator=(T* ptr) { + set(ptr); + return *this; + } + ConstCheckingPtr& operator=(const T* ptr) { + set(ptr); + return *this; + } + ConstCheckingPtr& operator=(std::nullptr_t) { + set(static_cast(nullptr)); + return *this; + } + + // Core accessors. These encapsulate the main logic: + // - for `set`, the constness of the argument determines whether internal + // pointer should be tracked as const/mutable. + // - for `get`, the constness of `this` determines whether the call + // counts as a const or mutable use of the internal pointer. + void set(T* ptr) { + ptr_ = ptr; + set_mutable(true); + } + void set(const T* ptr) { + ptr_ = ptr; + set_mutable(false); + } + T* get() /* NOT const */ { + assert_mutable(); + return const_cast(ptr_); + } + const T* get() const { return ptr_; } + + private: + static_assert(!std::is_const::value, ""); + const T* ptr_ = nullptr; +#ifndef NDEBUG + bool is_mutable_ = true; + void set_mutable(bool val) { is_mutable_ = val; } + void assert_mutable() { RUY_DCHECK(is_mutable_); } +#else + void set_mutable(bool) {} + void assert_mutable() {} +#endif +}; + +} // namespace detail + +// A Matrix is really what Eigen and gemmlowp would have called a 'matrix map': +// it merely wraps existing data as a matrix. It doesn't own any buffer. +// Scalar may be any floating-point or integral type. When integral, it may be +// signed or unsigned. +template +struct Matrix final { + Matrix& operator=(const Matrix& other) { + data = other.data; + cacheable = other.cacheable; + layout = other.layout; + zero_point = other.zero_point; + return *this; + } + + // The underlying buffer wrapped by this matrix. + detail::ConstCheckingPtr data; + // The shape and data layout of this matrix. + Layout layout; + // The zero_point, i.e. which Scalar value is to be interpreted as zero. + // When Scalar is floating-point, this must be 0. + Scalar zero_point = 0; + // Clients of Ruy must set this flag to enable any caching behavior. Doesn't + // impact numerical results, but caching can impact observable metrics like + // latency, memory usage, power, etc. + bool cacheable = false; +}; + +inline void MakeSimpleLayout(int rows, int cols, Order order, Layout* layout) { + layout->rows = rows; + layout->cols = cols; + layout->order = order; + layout->stride = order == Order::kColMajor ? rows : cols; +} + +// Opaque data structure representing a pre-packed matrix, as obtained from +// Ruy's advanced API. +struct PrepackedMatrix { + void* data = nullptr; + std::size_t data_size = 0; + void* sums = nullptr; + std::size_t sums_size = 0; +}; + +template +StreamType& operator<<(StreamType& stream, const Matrix& mat) { + for (int row = 0; row < mat.layout.rows; row++) { + for (int col = 0; col < mat.layout.cols; col++) { + stream << static_cast(Element(mat, row, col)) << " "; + } + stream << "\n"; + } + return stream; +} + +// Compile-time version of KernelLayout, used to declare kernel layouts in a +// way that can be consumed by compile-time logic. +// See how partial specializations of Kernel use it to declare their layouts. +// The only reason why this is currently part of the public API is to +// allow testing various layouts for the Path::kStandardCpp kernel, as a +// testing-only feature. See Spec::StandardCppKernelLhsLayout. +template +struct FixedKernelLayout { + static constexpr Order kOrder = tOrder; + static constexpr int kRows = tRows; + static constexpr int kCols = tCols; +}; + +#if (__cplusplus < 201703L) +// A static constexpr data member is automatically inline and should not require +// redeclaration without an initializer. This is actually deprecated from C++17 +// onwards. Clang with -O0 without this can fail to link. +template +constexpr int FixedKernelLayout::kCols; +template +constexpr int FixedKernelLayout::kRows; +#endif + +} // namespace ruy + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_MATRIX_H_ diff --git a/ruy/opt_set.h b/ruy/opt_set.h new file mode 100644 index 0000000..fef0107 --- /dev/null +++ b/ruy/opt_set.h @@ -0,0 +1,51 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_OPT_SET_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_OPT_SET_H_ + +// RUY_OPT_SET is a compile-time API that Ruy provides for enabling/disabling +// certain optimizations. It should be used by defining that macro on the +// compiler command line. +// +// Each bit in RUY_OPT_SET controls a particular optimization done in Ruy. +#define RUY_OPT_INTRINSICS 0x1 +#define RUY_OPT_ASM 0x2 +#define RUY_OPT_TUNING 0x4 +#define RUY_OPT_FAT_KERNEL 0x8 +#define RUY_OPT_NATIVE_ROUNDING 0x10 +#define RUY_OPT_AVOID_ALIASING 0x20 +#define RUY_OPT_MAX_STREAMING 0x40 +#define RUY_OPT_PACK_AHEAD 0x80 +#define RUY_OPT_PREFETCH_LOAD 0x100 +#define RUY_OPT_PREFETCH_STORE 0x200 +#define RUY_OPT_FRACTAL_Z 0x400 +#define RUY_OPT_FRACTAL_U 0x800 +#define RUY_OPT_FRACTAL_HILBERT 0x1000 + +#if !defined(RUY_OPT_SET) +#ifdef RUY_OPTIMIZE_FOR_MATMUL_BENCHMARK +// Load prefetching is detrimental in matrix multiplication benchmarks. +// Store prefetching is not. +#define RUY_OPT_SET (~RUY_OPT_PREFETCH_LOAD) +#else +// Default to all optimizations. +#define RUY_OPT_SET (~0) +#endif +#endif + +#define RUY_OPT_ENABLED(ruy_opt) ((RUY_OPT_SET & ruy_opt) != 0) + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_OPT_SET_H_ diff --git a/ruy/pack.h b/ruy/pack.h new file mode 100644 index 0000000..e066663 --- /dev/null +++ b/ruy/pack.h @@ -0,0 +1,98 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// # What is "packing"? +// +// Before feeding data to the gemm kernels (the parts of Ruy that do lots +// of multiply-add operations), Ruy first performs a data transformation (which +// we call "packing") on the input matrices. This transformation has two main +// goals: +// - rearrange data into blocks that are a convenient size/layout for the gemm +// kernels to consume. This helps make the memory access pattern of the gemm +// kernel simpler and more contiguous, and puts the data in a layout most +// convenient for specific arithmetic instructions in the gemm kernel. +// - compute row/column sums needed for handling quantization with non-symmetric +// zero points. +// +// # Simplified algorithmic analysis of packing +// +// Packing is a relatively simple transformation which does a small constant +// amount of work on each element of an input matrix, and hence for an NxM +// matrix performs O(N*M) work. If N and M are of the same order, then this is +// O(N^2) work. +// +// A NxKxM matrix multiplication requires N*K*M multiply-accumulate operations. +// Note that if N, K, and M are all the same order, then the number of +// multiply-accumulate operations is O(N^3). +// +// Thus, the O(N^2) cost of packing is small compared to the O(N^3) work, in the +// case of all dimensions being roughly the same order. +// +// # Packing cost can be significant +// +// When matrix * matrix multiplications begin to look more like matrix * vector +// multiplications, packing cost can become significant. We sometimes call these +// cases "gemv-like". +// +// Continuing the algorithmic analysis above, if we consider a case where an +// NxKxM matrix multiplication has either N = O(1) or M = O(1), then the +// situation is different. In this case, the multiply-accumulate work is only +// quadratic, so the quadratic cost of packing can be come significant. +// +// Another way to say this is that the cost of packing an input matrix (either +// the LHS or RHS) is amortized across the non-depth dimension of the opposite +// input matrix. Thus, when the LHS has very few rows or the RHS has very few +// columns, the cost of packing the opposite input matrix can become +// significant. +// +// As a rough rule of thumb, the cost of packing starts to become significant +// when either N or M is below 32 (and other dimensions are hundreds), with very +// significant packing costs at 8 or below. This varies by data type, Path, and +// tuning, so these numbers are only rough guides. +// +// One practical use case that is affected by this is inference of +// fully connected neural network layers with a low batch size. The weight +// matrix (which is a constant for inference) is the one affected by significant +// packing cost. +// +// Ruy provides an API in ruy_advanced.h for advanced users to pre-pack +// input matrices that are affected by significant packing costs. +// +// # Implementation notes +// +// Ruy's packing routines always operate on a range of columns and can be +// applied to either the LHS or RHS. This is possible because Ruy internally +// implements a TrMul, so the accumulation along depth is done along columns of +// both the LHS and RHS (whereas for a normal Mul the accumulation along depth +// for the LHS is along rows). As another example, we are always computing +// column sums for quantization (and never row sums, since the LHS is +// transposed). + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PACK_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PACK_H_ + +#include "ruy/platform.h" + +// IWYU pragma: begin_exports +#if RUY_PLATFORM(NEON) +#include "ruy/pack_arm.h" +#elif RUY_PLATFORM(X86) +#include "ruy/pack_x86.h" +#else +#include "ruy/pack_common.h" +#endif +// IWYU pragma: end_exports + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PACK_H_ diff --git a/ruy/pack_arm.cc b/ruy/pack_arm.cc new file mode 100644 index 0000000..8b68a39 --- /dev/null +++ b/ruy/pack_arm.cc @@ -0,0 +1,1936 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include "ruy/common.h" +#include "ruy/opt_set.h" +#include "ruy/pack.h" +#include "ruy/platform.h" +#include "ruy/profiler/instrumentation.h" + +namespace ruy { + +#if RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM) + +void Pack8bitNeonOutOfOrder(const void* src_ptr0, const void* src_ptr1, + const void* src_ptr2, const void* src_ptr3, + int src_inc0, int src_inc1, int src_inc2, + int src_inc3, int src_rows, int src_zero_point, + std::int8_t* packed_ptr, int start_col, int end_col, + std::int32_t* sums_ptr, int input_xor) { + profiler::ScopeLabel label("Pack (kNeon, optimized for out-of-order cores)"); + asm volatile( + // clang-format off + "dup v26.16b, %w[input_xor]\n" + "mov w1, #0\n" + "dup v28.4s, wzr\n" + "dup v29.4s, wzr\n" + "dup v30.4s, wzr\n" + "dup v31.4s, wzr\n" + + "and w2, %w[rows], #-16\n" + "cmp w1, w2\n" + "beq 3f\n" + + "add w1, w1, #16\n" + "ld1 {v0.16b}, [%[src_ptr0]], %[src_inc0]\n" + "ld1 {v1.16b}, [%[src_ptr1]], %[src_inc1]\n" + "cmp w1, w2\n" + "ld1 {v2.16b}, [%[src_ptr2]], %[src_inc2]\n" + "ld1 {v3.16b}, [%[src_ptr3]], %[src_inc3]\n" + "beq 2f\n" + + "1:\n" + + "add w1, w1, #16\n" + "eor v4.16b, v0.16b, v26.16b\n" + "ld1 {v0.16b}, [%[src_ptr0]], %[src_inc0]\n" + "eor v5.16b, v1.16b, v26.16b\n" + "ld1 {v1.16b}, [%[src_ptr1]], %[src_inc1]\n" + "eor v6.16b, v2.16b, v26.16b\n" + "ld1 {v2.16b}, [%[src_ptr2]], %[src_inc2]\n" + "eor v7.16b, v3.16b, v26.16b\n" + "ld1 {v3.16b}, [%[src_ptr3]], %[src_inc3]\n" + + "saddlp v16.8h, v4.16b\n" + "str q4, [%[packed_ptr], #0]\n" + "saddlp v17.8h, v5.16b\n" + "str q5, [%[packed_ptr], #16]\n" + "saddlp v18.8h, v6.16b\n" + "str q6, [%[packed_ptr], #32]\n" + "saddlp v19.8h, v7.16b\n" + "str q7, [%[packed_ptr], #48]\n" + "sadalp v28.4s, v16.8h\n" + "cmp w1, w2\n" + "sadalp v29.4s, v17.8h\n" + "add %[packed_ptr], %[packed_ptr], #64\n" + "sadalp v30.4s, v18.8h\n" + "sadalp v31.4s, v19.8h\n" + + "bne 1b\n" + + "2:\n" + + "eor v4.16b, v0.16b, v26.16b\n" + "eor v5.16b, v1.16b, v26.16b\n" + "eor v6.16b, v2.16b, v26.16b\n" + "eor v7.16b, v3.16b, v26.16b\n" + + "saddlp v16.8h, v4.16b\n" + "str q4, [%[packed_ptr], #0]\n" + "saddlp v17.8h, v5.16b\n" + "str q5, [%[packed_ptr], #16]\n" + "saddlp v18.8h, v6.16b\n" + "str q6, [%[packed_ptr], #32]\n" + "saddlp v19.8h, v7.16b\n" + "str q7, [%[packed_ptr], #48]\n" + "sadalp v28.4s, v16.8h\n" + "sadalp v29.4s, v17.8h\n" + "sadalp v30.4s, v18.8h\n" + "sadalp v31.4s, v19.8h\n" + + "add %[packed_ptr], %[packed_ptr], #64\n" + + "3:\n" + + "ands w2, %w[rows], #15\n" + "beq 4f\n" + "dup v0.16b, %w[src_zero_point]\n" + "dup v1.16b, %w[src_zero_point]\n" + "dup v2.16b, %w[src_zero_point]\n" + "dup v3.16b, %w[src_zero_point]\n" +#define RUY_LOAD_ONE_ROW(R) \ + "cmp w2, #" #R "\n" \ + "beq 5f\n" \ + "ld1 { v0.b }[" #R "], [%[src_ptr0]], #1\n" \ + "ld1 { v1.b }[" #R "], [%[src_ptr1]], #1\n" \ + "ld1 { v2.b }[" #R "], [%[src_ptr2]], #1\n" \ + "ld1 { v3.b }[" #R "], [%[src_ptr3]], #1\n" + + RUY_LOAD_ONE_ROW(0) + RUY_LOAD_ONE_ROW(1) + RUY_LOAD_ONE_ROW(2) + RUY_LOAD_ONE_ROW(3) + RUY_LOAD_ONE_ROW(4) + RUY_LOAD_ONE_ROW(5) + RUY_LOAD_ONE_ROW(6) + RUY_LOAD_ONE_ROW(7) + RUY_LOAD_ONE_ROW(8) + RUY_LOAD_ONE_ROW(9) + RUY_LOAD_ONE_ROW(10) + RUY_LOAD_ONE_ROW(11) + RUY_LOAD_ONE_ROW(12) + RUY_LOAD_ONE_ROW(13) + RUY_LOAD_ONE_ROW(14) + RUY_LOAD_ONE_ROW(15) +#undef RUY_LOAD_ONE_ROW + "5:\n" + + "eor v4.16b, v0.16b, v26.16b\n" + "eor v5.16b, v1.16b, v26.16b\n" + "eor v6.16b, v2.16b, v26.16b\n" + "eor v7.16b, v3.16b, v26.16b\n" + + "saddlp v16.8h, v4.16b\n" + "saddlp v17.8h, v5.16b\n" + "saddlp v18.8h, v6.16b\n" + "saddlp v19.8h, v7.16b\n" + "sadalp v28.4s, v16.8h\n" + "sadalp v29.4s, v17.8h\n" + "sadalp v30.4s, v18.8h\n" + "sadalp v31.4s, v19.8h\n" + + "str q4, [%[packed_ptr], #0]\n" + "str q5, [%[packed_ptr], #16]\n" + "str q6, [%[packed_ptr], #32]\n" + "str q7, [%[packed_ptr], #48]\n" + "add %[packed_ptr], %[packed_ptr], #64\n" + + "4:\n" + + "addp v28.4s, v28.4s, v29.4s\n" + "addp v30.4s, v30.4s, v31.4s\n" + "addp v28.4s, v28.4s, v30.4s\n" + + "cmp %[sums_ptr], #0\n" + "beq 6f\n" + "st1 {v28.4s}, [%[sums_ptr]], #16\n" + "6:\n" + // clang-format on + + : [ src_ptr0 ] "+r"(src_ptr0), [ src_ptr1 ] "+r"(src_ptr1), + [ src_ptr2 ] "+r"(src_ptr2), [ src_ptr3 ] "+r"(src_ptr3), + [ packed_ptr ] "+r"(packed_ptr), [ sums_ptr ] "+r"(sums_ptr) + : [ src_inc0 ] "r"(static_cast(src_inc0)), + [ src_inc1 ] "r"(static_cast(src_inc1)), + [ src_inc2 ] "r"(static_cast(src_inc2)), + [ src_inc3 ] "r"(static_cast(src_inc3)), + [ rows ] "r"(src_rows), [ src_zero_point ] "r"(src_zero_point), + [ input_xor ] "r"(input_xor) + : "cc", "memory", "x1", "x2", "v0", "v1", "v2", "v3", "v4", "v5", "v6", + "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", + "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", + "v27", "v28", "v29", "v30", "v31"); +} +#endif + +#if RUY_PLATFORM(NEON_32) && RUY_OPT_ENABLED(RUY_OPT_ASM) + +#define RUY_OFFSET_SRC_PTR0 0 +#define RUY_OFFSET_SRC_PTR1 4 +#define RUY_OFFSET_SRC_PTR2 8 +#define RUY_OFFSET_SRC_PTR3 12 +#define RUY_OFFSET_SUMS_PTR 16 +#define RUY_OFFSET_PACKED_PTR 20 +#define RUY_OFFSET_SRC_INC0 24 +#define RUY_OFFSET_SRC_INC1 28 +#define RUY_OFFSET_SRC_INC2 32 +#define RUY_OFFSET_SRC_INC3 36 +#define RUY_OFFSET_SRC_ROWS 40 +#define RUY_OFFSET_SRC_ZERO_POINT 44 +#define RUY_OFFSET_INPUT_XOR 48 + +template +void CheckOffsetsInPackParams8bit(const Params&) { + static_assert(offsetof(Params, src_ptr0) == RUY_OFFSET_SRC_PTR0, ""); + static_assert(offsetof(Params, src_ptr1) == RUY_OFFSET_SRC_PTR1, ""); + static_assert(offsetof(Params, src_ptr2) == RUY_OFFSET_SRC_PTR2, ""); + static_assert(offsetof(Params, src_ptr3) == RUY_OFFSET_SRC_PTR3, ""); + static_assert(offsetof(Params, sums_ptr) == RUY_OFFSET_SUMS_PTR, ""); + static_assert(offsetof(Params, packed_ptr) == RUY_OFFSET_PACKED_PTR, ""); + static_assert(offsetof(Params, src_inc0) == RUY_OFFSET_SRC_INC0, ""); + static_assert(offsetof(Params, src_inc1) == RUY_OFFSET_SRC_INC1, ""); + static_assert(offsetof(Params, src_inc2) == RUY_OFFSET_SRC_INC2, ""); + static_assert(offsetof(Params, src_inc3) == RUY_OFFSET_SRC_INC3, ""); + static_assert(offsetof(Params, src_rows) == RUY_OFFSET_SRC_ROWS, ""); + static_assert(offsetof(Params, src_zero_point) == RUY_OFFSET_SRC_ZERO_POINT, + ""); + static_assert(offsetof(Params, input_xor) == RUY_OFFSET_INPUT_XOR, ""); +} + +// Packing code for out-of-order ARMv7 CPUs like the Krait 400 or A9. +// No attempt made at making this code efficient on in-order cores yet. +void Pack8bitNeonOutOfOrder4Cols(const PackParams8bit& params) { + CheckOffsetsInPackParams8bit(params); + profiler::ScopeLabel label("Pack (kNeon, optimized for out-of-order cores)"); + const void* src_ptr0 = params.src_ptr0; + const void* src_ptr1 = params.src_ptr1; + const void* src_ptr2 = params.src_ptr2; + const void* src_ptr3 = params.src_ptr3; + const int src_inc0 = params.src_inc0; + const int src_inc1 = params.src_inc1; + const int src_inc2 = params.src_inc2; + const int src_inc3 = params.src_inc3; + const std::int8_t* packed_ptr = params.packed_ptr; + + asm volatile( + // clang-format off + + "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_INPUT_XOR) "]\n" + "vdup.8 q11, r2\n" + "mov r1, #0\n" + // Zero-out the accumulators + "vmov.i32 q12, #0\n" + "vmov.i32 q13, #0\n" + "vmov.i32 q14, #0\n" + "vmov.i32 q15, #0\n" + + // Round down src_rows to nearest multiple of 16. + "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SRC_ROWS) "]\n" + "and r2, r3, #-16\n" + "cmp r1, r2\n" + "beq 3f\n" + + "1:\n" + "add r1, r1, #16\n" + /* Load q0 */ + "vld1.8 {d0, d1}, [%[src_ptr0]]\n" + "add %[src_ptr0], %[src_ptr0], %[src_inc0]\n" + RUY_PREFETCH_LOAD("pld [%[src_ptr0]]\n") + + /* Load q1 */ + "vld1.8 {d2, d3}, [%[src_ptr1]]\n" + "add %[src_ptr1], %[src_ptr1], %[src_inc1]\n" + RUY_PREFETCH_LOAD("pld [%[src_ptr1]]\n") + + "veor.8 q4, q0, q11\n" + "veor.8 q5, q1, q11\n" + + // Pairwise add in to 16b accumulators. + "vpaddl.s8 q8, q4\n" + "vpaddl.s8 q9, q5\n" + + "vst1.32 {q4}, [%[packed_ptr]]!\n" + "vst1.32 {q5}, [%[packed_ptr]]!\n" + + // Pairwise add accumulate into 32b accumulators. + // q12 and q13 contain 4x32b accumulators + "vpadal.s16 q12, q8\n" + "vpadal.s16 q13, q9\n" + + // Now do the same for src_ptr2 and src_ptr3. + "vld1.8 {d0, d1}, [%[src_ptr2]]\n" + "add %[src_ptr2], %[src_ptr2], %[src_inc2]\n" + RUY_PREFETCH_LOAD("pld [%[src_ptr2]]\n") + + "vld1.8 {d2, d3}, [%[src_ptr3]]\n" + "add %[src_ptr3], %[src_ptr3], %[src_inc3]\n" + RUY_PREFETCH_LOAD("pld [%[src_ptr3]]\n") + + "veor.8 q4, q0, q11\n" + "veor.8 q5, q1, q11\n" + + "vpaddl.s8 q8, q4\n" + "vpaddl.s8 q9, q5\n" + + "vst1.32 {q4}, [%[packed_ptr]]!\n" + "vst1.32 {q5}, [%[packed_ptr]]!\n" + + // Pairwise add accumulate into 32b accumulators. + // q14 and q15 contain 4x32b accumulators + "vpadal.s16 q14, q8\n" + "vpadal.s16 q15, q9\n" + + "cmp r1, r2\n" + "bne 1b\n" + + "3:\n" + + // Now pack the last (num_rows % 16) rows. + "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SRC_ROWS) "]\n" + "ands r2, r3, #15\n" + "beq 4f\n" + "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SRC_ZERO_POINT) "]\n" + "vdup.8 q0, r3\n" + "vdup.8 q1, r3\n" + +// First, read/accumulate/write for src_ptr0 and src_ptr1. +#define RUY_LOAD_ONE_ROW1(I, R) \ + "cmp r2, #" #I "\n" \ + "beq 5f\n" \ + "vld1.8 { d0[" #R "]}, [%[src_ptr0]]!\n" \ + "vld1.8 { d2[" #R "]}, [%[src_ptr1]]!\n" \ + + RUY_LOAD_ONE_ROW1(0, 0) + RUY_LOAD_ONE_ROW1(1, 1) + RUY_LOAD_ONE_ROW1(2, 2) + RUY_LOAD_ONE_ROW1(3, 3) + RUY_LOAD_ONE_ROW1(4, 4) + RUY_LOAD_ONE_ROW1(5, 5) + RUY_LOAD_ONE_ROW1(6, 6) + RUY_LOAD_ONE_ROW1(7, 7) +#undef RUY_LOAD_ONE_ROW1 + +#define RUY_LOAD_ONE_ROW2(I, R) \ + "cmp r2, #" #I "\n" \ + "beq 5f\n" \ + "vld1.8 { d1[" #R "]}, [%[src_ptr0]]!\n" \ + "vld1.8 { d3[" #R "]}, [%[src_ptr1]]!\n" \ + + RUY_LOAD_ONE_ROW2(8, 0) + RUY_LOAD_ONE_ROW2(9, 1) + RUY_LOAD_ONE_ROW2(10, 2) + RUY_LOAD_ONE_ROW2(11, 3) + RUY_LOAD_ONE_ROW2(12, 4) + RUY_LOAD_ONE_ROW2(13, 5) + RUY_LOAD_ONE_ROW2(14, 6) + RUY_LOAD_ONE_ROW2(15, 7) +#undef RUY_LOAD_ONE_ROW2 + + "5:\n" + + "veor.16 q4, q0, q11\n" + "veor.16 q5, q1, q11\n" + + "vpaddl.s8 q8, q4\n" + "vpaddl.s8 q9, q5\n" + + // Pairwise add accumulate to 4x32b accumulators. + "vpadal.s16 q12, q8\n" + "vpadal.s16 q13, q9\n" + + "vst1.32 {q4}, [%[packed_ptr]]!\n" + "vst1.32 {q5}, [%[packed_ptr]]!\n" + + // Reset to src_zero for src_ptr2 and src_ptr3. + "vdup.8 q0, r3\n" + "vdup.8 q1, r3\n" + +// Next, read/accumulate/write for src_ptr2 and src_ptr3. +#define RUY_LOAD_ONE_ROW1(I, R) \ + "cmp r2, #" #I "\n" \ + "beq 5f\n" \ + "vld1.8 { d0[" #R "]}, [%[src_ptr2]]!\n" \ + "vld1.8 { d2[" #R "]}, [%[src_ptr3]]!\n" \ + + RUY_LOAD_ONE_ROW1(0, 0) + RUY_LOAD_ONE_ROW1(1, 1) + RUY_LOAD_ONE_ROW1(2, 2) + RUY_LOAD_ONE_ROW1(3, 3) + RUY_LOAD_ONE_ROW1(4, 4) + RUY_LOAD_ONE_ROW1(5, 5) + RUY_LOAD_ONE_ROW1(6, 6) + RUY_LOAD_ONE_ROW1(7, 7) +#undef RUY_LOAD_ONE_ROW1 + +#define RUY_LOAD_ONE_ROW2(I, R) \ + "cmp r2, #" #I "\n" \ + "beq 5f\n" \ + "vld1.8 { d1[" #R "]}, [%[src_ptr2]]!\n" \ + "vld1.8 { d3[" #R "]}, [%[src_ptr3]]!\n" \ + + RUY_LOAD_ONE_ROW2(8, 0) + RUY_LOAD_ONE_ROW2(9, 1) + RUY_LOAD_ONE_ROW2(10, 2) + RUY_LOAD_ONE_ROW2(11, 3) + RUY_LOAD_ONE_ROW2(12, 4) + RUY_LOAD_ONE_ROW2(13, 5) + RUY_LOAD_ONE_ROW2(14, 6) + RUY_LOAD_ONE_ROW2(15, 7) +#undef RUY_LOAD_ONE_ROW2 + + "5:\n" + + "veor.16 q4, q0, q11\n" + "veor.16 q5, q1, q11\n" + + "vpaddl.s8 q8, q4\n" + "vpaddl.s8 q9, q5\n" + + // Pairwise add accumulate to 4x32b accumulators. + "vpadal.s16 q14, q8\n" + "vpadal.s16 q15, q9\n" + + "vst1.32 {q4}, [%[packed_ptr]]!\n" + "vst1.32 {q5}, [%[packed_ptr]]!\n" + + "4:\n" + // Pairwise add 32-bit accumulators + "vpadd.i32 d24, d24, d25\n" + "vpadd.i32 d26, d26, d27\n" + "vpadd.i32 d28, d28, d29\n" + "vpadd.i32 d30, d30, d31\n" + // Final 32-bit values per row + "vpadd.i32 d25, d24, d26\n" + "vpadd.i32 d27, d28, d30\n" + + "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SUMS_PTR) "]\n" + "cmp r3, #0\n" + "beq 6f\n" + "vst1.32 {d25}, [r3]!\n" + "vst1.32 {d27}, [r3]!\n" + "6:\n" + // clang-format on + + : [ src_ptr0 ] "+r"(src_ptr0), [ src_ptr1 ] "+r"(src_ptr1), + [ src_ptr2 ] "+r"(src_ptr2), [ src_ptr3 ] "+r"(src_ptr3) + : [ src_inc0 ] "r"(src_inc0), [ src_inc1 ] "r"(src_inc1), + [ src_inc2 ] "r"(src_inc2), [ src_inc3 ] "r"(src_inc3), + [ packed_ptr ] "r"(packed_ptr), [ params ] "r"(¶ms) + : "cc", "memory", "r1", "r2", "r3", "q0", "q1", "q2", "q3", + "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13"); +} + +// Packing code for out-of-order ARMv7 CPUs like the Krait 400 or A9. +// No attempt made at making this code efficient on in-order cores yet. +// This version differs from the above in that we only handle two columns +// at a time. +void Pack8bitNeonOutOfOrder2Cols(const PackParams8bit& params) { + CheckOffsetsInPackParams8bit(params); + profiler::ScopeLabel label("Pack (kNeon, optimized for out-of-order cores)"); + const void* src_ptr0 = params.src_ptr0; + const void* src_ptr1 = params.src_ptr1; + const int src_inc0 = params.src_inc0; + const int src_inc1 = params.src_inc1; + const std::int8_t* packed_ptr = params.packed_ptr; + + asm volatile( + // clang-format off + + "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_INPUT_XOR) "]\n" + "vdup.8 q11, r2\n" + "mov r1, #0\n" + // Zero-out the accumulators + "vmov.i32 q12, #0\n" + "vmov.i32 q13, #0\n" + + // Round down src_rows to nearest multiple of 16. + "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SRC_ROWS) "]\n" + "and r2, r3, #-16\n" + "cmp r1, r2\n" + "beq 3f\n" + + "1:\n" + "add r1, r1, #16\n" + /* Load q0 */ + "vld1.8 {d0, d1}, [%[src_ptr0]]\n" + "add %[src_ptr0], %[src_ptr0], %[src_inc0]\n" + + /* Load q1 */ + "vld1.8 {d2, d3}, [%[src_ptr1]]\n" + "add %[src_ptr1], %[src_ptr1], %[src_inc1]\n" + + "veor.8 q4, q0, q11\n" + "veor.8 q5, q1, q11\n" + + // Pairwise add in to 16b accumulators. + "vpaddl.s8 q8, q4\n" + "vpaddl.s8 q9, q5\n" + + "vst1.32 {q4}, [%[packed_ptr]]!\n" + "vst1.32 {q5}, [%[packed_ptr]]!\n" + + // Pairwise add accumulate into 32b accumulators. + // q12 and q13 contain 4x32b accumulators + "vpadal.s16 q12, q8\n" + "vpadal.s16 q13, q9\n" + + "cmp r1, r2\n" + + "bne 1b\n" + + "3:\n" + + // Now pack the last (num_rows % 16) rows. + "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SRC_ROWS) "]\n" + "ands r2, r3, #15\n" + "beq 4f\n" + "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SRC_ZERO_POINT) "]\n" + "vdup.8 q0, r3\n" + "vdup.8 q1, r3\n" + +// Read/accumulate/write for src_ptr0 and src_ptr1. +#define RUY_LOAD_ONE_ROW1(I, R) \ + "cmp r2, #" #I "\n" \ + "beq 5f\n" \ + "vld1.8 { d0[" #R "]}, [%[src_ptr0]]!\n" \ + "vld1.8 { d2[" #R "]}, [%[src_ptr1]]!\n" \ + + RUY_LOAD_ONE_ROW1(0, 0) + RUY_LOAD_ONE_ROW1(1, 1) + RUY_LOAD_ONE_ROW1(2, 2) + RUY_LOAD_ONE_ROW1(3, 3) + RUY_LOAD_ONE_ROW1(4, 4) + RUY_LOAD_ONE_ROW1(5, 5) + RUY_LOAD_ONE_ROW1(6, 6) + RUY_LOAD_ONE_ROW1(7, 7) +#undef RUY_LOAD_ONE_ROW1 + +#define RUY_LOAD_ONE_ROW2(I, R) \ + "cmp r2, #" #I "\n" \ + "beq 5f\n" \ + "vld1.8 { d1[" #R "]}, [%[src_ptr0]]!\n" \ + "vld1.8 { d3[" #R "]}, [%[src_ptr1]]!\n" \ + + RUY_LOAD_ONE_ROW2(8, 0) + RUY_LOAD_ONE_ROW2(9, 1) + RUY_LOAD_ONE_ROW2(10, 2) + RUY_LOAD_ONE_ROW2(11, 3) + RUY_LOAD_ONE_ROW2(12, 4) + RUY_LOAD_ONE_ROW2(13, 5) + RUY_LOAD_ONE_ROW2(14, 6) + RUY_LOAD_ONE_ROW2(15, 7) +#undef RUY_LOAD_ONE_ROW2 + + "5:\n" + + "veor.16 q4, q0, q11\n" + "veor.16 q5, q1, q11\n" + + "vpaddl.s8 q8, q4\n" + "vpaddl.s8 q9, q5\n" + + + // Pairwise add accumulate to 4x32b accumulators. + "vpadal.s16 q12, q8\n" + "vpadal.s16 q13, q9\n" + + "vst1.32 {q4}, [%[packed_ptr]]!\n" + "vst1.32 {q5}, [%[packed_ptr]]!\n" + + "4:\n" + + // Pairwise add 32-bit accumulators + "vpadd.i32 d24, d24, d25\n" + "vpadd.i32 d26, d26, d27\n" + // Final 32-bit values per row + "vpadd.i32 d25, d24, d26\n" + + "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_SUMS_PTR) "]\n" + "cmp r3, #0\n" + "beq 6f\n" + "vst1.32 {d25}, [r3]!\n" + "6:\n" + // clang-format on + + : [ src_ptr0 ] "+r"(src_ptr0), [ src_ptr1 ] "+r"(src_ptr1) + : [ src_inc0 ] "r"(src_inc0), [ src_inc1 ] "r"(src_inc1), + [ packed_ptr ] "r"(packed_ptr), [ params ] "r"(¶ms) + : "cc", "memory", "r1", "r2", "r3", "q0", "q1", "q2", "q3", + "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13"); +} + +#undef RUY_OFFSET_SRC_PTR0 +#undef RUY_OFFSET_SRC_PTR1 +#undef RUY_OFFSET_SRC_PTR2 +#undef RUY_OFFSET_SRC_PTR32 +#undef RUY_OFFSET_SUMS_PTR +#undef RUY_OFFSET_PACKED_PTR0 +#undef RUY_OFFSET_SRC_INC0 +#undef RUY_OFFSET_SRC_INC1 +#undef RUY_OFFSET_SRC_INC2 +#undef RUY_OFFSET_SRC_INC3 +#undef RUY_OFFSET_SRC_ROWS +#undef RUY_OFFSET_SRC_ZERO_POINT +#undef RUY_OFFSET_INPUT_XOR + +#endif // RUY_PLATFORM(NEON_32) && RUY_OPT_ENABLED(RUY_OPT_ASM) + +#if RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM) + +void Pack8bitNeonInOrder(const void* src_ptr0, const void* src_ptr1, + const void* src_ptr2, const void* src_ptr3, + int src_inc0, int src_inc1, int src_inc2, int src_inc3, + int src_rows, int src_zero_point, + std::int8_t* packed_ptr, int start_col, int end_col, + std::int32_t* sums_ptr, int input_xor) { + profiler::ScopeLabel label("Pack (kNeon, optimized for in-order cores)"); + asm volatile( + // clang-format off + "dup v26.16b, %w[input_xor]\n" + "mov w1, #0\n" + "dup v28.4s, wzr\n" + "dup v29.4s, wzr\n" + "dup v30.4s, wzr\n" + "dup v31.4s, wzr\n" + + "and w2, %w[rows], #-16\n" + "cmp w1, w2\n" + "beq 3f\n" + "ldr x10, [%[src_ptr0], #8]\n" + "ld1 {v0.8b}, [%[src_ptr0]], %[src_inc0]\n" + "ldr x11, [%[src_ptr1], #8]\n" + "ld1 {v1.8b}, [%[src_ptr1]], %[src_inc1]\n" + "ldr x12, [%[src_ptr2], #8]\n" + "ld1 {v2.8b}, [%[src_ptr2]], %[src_inc2]\n" + "ldr x13, [%[src_ptr3], #8]\n" + "ld1 {v3.8b}, [%[src_ptr3]], %[src_inc3]\n" + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #64]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #64]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #64]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #64]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #128]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #128]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #128]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #128]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #192]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #192]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #192]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #192]\n") + "add w1, w1, #16\n" + "cmp w1, w2\n" + + "beq 2f\n" + + "1:\n" + "add w1, w1, #16\n" + "ins v0.d[1], x10\n" + "ldr x10, [%[src_ptr0], #8]\n" + "ins v1.d[1], x11\n" + "ldr x11, [%[src_ptr1], #8]\n" + "ins v2.d[1], x12\n" + "ldr x12, [%[src_ptr2], #8]\n" + "ins v3.d[1], x13\n" + "ldr x13, [%[src_ptr3], #8]\n" + "eor v4.16b, v0.16b, v26.16b\n" + "ld1 {v0.8b}, [%[src_ptr0]], %[src_inc0]\n" + "eor v5.16b, v1.16b, v26.16b\n" + "ld1 {v1.8b}, [%[src_ptr1]], %[src_inc1]\n" + "eor v6.16b, v2.16b, v26.16b\n" + "ld1 {v2.8b}, [%[src_ptr2]], %[src_inc2]\n" + "eor v7.16b, v3.16b, v26.16b\n" + "ld1 {v3.8b}, [%[src_ptr3]], %[src_inc3]\n" + "saddlp v16.8h, v4.16b\n" + "str q4, [%[packed_ptr], #0]\n" + "saddlp v17.8h, v5.16b\n" + "str q5, [%[packed_ptr], #16]\n" + "saddlp v18.8h, v6.16b\n" + "str q6, [%[packed_ptr], #32]\n" + "saddlp v19.8h, v7.16b\n" + "str q7, [%[packed_ptr], #48]\n" + "sadalp v28.4s, v16.8h\n" + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #240]\n") + "cmp w1, w2\n" + "sadalp v29.4s, v17.8h\n" + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #240]\n") + "add %[packed_ptr], %[packed_ptr], #64\n" + "sadalp v30.4s, v18.8h\n" + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #240]\n") + "sadalp v31.4s, v19.8h\n" + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #240]\n") + + "bne 1b\n" + + "2:\n" + "ins v0.d[1], x10\n" + "ins v1.d[1], x11\n" + "ins v2.d[1], x12\n" + "ins v3.d[1], x13\n" + "eor v4.16b, v0.16b, v26.16b\n" + "eor v5.16b, v1.16b, v26.16b\n" + "eor v6.16b, v2.16b, v26.16b\n" + "eor v7.16b, v3.16b, v26.16b\n" + + "saddlp v16.8h, v4.16b\n" + "str q4, [%[packed_ptr], #0]\n" + "saddlp v17.8h, v5.16b\n" + "str q5, [%[packed_ptr], #16]\n" + "saddlp v18.8h, v6.16b\n" + "str q6, [%[packed_ptr], #32]\n" + "saddlp v19.8h, v7.16b\n" + "str q7, [%[packed_ptr], #48]\n" + "sadalp v28.4s, v16.8h\n" + "sadalp v29.4s, v17.8h\n" + "sadalp v30.4s, v18.8h\n" + "sadalp v31.4s, v19.8h\n" + + "add %[packed_ptr], %[packed_ptr], #64\n" + + "3:\n" + + "ands w2, %w[rows], #15\n" + "beq 4f\n" + "dup v0.16b, %w[src_zero_point]\n" + "dup v1.16b, %w[src_zero_point]\n" + "dup v2.16b, %w[src_zero_point]\n" + "dup v3.16b, %w[src_zero_point]\n" +#define RUY_LOAD_ONE_ROW(R) \ + "cmp w2, #" #R "\n" \ + "beq 5f\n" \ + "ld1 { v0.b }[" #R "], [%[src_ptr0]], #1\n" \ + "ld1 { v1.b }[" #R "], [%[src_ptr1]], #1\n" \ + "ld1 { v2.b }[" #R "], [%[src_ptr2]], #1\n" \ + "ld1 { v3.b }[" #R "], [%[src_ptr3]], #1\n" + + RUY_LOAD_ONE_ROW(0) + RUY_LOAD_ONE_ROW(1) + RUY_LOAD_ONE_ROW(2) + RUY_LOAD_ONE_ROW(3) + RUY_LOAD_ONE_ROW(4) + RUY_LOAD_ONE_ROW(5) + RUY_LOAD_ONE_ROW(6) + RUY_LOAD_ONE_ROW(7) + RUY_LOAD_ONE_ROW(8) + RUY_LOAD_ONE_ROW(9) + RUY_LOAD_ONE_ROW(10) + RUY_LOAD_ONE_ROW(11) + RUY_LOAD_ONE_ROW(12) + RUY_LOAD_ONE_ROW(13) + RUY_LOAD_ONE_ROW(14) + RUY_LOAD_ONE_ROW(15) +#undef RUY_LOAD_ONE_ROW + "5:\n" + + "eor v4.16b, v0.16b, v26.16b\n" + "eor v5.16b, v1.16b, v26.16b\n" + "eor v6.16b, v2.16b, v26.16b\n" + "eor v7.16b, v3.16b, v26.16b\n" + + "saddlp v16.8h, v4.16b\n" + "saddlp v17.8h, v5.16b\n" + "saddlp v18.8h, v6.16b\n" + "saddlp v19.8h, v7.16b\n" + "sadalp v28.4s, v16.8h\n" + "sadalp v29.4s, v17.8h\n" + "sadalp v30.4s, v18.8h\n" + "sadalp v31.4s, v19.8h\n" + + "str q4, [%[packed_ptr], #0]\n" + "str q5, [%[packed_ptr], #16]\n" + "str q6, [%[packed_ptr], #32]\n" + "str q7, [%[packed_ptr], #48]\n" + "add %[packed_ptr], %[packed_ptr], #64\n" + + "4:\n" + + "addp v28.4s, v28.4s, v29.4s\n" + "addp v30.4s, v30.4s, v31.4s\n" + "addp v28.4s, v28.4s, v30.4s\n" + + "cmp %[sums_ptr], #0\n" + "beq 6f\n" + "st1 {v28.4s}, [%[sums_ptr]], #16\n" + "6:\n" + // clang-format on + + : [ src_ptr0 ] "+r"(src_ptr0), [ src_ptr1 ] "+r"(src_ptr1), + [ src_ptr2 ] "+r"(src_ptr2), [ src_ptr3 ] "+r"(src_ptr3), + [ packed_ptr ] "+r"(packed_ptr), [ sums_ptr ] "+r"(sums_ptr) + : [ src_inc0 ] "r"(static_cast(src_inc0)), [ src_inc1 ] "r"(static_cast(src_inc1)), + [ src_inc2 ] "r"(static_cast(src_inc2)), [ src_inc3 ] "r"(static_cast(src_inc3)), + [ rows ] "r"(src_rows), + [ src_zero_point ] "r"(src_zero_point), + [input_xor] "r"(input_xor) + : "cc", "memory", "x1", "x2", "x10", "x11", "x12", "x13", "v0", "v1", "v2", "v3", "v4", "v5", + "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", + "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", + "v25", "v26", "v27", "v28", "v29", "v30", "v31"); +} + +void Pack8bitNeonDotprodInOrder(const void* src_ptr0, const void* src_ptr1, + const void* src_ptr2, const void* src_ptr3, + int src_inc0, int src_inc1, int src_inc2, + int src_inc3, int src_rows, int src_zero_point, + std::int8_t* packed_ptr, int start_col, + int end_col, std::int32_t* sums_ptr, + int input_xor) { + profiler::ScopeLabel label( + "Pack (kNeonDotprod, optimized for in-order cores)"); + asm volatile( + // clang-format off + "dup v26.16b, %w[input_xor]\n" + "mov w1, #1\n" + "dup v27.16b, w1\n" + "mov w1, #0\n" + "dup v28.4s, wzr\n" + "dup v29.4s, wzr\n" + "dup v30.4s, wzr\n" + "dup v31.4s, wzr\n" + + "and w2, %w[rows], #-16\n" + "cmp w1, w2\n" + "beq 3f\n" + "ldr x10, [%[src_ptr0], #8]\n" + "ld1 {v0.8b}, [%[src_ptr0]], %[src_inc0]\n" + "ldr x11, [%[src_ptr1], #8]\n" + "ld1 {v1.8b}, [%[src_ptr1]], %[src_inc1]\n" + "ldr x12, [%[src_ptr2], #8]\n" + "ld1 {v2.8b}, [%[src_ptr2]], %[src_inc2]\n" + "ldr x13, [%[src_ptr3], #8]\n" + "ld1 {v3.8b}, [%[src_ptr3]], %[src_inc3]\n" + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #64]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #64]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #64]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #64]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #128]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #128]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #128]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #128]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #192]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #192]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #192]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #192]\n") + "add w1, w1, #16\n" + "cmp w1, w2\n" + + "beq 2f\n" + + "1:\n" + "add w1, w1, #16\n" + "ins v0.d[1], x10\n" + "ldr x10, [%[src_ptr0], #8]\n" + "ins v1.d[1], x11\n" + "ldr x11, [%[src_ptr1], #8]\n" + "ins v2.d[1], x12\n" + "ldr x12, [%[src_ptr2], #8]\n" + "ins v3.d[1], x13\n" + "ldr x13, [%[src_ptr3], #8]\n" + + "eor v4.16b, v0.16b, v26.16b\n" + "ld1 {v0.8b}, [%[src_ptr0]], %[src_inc0]\n" + "eor v5.16b, v1.16b, v26.16b\n" + "ld1 {v1.8b}, [%[src_ptr1]], %[src_inc1]\n" + "eor v6.16b, v2.16b, v26.16b\n" + "ld1 {v2.8b}, [%[src_ptr2]], %[src_inc2]\n" + "eor v7.16b, v3.16b, v26.16b\n" + "ld1 {v3.8b}, [%[src_ptr3]], %[src_inc3]\n" + + "trn1 v16.4s, v4.4s, v5.4s\n" + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #240]\n") + "trn2 v17.4s, v4.4s, v5.4s\n" + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #240]\n") + "trn1 v18.4s, v6.4s, v7.4s\n" + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #240]\n") + "trn2 v19.4s, v6.4s, v7.4s\n" + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #240]\n") + + "trn1 v20.2d, v16.2d, v18.2d\n" + "trn2 v22.2d, v16.2d, v18.2d\n" + "trn1 v21.2d, v17.2d, v19.2d\n" + "trn2 v23.2d, v17.2d, v19.2d\n" + "cmp w1, w2\n" + + ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" + "str q20, [%[packed_ptr], #0]\n" + ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" + "str q21, [%[packed_ptr], #32]\n" + ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" + "str q22, [%[packed_ptr], #64]\n" + ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" + "str q23, [%[packed_ptr], #96]\n" + + "add %[packed_ptr], %[packed_ptr], #128\n" + + "bne 1b\n" + + "2:\n" + "ins v0.d[1], x10\n" + "ins v1.d[1], x11\n" + "ins v2.d[1], x12\n" + "ins v3.d[1], x13\n" + "eor v0.16b, v0.16b, v26.16b\n" + "eor v1.16b, v1.16b, v26.16b\n" + "eor v2.16b, v2.16b, v26.16b\n" + "eor v3.16b, v3.16b, v26.16b\n" + + "trn1 v16.4s, v0.4s, v1.4s\n" + "trn2 v17.4s, v0.4s, v1.4s\n" + "trn1 v18.4s, v2.4s, v3.4s\n" + "trn2 v19.4s, v2.4s, v3.4s\n" + + "trn1 v20.2d, v16.2d, v18.2d\n" + "trn2 v22.2d, v16.2d, v18.2d\n" + "trn1 v21.2d, v17.2d, v19.2d\n" + "trn2 v23.2d, v17.2d, v19.2d\n" + + ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" + "str q20, [%[packed_ptr], #0]\n" + ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" + "str q21, [%[packed_ptr], #32]\n" + ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" + "str q22, [%[packed_ptr], #64]\n" + ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" + "str q23, [%[packed_ptr], #96]\n" + "add %[packed_ptr], %[packed_ptr], #128\n" + + "3:\n" + + "ands w2, %w[rows], #15\n" + "beq 4f\n" + "dup v0.16b, %w[src_zero_point]\n" + "dup v1.16b, %w[src_zero_point]\n" + "dup v2.16b, %w[src_zero_point]\n" + "dup v3.16b, %w[src_zero_point]\n" +#define RUY_LOAD_ONE_ROW(R) \ + "cmp w2, #" #R "\n" \ + "beq 5f\n" \ + "ld1 { v0.b }[" #R "], [%[src_ptr0]], #1\n" \ + "ld1 { v1.b }[" #R "], [%[src_ptr1]], #1\n" \ + "ld1 { v2.b }[" #R "], [%[src_ptr2]], #1\n" \ + "ld1 { v3.b }[" #R "], [%[src_ptr3]], #1\n" + + RUY_LOAD_ONE_ROW(0) + RUY_LOAD_ONE_ROW(1) + RUY_LOAD_ONE_ROW(2) + RUY_LOAD_ONE_ROW(3) + RUY_LOAD_ONE_ROW(4) + RUY_LOAD_ONE_ROW(5) + RUY_LOAD_ONE_ROW(6) + RUY_LOAD_ONE_ROW(7) + RUY_LOAD_ONE_ROW(8) + RUY_LOAD_ONE_ROW(9) + RUY_LOAD_ONE_ROW(10) + RUY_LOAD_ONE_ROW(11) + RUY_LOAD_ONE_ROW(12) + RUY_LOAD_ONE_ROW(13) + RUY_LOAD_ONE_ROW(14) + RUY_LOAD_ONE_ROW(15) +#undef RUY_LOAD_ONE_ROW + "5:\n" + + "eor v0.16b, v0.16b, v26.16b\n" + "eor v1.16b, v1.16b, v26.16b\n" + "eor v2.16b, v2.16b, v26.16b\n" + "eor v3.16b, v3.16b, v26.16b\n" + + "trn1 v16.4s, v0.4s, v1.4s\n" + "trn2 v17.4s, v0.4s, v1.4s\n" + "trn1 v18.4s, v2.4s, v3.4s\n" + "trn2 v19.4s, v2.4s, v3.4s\n" + + "trn1 v20.2d, v16.2d, v18.2d\n" + "trn2 v22.2d, v16.2d, v18.2d\n" + "trn1 v21.2d, v17.2d, v19.2d\n" + "trn2 v23.2d, v17.2d, v19.2d\n" + + ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" + "str q20, [%[packed_ptr], #0]\n" + "cmp w2, #4\n" + "ble 4f\n" + ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" + "str q21, [%[packed_ptr], #32]\n" + "cmp w2, #8\n" + "ble 4f\n" + ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" + "str q22, [%[packed_ptr], #64]\n" + "cmp w2, #12\n" + "ble 4f\n" + ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" + "str q23, [%[packed_ptr], #96]\n" + "add %[packed_ptr], %[packed_ptr], #128\n" + + "4:\n" + + "add v28.4s, v28.4s, v29.4s\n" + "add v30.4s, v30.4s, v31.4s\n" + "add v28.4s, v28.4s, v30.4s\n" + + "cmp %[sums_ptr], #0\n" + "beq 6f\n" + "st1 {v28.4s}, [%[sums_ptr]], #16\n" + "6:\n" + // clang-format on + + : [ src_ptr0 ] "+r"(src_ptr0), [src_ptr1] "+r"(src_ptr1), [src_ptr2] "+r"(src_ptr2), + [src_ptr3] "+r"(src_ptr3), [packed_ptr] "+r"(packed_ptr), [sums_ptr] "+r"(sums_ptr) + : [ src_inc0 ] "r"(static_cast(src_inc0)), [ src_inc1 ] "r"(static_cast(src_inc1)), + [ src_inc2 ] "r"(static_cast(src_inc2)), [ src_inc3 ] "r"(static_cast(src_inc3)), + [rows] "r"(src_rows), + [src_zero_point] "r"(static_cast(src_zero_point)), + [input_xor] "r"(input_xor) + : "cc", "memory", "x1", "x2", "x10", "x11", "x12", "x13", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", + "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", + "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); +} + +void Pack8bitNeonDotprodOutOfOrder(const void* src_ptr0, const void* src_ptr1, + const void* src_ptr2, const void* src_ptr3, + int src_inc0, int src_inc1, int src_inc2, + int src_inc3, int src_rows, + int src_zero_point, std::int8_t* packed_ptr, + int start_col, int end_col, + std::int32_t* sums_ptr, int input_xor) { + profiler::ScopeLabel label( + "Pack (kNeonDotprod, optimized for out-of-order cores)"); + asm volatile( + // clang-format off + "dup v26.16b, %w[input_xor]\n" + "mov w1, #1\n" + "dup v27.16b, w1\n" + "mov w1, #0\n" + "dup v28.4s, wzr\n" + "dup v29.4s, wzr\n" + "dup v30.4s, wzr\n" + "dup v31.4s, wzr\n" + +#if RUY_OPT_ENABLED(RUY_OPT_MAX_STREAMING) + "and w2, %w[rows], #-64\n" + "cmp w1, w2\n" + "beq 9f\n" + + "ld1 {v0.16b}, [%[src_ptr0]], %[src_inc0]\n" + "ld1 {v1.16b}, [%[src_ptr1]], %[src_inc1]\n" + "ld1 {v2.16b}, [%[src_ptr2]], %[src_inc2]\n" + "ld1 {v3.16b}, [%[src_ptr3]], %[src_inc3]\n" + "ld1 {v4.16b}, [%[src_ptr0]], %[src_inc0]\n" + "ld1 {v5.16b}, [%[src_ptr1]], %[src_inc1]\n" + "ld1 {v6.16b}, [%[src_ptr2]], %[src_inc2]\n" + "ld1 {v7.16b}, [%[src_ptr3]], %[src_inc3]\n" + "ld1 {v8.16b}, [%[src_ptr0]], %[src_inc0]\n" + "ld1 {v9.16b}, [%[src_ptr1]], %[src_inc1]\n" + "ld1 {v10.16b}, [%[src_ptr2]], %[src_inc2]\n" + "ld1 {v11.16b}, [%[src_ptr3]], %[src_inc3]\n" + "ld1 {v12.16b}, [%[src_ptr0]], %[src_inc0]\n" + "ld1 {v13.16b}, [%[src_ptr1]], %[src_inc1]\n" + "ld1 {v14.16b}, [%[src_ptr2]], %[src_inc2]\n" + "ld1 {v15.16b}, [%[src_ptr3]], %[src_inc3]\n" + "add w1, w1, #64\n" + "cmp w1, w2\n" + "beq 8f\n" + + "7:\n" + "eor v0.16b, v0.16b, v26.16b\n" + "eor v1.16b, v1.16b, v26.16b\n" + "eor v2.16b, v2.16b, v26.16b\n" + "eor v3.16b, v3.16b, v26.16b\n" + + "trn1 v16.4s, v0.4s, v1.4s\n" + "trn2 v17.4s, v0.4s, v1.4s\n" + "trn1 v18.4s, v2.4s, v3.4s\n" + "trn2 v19.4s, v2.4s, v3.4s\n" + + "ld1 {v0.16b}, [%[src_ptr0]], %[src_inc0]\n" + "ld1 {v1.16b}, [%[src_ptr1]], %[src_inc1]\n" + "ld1 {v2.16b}, [%[src_ptr2]], %[src_inc2]\n" + "ld1 {v3.16b}, [%[src_ptr3]], %[src_inc3]\n" + "add w1, w1, #16\n" + + "trn1 v20.2d, v16.2d, v18.2d\n" + "trn2 v22.2d, v16.2d, v18.2d\n" + "trn1 v21.2d, v17.2d, v19.2d\n" + "trn2 v23.2d, v17.2d, v19.2d\n" + + ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" + ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" + ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" + ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" + + "str q20, [%[packed_ptr], #0]\n" + "str q21, [%[packed_ptr], #32]\n" + "str q22, [%[packed_ptr], #64]\n" + "str q23, [%[packed_ptr], #96]\n" + "add %[packed_ptr], %[packed_ptr], #128\n" + + "eor v4.16b, v4.16b, v26.16b\n" + "eor v5.16b, v5.16b, v26.16b\n" + "eor v6.16b, v6.16b, v26.16b\n" + "eor v7.16b, v7.16b, v26.16b\n" + + "trn1 v16.4s, v4.4s, v5.4s\n" + "trn2 v17.4s, v4.4s, v5.4s\n" + "trn1 v18.4s, v6.4s, v7.4s\n" + "trn2 v19.4s, v6.4s, v7.4s\n" + + "ld1 {v4.16b}, [%[src_ptr0]], %[src_inc0]\n" + "ld1 {v5.16b}, [%[src_ptr1]], %[src_inc1]\n" + "ld1 {v6.16b}, [%[src_ptr2]], %[src_inc2]\n" + "ld1 {v7.16b}, [%[src_ptr3]], %[src_inc3]\n" + "add w1, w1, #16\n" + + "trn1 v20.2d, v16.2d, v18.2d\n" + "trn2 v22.2d, v16.2d, v18.2d\n" + "trn1 v21.2d, v17.2d, v19.2d\n" + "trn2 v23.2d, v17.2d, v19.2d\n" + + ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" + ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" + ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" + ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" + + "str q20, [%[packed_ptr], #0]\n" + "str q21, [%[packed_ptr], #32]\n" + "str q22, [%[packed_ptr], #64]\n" + "str q23, [%[packed_ptr], #96]\n" + "add %[packed_ptr], %[packed_ptr], #128\n" + + "eor v8.16b, v8.16b, v26.16b\n" + "eor v9.16b, v9.16b, v26.16b\n" + "eor v10.16b, v10.16b, v26.16b\n" + "eor v11.16b, v11.16b, v26.16b\n" + + "trn1 v16.4s, v8.4s, v9.4s\n" + "trn2 v17.4s, v8.4s, v9.4s\n" + "trn1 v18.4s, v10.4s, v11.4s\n" + "trn2 v19.4s, v10.4s, v11.4s\n" + + "ld1 {v8.16b}, [%[src_ptr0]], %[src_inc0]\n" + "ld1 {v9.16b}, [%[src_ptr1]], %[src_inc1]\n" + "ld1 {v10.16b}, [%[src_ptr2]], %[src_inc2]\n" + "ld1 {v11.16b}, [%[src_ptr3]], %[src_inc3]\n" + "add w1, w1, #16\n" + + "trn1 v20.2d, v16.2d, v18.2d\n" + "trn2 v22.2d, v16.2d, v18.2d\n" + "trn1 v21.2d, v17.2d, v19.2d\n" + "trn2 v23.2d, v17.2d, v19.2d\n" + + ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" + ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" + ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" + ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" + + "str q20, [%[packed_ptr], #0]\n" + "str q21, [%[packed_ptr], #32]\n" + "str q22, [%[packed_ptr], #64]\n" + "str q23, [%[packed_ptr], #96]\n" + "add %[packed_ptr], %[packed_ptr], #128\n" + + "eor v12.16b, v12.16b, v26.16b\n" + "eor v13.16b, v13.16b, v26.16b\n" + "eor v14.16b, v14.16b, v26.16b\n" + "eor v15.16b, v15.16b, v26.16b\n" + + "trn1 v16.4s, v12.4s, v13.4s\n" + "trn2 v17.4s, v12.4s, v13.4s\n" + "trn1 v18.4s, v14.4s, v15.4s\n" + "trn2 v19.4s, v14.4s, v15.4s\n" + + "ld1 {v12.16b}, [%[src_ptr0]], %[src_inc0]\n" + "ld1 {v13.16b}, [%[src_ptr1]], %[src_inc1]\n" + "ld1 {v14.16b}, [%[src_ptr2]], %[src_inc2]\n" + "ld1 {v15.16b}, [%[src_ptr3]], %[src_inc3]\n" + "add w1, w1, #16\n" + + "trn1 v20.2d, v16.2d, v18.2d\n" + "trn2 v22.2d, v16.2d, v18.2d\n" + "trn1 v21.2d, v17.2d, v19.2d\n" + "trn2 v23.2d, v17.2d, v19.2d\n" + + ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" + ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" + ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" + ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" + + "str q20, [%[packed_ptr], #0]\n" + "str q21, [%[packed_ptr], #32]\n" + "str q22, [%[packed_ptr], #64]\n" + "str q23, [%[packed_ptr], #96]\n" + "add %[packed_ptr], %[packed_ptr], #128\n" + + "cmp w1, w2\n" + "bne 7b\n" + + "8:\n" + + "eor v0.16b, v0.16b, v26.16b\n" + "eor v1.16b, v1.16b, v26.16b\n" + "eor v2.16b, v2.16b, v26.16b\n" + "eor v3.16b, v3.16b, v26.16b\n" + + "trn1 v16.4s, v0.4s, v1.4s\n" + "trn2 v17.4s, v0.4s, v1.4s\n" + "trn1 v18.4s, v2.4s, v3.4s\n" + "trn2 v19.4s, v2.4s, v3.4s\n" + + "trn1 v20.2d, v16.2d, v18.2d\n" + "trn2 v22.2d, v16.2d, v18.2d\n" + "trn1 v21.2d, v17.2d, v19.2d\n" + "trn2 v23.2d, v17.2d, v19.2d\n" + + ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" + ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" + ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" + ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" + + "str q20, [%[packed_ptr], #0]\n" + "str q21, [%[packed_ptr], #32]\n" + "str q22, [%[packed_ptr], #64]\n" + "str q23, [%[packed_ptr], #96]\n" + "add %[packed_ptr], %[packed_ptr], #128\n" + + "eor v4.16b, v4.16b, v26.16b\n" + "eor v5.16b, v5.16b, v26.16b\n" + "eor v6.16b, v6.16b, v26.16b\n" + "eor v7.16b, v7.16b, v26.16b\n" + + "trn1 v16.4s, v4.4s, v5.4s\n" + "trn2 v17.4s, v4.4s, v5.4s\n" + "trn1 v18.4s, v6.4s, v7.4s\n" + "trn2 v19.4s, v6.4s, v7.4s\n" + + "trn1 v20.2d, v16.2d, v18.2d\n" + "trn2 v22.2d, v16.2d, v18.2d\n" + "trn1 v21.2d, v17.2d, v19.2d\n" + "trn2 v23.2d, v17.2d, v19.2d\n" + + ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" + ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" + ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" + ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" + + "str q20, [%[packed_ptr], #0]\n" + "str q21, [%[packed_ptr], #32]\n" + "str q22, [%[packed_ptr], #64]\n" + "str q23, [%[packed_ptr], #96]\n" + "add %[packed_ptr], %[packed_ptr], #128\n" + + "eor v8.16b, v8.16b, v26.16b\n" + "eor v9.16b, v9.16b, v26.16b\n" + "eor v10.16b, v10.16b, v26.16b\n" + "eor v11.16b, v11.16b, v26.16b\n" + + "trn1 v16.4s, v8.4s, v9.4s\n" + "trn2 v17.4s, v8.4s, v9.4s\n" + "trn1 v18.4s, v10.4s, v11.4s\n" + "trn2 v19.4s, v10.4s, v11.4s\n" + + "trn1 v20.2d, v16.2d, v18.2d\n" + "trn2 v22.2d, v16.2d, v18.2d\n" + "trn1 v21.2d, v17.2d, v19.2d\n" + "trn2 v23.2d, v17.2d, v19.2d\n" + + ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" + ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" + ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" + ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" + + "str q20, [%[packed_ptr], #0]\n" + "str q21, [%[packed_ptr], #32]\n" + "str q22, [%[packed_ptr], #64]\n" + "str q23, [%[packed_ptr], #96]\n" + "add %[packed_ptr], %[packed_ptr], #128\n" + + "eor v12.16b, v12.16b, v26.16b\n" + "eor v13.16b, v13.16b, v26.16b\n" + "eor v14.16b, v14.16b, v26.16b\n" + "eor v15.16b, v15.16b, v26.16b\n" + + "trn1 v16.4s, v12.4s, v13.4s\n" + "trn2 v17.4s, v12.4s, v13.4s\n" + "trn1 v18.4s, v14.4s, v15.4s\n" + "trn2 v19.4s, v14.4s, v15.4s\n" + + "trn1 v20.2d, v16.2d, v18.2d\n" + "trn2 v22.2d, v16.2d, v18.2d\n" + "trn1 v21.2d, v17.2d, v19.2d\n" + "trn2 v23.2d, v17.2d, v19.2d\n" + + ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" + ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" + ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" + ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" + + "str q20, [%[packed_ptr], #0]\n" + "str q21, [%[packed_ptr], #32]\n" + "str q22, [%[packed_ptr], #64]\n" + "str q23, [%[packed_ptr], #96]\n" + "add %[packed_ptr], %[packed_ptr], #128\n" + + "9:\n" +#endif // #if RUY_OPT_ENABLED(RUY_OPT_MAX_STREAMING) + "and w2, %w[rows], #-16\n" + "cmp w1, w2\n" + "beq 3f\n" + + "ld1 {v0.16b}, [%[src_ptr0]], %[src_inc0]\n" + "ld1 {v1.16b}, [%[src_ptr1]], %[src_inc1]\n" + "ld1 {v2.16b}, [%[src_ptr2]], %[src_inc2]\n" + "ld1 {v3.16b}, [%[src_ptr3]], %[src_inc3]\n" + "add w1, w1, #16\n" + "cmp w1, w2\n" + "beq 2f\n" + + "1:\n" + + "eor v0.16b, v0.16b, v26.16b\n" + "eor v1.16b, v1.16b, v26.16b\n" + "eor v2.16b, v2.16b, v26.16b\n" + "eor v3.16b, v3.16b, v26.16b\n" + + "trn1 v16.4s, v0.4s, v1.4s\n" + "trn2 v17.4s, v0.4s, v1.4s\n" + "trn1 v18.4s, v2.4s, v3.4s\n" + "trn2 v19.4s, v2.4s, v3.4s\n" + + "ld1 {v0.16b}, [%[src_ptr0]], %[src_inc0]\n" + "ld1 {v1.16b}, [%[src_ptr1]], %[src_inc1]\n" + "ld1 {v2.16b}, [%[src_ptr2]], %[src_inc2]\n" + "ld1 {v3.16b}, [%[src_ptr3]], %[src_inc3]\n" + "add w1, w1, #16\n" + + "trn1 v20.2d, v16.2d, v18.2d\n" + "trn2 v22.2d, v16.2d, v18.2d\n" + "trn1 v21.2d, v17.2d, v19.2d\n" + "trn2 v23.2d, v17.2d, v19.2d\n" + + ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" + ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" + ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" + ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" + + "str q20, [%[packed_ptr], #0]\n" + "str q21, [%[packed_ptr], #32]\n" + "str q22, [%[packed_ptr], #64]\n" + "str q23, [%[packed_ptr], #96]\n" + "add %[packed_ptr], %[packed_ptr], #128\n" + + "cmp w1, w2\n" + "bne 1b\n" + + "2:\n" + + "eor v0.16b, v0.16b, v26.16b\n" + "eor v1.16b, v1.16b, v26.16b\n" + "eor v2.16b, v2.16b, v26.16b\n" + "eor v3.16b, v3.16b, v26.16b\n" + + "trn1 v16.4s, v0.4s, v1.4s\n" + "trn2 v17.4s, v0.4s, v1.4s\n" + "trn1 v18.4s, v2.4s, v3.4s\n" + "trn2 v19.4s, v2.4s, v3.4s\n" + + "trn1 v20.2d, v16.2d, v18.2d\n" + "trn2 v22.2d, v16.2d, v18.2d\n" + "trn1 v21.2d, v17.2d, v19.2d\n" + "trn2 v23.2d, v17.2d, v19.2d\n" + + ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" + ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" + ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" + ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" + + "str q20, [%[packed_ptr], #0]\n" + "str q21, [%[packed_ptr], #32]\n" + "str q22, [%[packed_ptr], #64]\n" + "str q23, [%[packed_ptr], #96]\n" + "add %[packed_ptr], %[packed_ptr], #128\n" + + "3:\n" + + "ands w2, %w[rows], #15\n" + "beq 4f\n" + "dup v0.16b, %w[src_zero_point]\n" + "dup v1.16b, %w[src_zero_point]\n" + "dup v2.16b, %w[src_zero_point]\n" + "dup v3.16b, %w[src_zero_point]\n" +#define RUY_LOAD_ONE_ROW(R) \ + "cmp w2, #" #R "\n" \ + "beq 5f\n" \ + "ld1 { v0.b }[" #R "], [%[src_ptr0]], #1\n" \ + "ld1 { v1.b }[" #R "], [%[src_ptr1]], #1\n" \ + "ld1 { v2.b }[" #R "], [%[src_ptr2]], #1\n" \ + "ld1 { v3.b }[" #R "], [%[src_ptr3]], #1\n" + + RUY_LOAD_ONE_ROW(0) + RUY_LOAD_ONE_ROW(1) + RUY_LOAD_ONE_ROW(2) + RUY_LOAD_ONE_ROW(3) + RUY_LOAD_ONE_ROW(4) + RUY_LOAD_ONE_ROW(5) + RUY_LOAD_ONE_ROW(6) + RUY_LOAD_ONE_ROW(7) + RUY_LOAD_ONE_ROW(8) + RUY_LOAD_ONE_ROW(9) + RUY_LOAD_ONE_ROW(10) + RUY_LOAD_ONE_ROW(11) + RUY_LOAD_ONE_ROW(12) + RUY_LOAD_ONE_ROW(13) + RUY_LOAD_ONE_ROW(14) + RUY_LOAD_ONE_ROW(15) +#undef RUY_LOAD_ONE_ROW + "5:\n" + + "eor v0.16b, v0.16b, v26.16b\n" + "eor v1.16b, v1.16b, v26.16b\n" + "eor v2.16b, v2.16b, v26.16b\n" + "eor v3.16b, v3.16b, v26.16b\n" + + "trn1 v16.4s, v0.4s, v1.4s\n" + "trn2 v17.4s, v0.4s, v1.4s\n" + "trn1 v18.4s, v2.4s, v3.4s\n" + "trn2 v19.4s, v2.4s, v3.4s\n" + + "trn1 v20.2d, v16.2d, v18.2d\n" + "trn2 v22.2d, v16.2d, v18.2d\n" + "trn1 v21.2d, v17.2d, v19.2d\n" + "trn2 v23.2d, v17.2d, v19.2d\n" + + ".word 0x4e9b969c // sdot v28.4s, v20.16b, v27.16b\n" + "str q20, [%[packed_ptr], #0]\n" + "cmp w2, #4\n" + "ble 4f\n" + ".word 0x4e9b96be // sdot v30.4s, v21.16b, v27.16b\n" + "str q21, [%[packed_ptr], #32]\n" + "cmp w2, #8\n" + "ble 4f\n" + ".word 0x4e9b96dd // sdot v29.4s, v22.16b, v27.16b\n" + "str q22, [%[packed_ptr], #64]\n" + "cmp w2, #12\n" + "ble 4f\n" + ".word 0x4e9b96ff // sdot v31.4s, v23.16b, v27.16b\n" + "str q23, [%[packed_ptr], #96]\n" + "add %[packed_ptr], %[packed_ptr], #128\n" + + "4:\n" + + "add v28.4s, v28.4s, v29.4s\n" + "add v30.4s, v30.4s, v31.4s\n" + "add v28.4s, v28.4s, v30.4s\n" + + "cmp %[sums_ptr], #0\n" + "beq 6f\n" + "st1 {v28.4s}, [%[sums_ptr]], #16\n" + "6:\n" + // clang-format on + + : [ src_ptr0 ] "+r"(src_ptr0), [ src_ptr1 ] "+r"(src_ptr1), + [ src_ptr2 ] "+r"(src_ptr2), [ src_ptr3 ] "+r"(src_ptr3), + [ packed_ptr ] "+r"(packed_ptr), [ sums_ptr ] "+r"(sums_ptr) + : [ src_inc0 ] "r"(static_cast(src_inc0)), + [ src_inc1 ] "r"(static_cast(src_inc1)), + [ src_inc2 ] "r"(static_cast(src_inc2)), + [ src_inc3 ] "r"(static_cast(src_inc3)), + [ rows ] "r"(src_rows), + [ src_zero_point ] "r"(static_cast(src_zero_point)), + [ input_xor ] "r"(input_xor) + : "cc", "memory", "x1", "x2", "v0", "v1", "v2", "v3", "v4", "v5", "v6", + "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", + "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", + "v27", "v28", "v29", "v30", "v31"); +} + +#endif // RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM) + +#if RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM) +void PackFloatNeonOutOfOrder(const float* src_ptr0, const float* src_ptr1, + const float* src_ptr2, const float* src_ptr3, + int src_inc0, int src_inc1, int src_inc2, + int src_inc3, int src_rows, int src_zero_point, + float* packed_ptr, int start_col, int end_col) { + profiler::ScopeLabel label("Pack (kNeon, optimized for out-of-order cores)"); + asm volatile( + // clang-format off + "mov w1, #0\n" + + "and w2, %w[rows], #-4\n" + "cmp w1, w2\n" + "beq 3f\n" + "ld1 {v0.4s}, [%[src_ptr0]], %[src_inc0]\n" + "ld1 {v1.4s}, [%[src_ptr1]], %[src_inc1]\n" + "ld1 {v2.4s}, [%[src_ptr2]], %[src_inc2]\n" + "ld1 {v3.4s}, [%[src_ptr3]], %[src_inc3]\n" + "add w1, w1, #4\n" + "cmp w1, w2\n" + + "beq 2f\n" + + "1:\n" + "add w1, w1, #4\n" + + "trn1 v16.4s, v0.4s, v1.4s\n" + "trn2 v17.4s, v0.4s, v1.4s\n" + "trn1 v18.4s, v2.4s, v3.4s\n" + "trn2 v19.4s, v2.4s, v3.4s\n" + + "ld1 {v0.4s}, [%[src_ptr0]], %[src_inc0]\n" + "ld1 {v1.4s}, [%[src_ptr1]], %[src_inc1]\n" + "ld1 {v2.4s}, [%[src_ptr2]], %[src_inc2]\n" + "ld1 {v3.4s}, [%[src_ptr3]], %[src_inc3]\n" + + "trn1 v20.2d, v16.2d, v18.2d\n" + "trn2 v22.2d, v16.2d, v18.2d\n" + "trn1 v21.2d, v17.2d, v19.2d\n" + "trn2 v23.2d, v17.2d, v19.2d\n" + "cmp w1, w2\n" + + "str q20, [%[packed_ptr], #0]\n" + "str q21, [%[packed_ptr], #32]\n" + "str q22, [%[packed_ptr], #64]\n" + "str q23, [%[packed_ptr], #96]\n" + + "add %[packed_ptr], %[packed_ptr], #128\n" + + "bne 1b\n" + + "2:\n" + + "trn1 v16.4s, v0.4s, v1.4s\n" + "trn2 v17.4s, v0.4s, v1.4s\n" + "trn1 v18.4s, v2.4s, v3.4s\n" + "trn2 v19.4s, v2.4s, v3.4s\n" + + "trn1 v20.2d, v16.2d, v18.2d\n" + "trn2 v22.2d, v16.2d, v18.2d\n" + "trn1 v21.2d, v17.2d, v19.2d\n" + "trn2 v23.2d, v17.2d, v19.2d\n" + + "str q20, [%[packed_ptr], #0]\n" + "str q21, [%[packed_ptr], #32]\n" + "str q22, [%[packed_ptr], #64]\n" + "str q23, [%[packed_ptr], #96]\n" + "add %[packed_ptr], %[packed_ptr], #128\n" + + "3:\n" + + "ands w2, %w[rows], #3\n" + "beq 4f\n" + "dup v0.16b, wzr\n" + "dup v1.16b, wzr\n" + "dup v2.16b, wzr\n" + "dup v3.16b, wzr\n" +#define RUY_LOAD_ONE_ROW(R) \ + "cmp w2, #" #R "\n" \ + "beq 5f\n" \ + "ld1 { v0.s }[" #R "], [%[src_ptr0]], #4\n" \ + "ld1 { v1.s }[" #R "], [%[src_ptr1]], #4\n" \ + "ld1 { v2.s }[" #R "], [%[src_ptr2]], #4\n" \ + "ld1 { v3.s }[" #R "], [%[src_ptr3]], #4\n" + + RUY_LOAD_ONE_ROW(0) + RUY_LOAD_ONE_ROW(1) + RUY_LOAD_ONE_ROW(2) + RUY_LOAD_ONE_ROW(3) +#undef RUY_LOAD_ONE_ROW + "5:\n" + + "trn1 v16.4s, v0.4s, v1.4s\n" + "trn2 v17.4s, v0.4s, v1.4s\n" + "trn1 v18.4s, v2.4s, v3.4s\n" + "trn2 v19.4s, v2.4s, v3.4s\n" + + "trn1 v20.2d, v16.2d, v18.2d\n" + "trn2 v22.2d, v16.2d, v18.2d\n" + "trn1 v21.2d, v17.2d, v19.2d\n" + "trn2 v23.2d, v17.2d, v19.2d\n" + + "mov x1, #32\n" + +#define RUY_STORE_ONE_ROW(ROW, REGISTER) \ + "cmp w2, #" #ROW "\n" \ + "beq 4f\n" \ + "st1 {" #REGISTER ".4s}, [%[packed_ptr]], x1\n" + + RUY_STORE_ONE_ROW(0, v20) + RUY_STORE_ONE_ROW(1, v21) + RUY_STORE_ONE_ROW(2, v22) + RUY_STORE_ONE_ROW(3, v23) + +#undef RUY_STORE_ONE_ROW + + "4:\n" + + // clang-format on + + : [ src_ptr0 ] "+r"(src_ptr0), [ src_ptr1 ] "+r"(src_ptr1), + [ src_ptr2 ] "+r"(src_ptr2), [ src_ptr3 ] "+r"(src_ptr3), + [ packed_ptr ] "+r"(packed_ptr) + : [ src_inc0 ] "r"(static_cast(src_inc0)), + [ src_inc1 ] "r"(static_cast(src_inc1)), + [ src_inc2 ] "r"(static_cast(src_inc2)), + [ src_inc3 ] "r"(static_cast(src_inc3)), + [ rows ] "r"(src_rows) + : "cc", "memory", "x1", "x2", "x10", "x11", "x12", "x13", "v0", "v1", + "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", + "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", + "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); +} +#endif + +#if RUY_PLATFORM(NEON_32) && RUY_OPT_ENABLED(RUY_OPT_ASM) +void PackFloatNeonOutOfOrder(const float* src_ptr0, const float* src_ptr1, + const float* src_ptr2, const float* src_ptr3, + int src_inc, int src_rows, int src_zero_point, + float* packed_ptr, int start_col, int end_col, + int output_stride) { + profiler::ScopeLabel label("Pack (kNeon, optimized for out-of-order cores)"); + asm volatile( + // clang-format off + "mov r1, #0\n" + "and r2, %[rows], #-4\n" + "cmp r1, r2\n" + "beq 3f\n" +#define RUY_LOAD_FOUR_BY_FOUR() \ + /* Load q0 */ \ + "vld1.32 {d0, d1}, [%[src_ptr0]]\n" \ + /* if src_inc0 != 0, add 16 to src_ptr0 */ \ + "and r3, %[src_inc], #1\n" \ + "add %[src_ptr0], %[src_ptr0], r3, lsl #4\n"\ + /* Load q1 */ \ + "vld1.32 {d2, d3}, [%[src_ptr1]]\n" \ + /* if src_inc1 != 0, add 16 to src_ptr0 */ \ + "and r3, %[src_inc], #2\n" \ + "add %[src_ptr1], %[src_ptr1], r3, lsl #3\n"\ + /* Load q2 */ \ + "vld1.32 {d4, d5}, [%[src_ptr2]]\n" \ + /* if src_inc2 != 0, add 16 to src_ptr0 */ \ + "and r3, %[src_inc], #4\n" \ + "add %[src_ptr2], %[src_ptr2], r3, lsl #2\n"\ + /* Load q3 */ \ + "vld1.32 {d6, d7}, [%[src_ptr3]]\n" \ + /* if src_inc3 != 0, add 16 to src_ptr0 */ \ + "and r3, %[src_inc], #8\n" \ + "add %[src_ptr3], %[src_ptr3], r3, lsl #1\n"\ + + RUY_LOAD_FOUR_BY_FOUR() + "add r1, r1, #4\n" + "cmp r1, r2\n" + + "beq 2f\n" + + "1:\n" + "add r1, r1, #4\n" + + // Transpose 4x4 matrix. + "vzip.32 q0, q1\n" + "vzip.32 q2, q3\n" + + "vtrn.32 q0, q2\n" + "vtrn.32 q1, q3\n" + + "vzip.32 q0, q2\n" + "vzip.32 q1, q3\n" + + "vmov q8, q0\n" + "vmov q9, q1\n" + "vmov q10, q2\n" + "vmov q11, q3\n" + + RUY_LOAD_FOUR_BY_FOUR() +#undef RUY_LOAD_FOUR_BY_FOUR + +#define RUY_STORE_FOUR_BY_FOUR() \ + /* Store q8, q10, q9, q11 */ \ + /* q8 = d16, d17 */ \ + "vst1.32 {d16, d17}, [%[packed_ptr]]\n" \ + /* q10 = d20, d21 */ \ + "add %[packed_ptr], %[packed_ptr], %[stride]\n" \ + "vst1.32 {d20, d21}, [%[packed_ptr]]\n" \ + /* q9 = d18, d19 */ \ + "add %[packed_ptr], %[packed_ptr], %[stride]\n" \ + "vst1.32 {d18, d19}, [%[packed_ptr]]\n" \ + /* q11 = d22, d23 */ \ + "add %[packed_ptr], %[packed_ptr], %[stride]\n" \ + "vst1.32 {d22, d23}, [%[packed_ptr]]\n" \ + "add %[packed_ptr], %[packed_ptr], %[stride]\n" \ + + RUY_STORE_FOUR_BY_FOUR() + "cmp r1, r2\n" + + "bne 1b\n" + + "2:\n" + + // Transpose 4x4 matrix. + "vzip.32 q0, q1\n" + "vzip.32 q2, q3\n" + + "vtrn.32 q0, q2\n" + "vtrn.32 q1, q3\n" + + "vzip.32 q0, q2\n" + "vzip.32 q1, q3\n" + + "vmov q8, q0\n" + "vmov q9, q1\n" + "vmov q10, q2\n" + "vmov q11, q3\n" + + RUY_STORE_FOUR_BY_FOUR() +#undef RUY_STORE_FOUR_BY_FOUR + "3:\n" + + "ands r2, %[rows], #3\n" + "beq 4f\n" + "mov r0, #0\n" + // Zero out q0 - q3 + "vdup.32 q0, r0\n" + "vdup.32 q1, r0\n" + "vdup.32 q2, r0\n" + "vdup.32 q3, r0\n" +#define RUY_LOAD_ONE_ROW_FIRST_HALF(R, I) \ + "cmp r2, #" #R "\n" \ + "beq 5f\n" \ + "vld1.32 { d0[" #I "] }, [%[src_ptr0]]!\n" \ + "vld1.32 { d2[" #I "] }, [%[src_ptr1]]!\n" \ + "vld1.32 { d4[" #I "] }, [%[src_ptr2]]!\n" \ + "vld1.32 { d6[" #I "] }, [%[src_ptr3]]!\n" + +#define RUY_LOAD_ONE_ROW_SECOND_HALF(R, I) \ + "vld1.32 { d1[" #I "] }, [%[src_ptr0]]!\n" \ + "vld1.32 { d3[" #I "] }, [%[src_ptr1]]!\n" \ + "vld1.32 { d5[" #I "] }, [%[src_ptr2]]!\n" \ + "vld1.32 { d7[" #I "] }, [%[src_ptr3]]!\n" + + RUY_LOAD_ONE_ROW_FIRST_HALF(0, 0) + RUY_LOAD_ONE_ROW_FIRST_HALF(1, 1) + RUY_LOAD_ONE_ROW_SECOND_HALF(2, 0) + RUY_LOAD_ONE_ROW_SECOND_HALF(3, 1) +#undef RUY_LOAD_ONE_ROW_SECOND_HALF +#undef RUY_LOAD_ONE_ROW_FIRST_HALF + "5:\n" + + // Transpose 4x4 matrix. + "vzip.32 q0, q1\n" + "vzip.32 q2, q3\n" + + "vtrn.32 q0, q2\n" + "vtrn.32 q1, q3\n" + + "vzip.32 q0, q2\n" + "vzip.32 q1, q3\n" + + "vmov q8, q0\n" + "vmov q9, q1\n" + "vmov q10, q2\n" + "vmov q11, q3\n" + + "mov r1, #32\n" + +#define RUY_STORE_ONE_ROW(ROW, REGISTER) \ + "cmp r2, #" #ROW "\n" \ + "beq 4f\n" \ + "vst1.32 {" #REGISTER "}, [%[packed_ptr]]\n" \ + "add %[packed_ptr], %[packed_ptr], %[stride]\n" + + // Store q8 + RUY_STORE_ONE_ROW(0, q8) + // Store q10 + RUY_STORE_ONE_ROW(1, q10) + // Store q9 + RUY_STORE_ONE_ROW(2, q9) + // Store q11 + RUY_STORE_ONE_ROW(3, q11) + +#undef RUY_STORE_ONE_ROW + + "4:\n" + + // clang-format on + : [ src_ptr0 ] "+r"(src_ptr0), [ src_ptr1 ] "+r"(src_ptr1), + [ src_ptr2 ] "+r"(src_ptr2), [ src_ptr3 ] "+r"(src_ptr3), + [ packed_ptr ] "+r"(packed_ptr) + : [ src_inc ] "r"(static_cast(src_inc)), + [ rows ] "r"(src_rows), [ stride ] "r"(output_stride) + : "cc", "memory", "r0", "r1", "r2", "r3", "q0", "q1", "q2", "q3", + "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11"); +} + +#endif // (RUY_PLATFORM(NEON_32) + +#if RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM) +void PackFloatNeonInOrder(const float* src_ptr0, const float* src_ptr1, + const float* src_ptr2, const float* src_ptr3, + int src_inc0, int src_inc1, int src_inc2, + int src_inc3, int src_rows, int src_zero_point, + float* packed_ptr, int start_col, int end_col) { + profiler::ScopeLabel label("Pack (kNeon, optimized for in-order cores)"); + + asm volatile( + // clang-format off + "mov w1, #0\n" + + "and w2, %w[rows], #-4\n" + "cmp w1, w2\n" + "beq 3f\n" + "ld1 {v0.4s}, [%[src_ptr0]], %[src_inc0]\n" + "ld1 {v1.4s}, [%[src_ptr1]], %[src_inc1]\n" + "ld1 {v2.4s}, [%[src_ptr2]], %[src_inc2]\n" + "ld1 {v3.4s}, [%[src_ptr3]], %[src_inc3]\n" + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #64]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #64]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #64]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #64]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #128]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #128]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #128]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #128]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #192]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #192]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #192]\n") + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #192]\n") + "add w1, w1, #4\n" + "cmp w1, w2\n" + + "beq 2f\n" + + "1:\n" + "add w1, w1, #4\n" + + "ldr x10, [%[src_ptr0], #8]\n" + "trn1 v16.4s, v0.4s, v1.4s\n" + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr0], #240]\n") + "ldr x11, [%[src_ptr1], #8]\n" + "trn2 v17.4s, v0.4s, v1.4s\n" + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr1], #240]\n") + "ldr x12, [%[src_ptr2], #8]\n" + "trn1 v18.4s, v2.4s, v3.4s\n" + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr2], #240]\n") + "ldr x13, [%[src_ptr3], #8]\n" + "trn2 v19.4s, v2.4s, v3.4s\n" + RUY_PREFETCH_LOAD("prfm pldl1strm, [%[src_ptr3], #240]\n") + + "ld1 {v0.2s}, [%[src_ptr0]], %[src_inc0]\n" + "trn1 v20.2d, v16.2d, v18.2d\n" + "ld1 {v1.2s}, [%[src_ptr1]], %[src_inc1]\n" + "trn2 v22.2d, v16.2d, v18.2d\n" + "ld1 {v2.2s}, [%[src_ptr2]], %[src_inc2]\n" + "trn1 v21.2d, v17.2d, v19.2d\n" + "ld1 {v3.2s}, [%[src_ptr3]], %[src_inc3]\n" + "trn2 v23.2d, v17.2d, v19.2d\n" + "cmp w1, w2\n" + + "ins v0.d[1], x10\n" + "str q20, [%[packed_ptr], #0]\n" + "ins v1.d[1], x11\n" + "str q21, [%[packed_ptr], #32]\n" + "ins v2.d[1], x12\n" + "str q22, [%[packed_ptr], #64]\n" + "ins v3.d[1], x13\n" + "str q23, [%[packed_ptr], #96]\n" + + "add %[packed_ptr], %[packed_ptr], #128\n" + + "bne 1b\n" + + "2:\n" + + "trn1 v16.4s, v0.4s, v1.4s\n" + "trn2 v17.4s, v0.4s, v1.4s\n" + "trn1 v18.4s, v2.4s, v3.4s\n" + "trn2 v19.4s, v2.4s, v3.4s\n" + + "trn1 v20.2d, v16.2d, v18.2d\n" + "trn2 v22.2d, v16.2d, v18.2d\n" + "trn1 v21.2d, v17.2d, v19.2d\n" + "trn2 v23.2d, v17.2d, v19.2d\n" + + "str q20, [%[packed_ptr], #0]\n" + "str q21, [%[packed_ptr], #32]\n" + "str q22, [%[packed_ptr], #64]\n" + "str q23, [%[packed_ptr], #96]\n" + "add %[packed_ptr], %[packed_ptr], #128\n" + + "3:\n" + + "ands w2, %w[rows], #3\n" + "beq 4f\n" + "dup v0.16b, wzr\n" + "dup v1.16b, wzr\n" + "dup v2.16b, wzr\n" + "dup v3.16b, wzr\n" +#define RUY_LOAD_ONE_ROW(R) \ + "cmp w2, #" #R "\n" \ + "beq 5f\n" \ + "ld1 { v0.s }[" #R "], [%[src_ptr0]], #4\n" \ + "ld1 { v1.s }[" #R "], [%[src_ptr1]], #4\n" \ + "ld1 { v2.s }[" #R "], [%[src_ptr2]], #4\n" \ + "ld1 { v3.s }[" #R "], [%[src_ptr3]], #4\n" + + RUY_LOAD_ONE_ROW(0) + RUY_LOAD_ONE_ROW(1) + RUY_LOAD_ONE_ROW(2) + RUY_LOAD_ONE_ROW(3) +#undef RUY_LOAD_ONE_ROW + "5:\n" + + "trn1 v16.4s, v0.4s, v1.4s\n" + "trn2 v17.4s, v0.4s, v1.4s\n" + "trn1 v18.4s, v2.4s, v3.4s\n" + "trn2 v19.4s, v2.4s, v3.4s\n" + + "trn1 v20.2d, v16.2d, v18.2d\n" + "trn2 v22.2d, v16.2d, v18.2d\n" + "trn1 v21.2d, v17.2d, v19.2d\n" + "trn2 v23.2d, v17.2d, v19.2d\n" + + "mov x1, #32\n" + +#define RUY_STORE_ONE_ROW(ROW, REGISTER) \ + "cmp w2, #" #ROW "\n" \ + "beq 4f\n" \ + "st1 {" #REGISTER ".4s}, [%[packed_ptr]], x1\n" + + RUY_STORE_ONE_ROW(0, v20) + RUY_STORE_ONE_ROW(1, v21) + RUY_STORE_ONE_ROW(2, v22) + RUY_STORE_ONE_ROW(3, v23) + +#undef RUY_STORE_ONE_ROW + + "4:\n" + + // clang-format on + + : [ src_ptr0 ] "+r"(src_ptr0), [src_ptr1] "+r"(src_ptr1), [src_ptr2] "+r"(src_ptr2), + [src_ptr3] "+r"(src_ptr3), [packed_ptr] "+r"(packed_ptr) + : [ src_inc0 ] "r"(static_cast(src_inc0)), [src_inc1] "r"(static_cast(src_inc1)), [src_inc2] "r"(static_cast(src_inc2)), + [src_inc3] "r"(static_cast(src_inc3)), [rows] "r"(src_rows) + : "cc", "memory", "x1", "x2", "x10", "x11", "x12", "x13", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", + "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", + "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); +} +#endif // RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM) + +} // namespace ruy diff --git a/ruy/pack_arm.h b/ruy/pack_arm.h new file mode 100644 index 0000000..8e7f619 --- /dev/null +++ b/ruy/pack_arm.h @@ -0,0 +1,497 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// # What is "packing"? +// +// Before feeding data to the gemm kernels (the parts of Ruy that do lots +// of multiply-add operations), Ruy first performs a data transformation (which +// we call "packing") on the input matrices. This transformation has two main +// goals: +// - rearrange data into blocks that are a convenient size/layout for the gemm +// kernels to consume. This helps make the memory access pattern of the gemm +// kernel simpler and more contiguous, and puts the data in a layout most +// convenient for specific arithmetic instructions in the gemm kernel. +// - compute row/column sums needed for handling quantization with non-symmetric +// zero points. +// +// # Simplified algorithmic analysis of packing +// +// Packing is a relatively simple transformation which does a small constant +// amount of work on each element of an input matrix, and hence for an NxM +// matrix performs O(N*M) work. If N and M are of the same order, then this is +// O(N^2) work. +// +// A NxKxM matrix multiplication requires N*K*M multiply-accumulate operations. +// Note that if N, K, and M are all the same order, then the number of +// multiply-accumulate operations is O(N^3). +// +// Thus, the O(N^2) cost of packing is small compared to the O(N^3) work, in the +// case of all dimensions being roughly the same order. +// +// # Packing cost can be significant +// +// When matrix * matrix multiplications begin to look more like matrix * vector +// multiplications, packing cost can become significant. We sometimes call these +// cases "gemv-like". +// +// Continuing the algorithmic analysis above, if we consider a case where an +// NxKxM matrix multiplication has either N = O(1) or M = O(1), then the +// situation is different. In this case, the multiply-accumulate work is only +// quadratic, so the quadratic cost of packing can be come significant. +// +// Another way to say this is that the cost of packing an input matrix (either +// the LHS or RHS) is amortized across the non-depth dimension of the opposite +// input matrix. Thus, when the LHS has very few rows or the RHS has very few +// columns, the cost of packing the opposite input matrix can become +// significant. +// +// As a rough rule of thumb, the cost of packing starts to become significant +// when either N or M is below 32 (and other dimensions are hundreds), with very +// significant packing costs at 8 or below. This varies by data type, Path, and +// tuning, so these numbers are only rough guides. +// +// One practical use case that is affected by this is inference of +// fully connected neural network layers with a low batch size. The weight +// matrix (which is a constant for inference) is the one affected by significant +// packing cost. +// +// Ruy provides an API in ruy_advanced.h for advanced users to pre-pack +// input matrices that are affected by significant packing costs. +// +// # Implementation notes +// +// Ruy's packing routines always operate on a range of columns and can be +// applied to either the LHS or RHS. This is possible because Ruy internally +// implements a TrMul, so the accumulation along depth is done along columns of +// both the LHS and RHS (whereas for a normal Mul the accumulation along depth +// for the LHS is along rows). As another example, we are always computing +// column sums for quantization (and never row sums, since the LHS is +// transposed). + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PACK_ARM_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PACK_ARM_H_ + +#include +#include + +#include "ruy/check_macros.h" +#include "ruy/common.h" +#include "ruy/internal_matrix.h" +#include "ruy/matrix.h" +#include "ruy/opt_set.h" +#include "ruy/pack_common.h" +#include "ruy/path.h" +#include "ruy/platform.h" +#include "ruy/profiler/instrumentation.h" +#include "ruy/tune.h" + +namespace ruy { + +#if RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM) +void Pack8bitNeonOutOfOrder(const void* src_ptr0, const void* src_ptr1, + const void* src_ptr2, const void* src_ptr3, + int src_inc0, int src_inc1, int src_inc2, + int src_inc3, int src_rows, int src_zero_point, + std::int8_t* packed_ptr, int start_col, int end_col, + std::int32_t* sums_ptr, int input_xor); +void Pack8bitNeonInOrder(const void* src_ptr0, const void* src_ptr1, + const void* src_ptr2, const void* src_ptr3, + int src_inc0, int src_inc1, int src_inc2, int src_inc3, + int src_rows, int src_zero_point, + std::int8_t* packed_ptr, int start_col, int end_col, + std::int32_t* sums_ptr, int input_xor); +void Pack8bitNeonDotprodOutOfOrder(const void* src_ptr0, const void* src_ptr1, + const void* src_ptr2, const void* src_ptr3, + int src_inc0, int src_inc1, int src_inc2, + int src_inc3, int src_rows, + int src_zero_point, std::int8_t* packed_ptr, + int start_col, int end_col, + std::int32_t* sums_ptr, int input_xor); +void Pack8bitNeonDotprodInOrder(const void* src_ptr0, const void* src_ptr1, + const void* src_ptr2, const void* src_ptr3, + int src_inc0, int src_inc1, int src_inc2, + int src_inc3, int src_rows, int src_zero_point, + std::int8_t* packed_ptr, int start_col, + int end_col, std::int32_t* sums_ptr, + int input_xor); + +#elif RUY_PLATFORM(NEON_32) && RUY_OPT_ENABLED(RUY_OPT_ASM) +void Pack8bitNeonOutOfOrder4Cols(const PackParams8bit& params); +void Pack8bitNeonOutOfOrder2Cols(const PackParams8bit& params); +#endif // (RUY_PLATFORM(NEON_64)&& RUY_OPT_ENABLED(RUY_OPT_ASM) + +#if (RUY_PLATFORM(NEON_32) || RUY_PLATFORM(NEON_64)) && \ + RUY_OPT_ENABLED(RUY_OPT_ASM) + +template +struct PackImpl, Scalar, + std::int8_t, std::int32_t> { + static_assert(std::is_same::value || + std::is_same::value, + ""); + static constexpr int kInputXor = + std::is_same::value ? 0 : 0x80; + + static void Run(Tuning tuning, const Matrix& src_matrix, + PackedMatrix* packed_matrix, int start_col, + int end_col) { + RUY_DCHECK(IsColMajor(src_matrix.layout)); + RUY_DCHECK(IsColMajor(packed_matrix->layout)); + RUY_DCHECK_EQ(start_col % 4, 0); + std::int32_t* sums = packed_matrix->sums; + Scalar zerobuf[16]; + memset(zerobuf, src_matrix.zero_point, sizeof(zerobuf)); + for (int block_col = start_col; block_col < end_col; block_col += 4) { + int src_stride = src_matrix.layout.stride; + const Scalar* src_ptr0 = src_matrix.data.get() + src_stride * block_col; + const Scalar* src_ptr1 = src_ptr0 + src_stride; + const Scalar* src_ptr2 = src_ptr1 + src_stride; + const Scalar* src_ptr3 = src_ptr2 + src_stride; + int src_inc0 = 16; + int src_inc1 = 16; + int src_inc2 = 16; + int src_inc3 = 16; + if (block_col >= src_matrix.layout.cols - 3) { + if (block_col >= src_matrix.layout.cols - 0) { + src_ptr0 = zerobuf; + src_inc0 = 0; + } + if (block_col >= src_matrix.layout.cols - 1) { + src_ptr1 = zerobuf; + src_inc1 = 0; + } + if (block_col >= src_matrix.layout.cols - 2) { + src_ptr2 = zerobuf; + src_inc2 = 0; + } + if (block_col >= src_matrix.layout.cols - 3) { + src_ptr3 = zerobuf; + src_inc3 = 0; + } + } + std::int8_t* packed_ptr = + packed_matrix->data + packed_matrix->layout.stride * block_col; + std::int32_t* sums_ptr = sums ? sums + block_col : nullptr; +#if RUY_PLATFORM(NEON_64) + if (__builtin_expect(tuning == Tuning::kInOrder, true)) { + Pack8bitNeonInOrder( + src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc0, src_inc1, + src_inc2, src_inc3, src_matrix.layout.rows, src_matrix.zero_point, + packed_ptr, start_col, end_col, sums_ptr, kInputXor); + } else { + Pack8bitNeonOutOfOrder( + src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc0, src_inc1, + src_inc2, src_inc3, src_matrix.layout.rows, src_matrix.zero_point, + packed_ptr, start_col, end_col, sums_ptr, kInputXor); + } +#else + // We have a more limited set of general purpose registers in ARMv7, so + // we use the "params" struct technique from the kernel code to save + // registers. + PackParams8bit params; + MakePackParams8bit(src_ptr0, src_ptr1, src_ptr2, src_ptr3, sums_ptr, + packed_ptr, src_inc0, src_inc1, src_inc2, src_inc3, + src_matrix.layout.rows, src_matrix.zero_point, + kInputXor, ¶ms); + Pack8bitNeonOutOfOrder4Cols(params); +#endif // RUY_PLATFORM(NEON_64) + } + } +}; + +#endif // (RUY_PLATFORM(NEON_32) || RUY_PLATFORM(NEON_64)) && + // RUY_OPT_ENABLED(RUY_OPT_ASM) + +#if RUY_PLATFORM(NEON_32) && RUY_OPT_ENABLED(RUY_OPT_ASM) +// The 32-bit float kernel is 4 rows X 2 columns, so we need an additional +// partial specialization for the RHS, which has a FixedKernelLayout with 2 +// columns. +template +struct PackImpl, Scalar, + std::int8_t, std::int32_t> { + static_assert(std::is_same::value || + std::is_same::value, + ""); + static constexpr int kInputXor = + std::is_same::value ? 0 : 0x80; + static void Run(Tuning tuning, const Matrix& src_matrix, + PackedMatrix* packed_matrix, int start_col, + int end_col) { + RUY_DCHECK(IsColMajor(src_matrix.layout)); + RUY_DCHECK(IsColMajor(packed_matrix->layout)); + RUY_DCHECK_EQ(start_col % 2, 0); + std::int32_t* sums = packed_matrix->sums; + Scalar zerobuf[16]; + memset(zerobuf, src_matrix.zero_point, sizeof(zerobuf)); + for (int block_col = start_col; block_col < end_col; block_col += 2) { + int src_stride = src_matrix.layout.stride; + const Scalar* src_ptr0 = src_matrix.data.get() + src_stride * block_col; + const Scalar* src_ptr1 = src_ptr0 + src_stride; + int src_inc0 = 16; + int src_inc1 = 16; + if (block_col >= src_matrix.layout.cols - 2) { + if (block_col >= src_matrix.layout.cols - 0) { + src_ptr0 = zerobuf; + src_inc0 = 0; + } + if (block_col >= src_matrix.layout.cols - 1) { + src_ptr1 = zerobuf; + src_inc1 = 0; + } + } + std::int8_t* packed_ptr = + packed_matrix->data + packed_matrix->layout.stride * block_col; + std::int32_t* sums_ptr = sums ? sums + block_col : nullptr; + PackParams8bit params; + MakePackParams8bit(src_ptr0, src_ptr1, nullptr, nullptr, sums_ptr, + packed_ptr, src_inc0, src_inc1, -1, -1, + src_matrix.layout.rows, src_matrix.zero_point, + kInputXor, ¶ms); + Pack8bitNeonOutOfOrder2Cols(params); + } + } +}; +#endif // (RUY_PLATFORM(NEON_32)) && RUY_OPT_ENABLED(RUY_OPT_ASM) + +#if RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM) +template +struct PackImpl, + Scalar, std::int8_t, std::int32_t> { + static_assert(std::is_same::value || + std::is_same::value, + ""); + static constexpr int kInputXor = + std::is_same::value ? 0 : 0x80; + + static void Run(Tuning tuning, const Matrix& src_matrix, + PackedMatrix* packed_matrix, int start_col, + int end_col) { + RUY_DCHECK(IsColMajor(src_matrix.layout)); + RUY_DCHECK(IsColMajor(packed_matrix->layout)); + RUY_DCHECK_EQ(start_col % 8, 0); + std::int32_t* sums = packed_matrix->sums; + Scalar zerobuf[16]; + memset(zerobuf, src_matrix.zero_point, sizeof(zerobuf)); + for (int block_col = start_col; block_col < end_col; block_col += 4) { + int src_stride = src_matrix.layout.stride; + const Scalar* src_ptr0 = src_matrix.data.get() + src_stride * block_col; + const Scalar* src_ptr1 = src_ptr0 + src_stride; + const Scalar* src_ptr2 = src_ptr1 + src_stride; + const Scalar* src_ptr3 = src_ptr2 + src_stride; + std::int64_t src_inc0 = 16; + std::int64_t src_inc1 = 16; + std::int64_t src_inc2 = 16; + std::int64_t src_inc3 = 16; + if (block_col >= src_matrix.layout.cols - 3) { + if (block_col >= src_matrix.layout.cols - 0) { + src_ptr0 = zerobuf; + src_inc0 = 0; + } + if (block_col >= src_matrix.layout.cols - 1) { + src_ptr1 = zerobuf; + src_inc1 = 0; + } + if (block_col >= src_matrix.layout.cols - 2) { + src_ptr2 = zerobuf; + src_inc2 = 0; + } + if (block_col >= src_matrix.layout.cols - 3) { + src_ptr3 = zerobuf; + src_inc3 = 0; + } + } + std::int8_t* packed_ptr = + packed_matrix->data + + packed_matrix->layout.stride * (block_col & ~7) + + ((block_col & 4) * 4); + std::int32_t* sums_ptr = sums ? sums + block_col : nullptr; + if (__builtin_expect(tuning == Tuning::kInOrder, true)) { + Pack8bitNeonDotprodInOrder( + src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc0, src_inc1, + src_inc2, src_inc3, src_matrix.layout.rows, src_matrix.zero_point, + packed_ptr, start_col, end_col, sums_ptr, kInputXor); + } else { + Pack8bitNeonDotprodOutOfOrder( + src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc0, src_inc1, + src_inc2, src_inc3, src_matrix.layout.rows, src_matrix.zero_point, + packed_ptr, start_col, end_col, sums_ptr, kInputXor); + } + } + } +}; +#endif // (RUY_PLATFORM(NEON_64)&& RUY_OPT_ENABLED(RUY_OPT_ASM) + +#if RUY_PLATFORM(NEON_64) && RUY_OPT_ENABLED(RUY_OPT_ASM) +void PackFloatNeonOutOfOrder(const float* src_ptr0, const float* src_ptr1, + const float* src_ptr2, const float* src_ptr3, + int src_inc0, int src_inc1, int src_inc2, + int src_inc3, int src_rows, int src_zero_point, + float* packed_ptr, int start_col, int end_col); +void PackFloatNeonInOrder(const float* src_ptr0, const float* src_ptr1, + const float* src_ptr2, const float* src_ptr3, + int src_inc0, int src_inc1, int src_inc2, + int src_inc3, int src_rows, int src_zero_point, + float* packed_ptr, int start_col, int end_col); + +#elif RUY_PLATFORM(NEON_32) && RUY_OPT_ENABLED(RUY_OPT_ASM) +void PackFloatNeonOutOfOrder(const float* src_ptr0, const float* src_ptr1, + const float* src_ptr2, const float* src_ptr3, + int src_inc, int src_rows, int src_zero_point, + float* packed_ptr, int start_col, int end_col, + int stride); +#endif // (RUY_PLATFORM(NEON_64)&& RUY_OPT_ENABLED(RUY_OPT_ASM) + +#if (RUY_PLATFORM(NEON_32) || RUY_PLATFORM(NEON_64)) && \ + RUY_OPT_ENABLED(RUY_OPT_ASM) + +template <> +struct PackImpl, float, + float, float> { + static void Run(Tuning tuning, const Matrix& src_matrix, + PackedMatrix* packed_matrix, int start_col, + int end_col) { + RUY_DCHECK(IsColMajor(src_matrix.layout)); + RUY_DCHECK(IsColMajor(packed_matrix->layout)); + RUY_DCHECK_EQ(start_col % 8, 0); + const float zerobuf[4] = {0}; + for (int block_col = start_col; block_col < end_col; block_col += 4) { + int src_stride = src_matrix.layout.stride; + const float* src_ptr0 = src_matrix.data.get() + src_stride * block_col; + const float* src_ptr1 = src_ptr0 + src_stride; + const float* src_ptr2 = src_ptr1 + src_stride; + const float* src_ptr3 = src_ptr2 + src_stride; + std::int64_t src_inc0 = 16; + std::int64_t src_inc1 = 16; + std::int64_t src_inc2 = 16; + std::int64_t src_inc3 = 16; + if (block_col >= src_matrix.layout.cols - 3) { + if (block_col >= src_matrix.layout.cols - 0) { + src_ptr0 = zerobuf; + src_inc0 = 0; + } + if (block_col >= src_matrix.layout.cols - 1) { + src_ptr1 = zerobuf; + src_inc1 = 0; + } + if (block_col >= src_matrix.layout.cols - 2) { + src_ptr2 = zerobuf; + src_inc2 = 0; + } + if (block_col >= src_matrix.layout.cols - 3) { + src_ptr3 = zerobuf; + src_inc3 = 0; + } + } + float* packed_ptr = packed_matrix->data + + packed_matrix->layout.stride * (block_col & ~7) + + ((block_col & 4)); +#if RUY_PLATFORM(NEON_64) + if (__builtin_expect(tuning == Tuning::kInOrder, true)) { + PackFloatNeonInOrder(src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc0, + src_inc1, src_inc2, src_inc3, + src_matrix.layout.rows, src_matrix.zero_point, + packed_ptr, start_col, end_col); + } else { + PackFloatNeonOutOfOrder(src_ptr0, src_ptr1, src_ptr2, src_ptr3, + src_inc0, src_inc1, src_inc2, src_inc3, + src_matrix.layout.rows, src_matrix.zero_point, + packed_ptr, start_col, end_col); + } +#else + // Encode each of src_inc0, ..., src_inc3 in lowest 4 bits of src_inc + // to save on registers (we have fewer general purpose registers in + // 32-bit ARM than in 64-bit ARM). For the 64-bit case, we pass four + // values that are each either 16 or 0 and use them directly. For the + // 32-bit case, bits 0, 1, 2, and 3 are used to determine if we should + // use the value 16 (bit is set) or 0 (bit is not set) for the + // respective increment value. + std::int64_t src_inc = 0; + src_inc += src_inc0 == 16 ? 1 : 0; + src_inc += src_inc1 == 16 ? 2 : 0; + src_inc += src_inc2 == 16 ? 4 : 0; + src_inc += src_inc3 == 16 ? 8 : 0; + const int kOutputStride = 32; + PackFloatNeonOutOfOrder(src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc, + src_matrix.layout.rows, src_matrix.zero_point, + packed_ptr, start_col, end_col, kOutputStride); +#endif // RUY_PLATFORM(NEON_64) + } + } +}; + +#if RUY_PLATFORM(NEON_32) +// The 32-bit float kernel is 8 rows X 4 columns, so we need an additional +// specialization for a FixedKernelLayout with 4 columns. +template <> +struct PackImpl, float, + float, float> { + static void Run(Tuning tuning, const Matrix& src_matrix, + PackedMatrix* packed_matrix, int start_col, + int end_col) { + RUY_DCHECK(IsColMajor(src_matrix.layout)); + RUY_DCHECK(IsColMajor(packed_matrix->layout)); + RUY_DCHECK_EQ(start_col % 4, 0); + const float zerobuf[4] = {0}; + for (int block_col = start_col; block_col < end_col; block_col += 4) { + int src_stride = src_matrix.layout.stride; + const float* src_ptr0 = src_matrix.data.get() + src_stride * block_col; + const float* src_ptr1 = src_ptr0 + src_stride; + const float* src_ptr2 = src_ptr1 + src_stride; + const float* src_ptr3 = src_ptr2 + src_stride; + std::int64_t src_inc0 = 16; + std::int64_t src_inc1 = 16; + std::int64_t src_inc2 = 16; + std::int64_t src_inc3 = 16; + if (block_col >= src_matrix.layout.cols - 3) { + if (block_col >= src_matrix.layout.cols - 0) { + src_ptr0 = zerobuf; + src_inc0 = 0; + } + if (block_col >= src_matrix.layout.cols - 1) { + src_ptr1 = zerobuf; + src_inc1 = 0; + } + if (block_col >= src_matrix.layout.cols - 2) { + src_ptr2 = zerobuf; + src_inc2 = 0; + } + if (block_col >= src_matrix.layout.cols - 3) { + src_ptr3 = zerobuf; + src_inc3 = 0; + } + } + float* packed_ptr = + packed_matrix->data + packed_matrix->layout.stride * (block_col); + // Encode each of src_inc0, ..., src_inc1 in lowest 4 bits of scrc_inc + // to save registers. + std::int64_t src_inc = 0; + src_inc += src_inc0 == 16 ? 1 : 0; + src_inc += src_inc1 == 16 ? 2 : 0; + src_inc += src_inc2 == 16 ? 4 : 0; + src_inc += src_inc3 == 16 ? 8 : 0; + const int kOutputStride = 16; + PackFloatNeonOutOfOrder(src_ptr0, src_ptr1, src_ptr2, src_ptr3, src_inc, + src_matrix.layout.rows, src_matrix.zero_point, + packed_ptr, start_col, end_col, kOutputStride); + } + } +}; +#endif // (RUY_PLATFORM(NEON_32)) +#endif // (RUY_PLATFORM(NEON_64) || RUY_PLATFORM(NEON_32)) && \ + // RUY_OPT_ENABLED(RUY_OPT_ASM) + +} // namespace ruy + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PACK_ARM_H_ diff --git a/ruy/pack_avx2.cc b/ruy/pack_avx2.cc new file mode 100644 index 0000000..013a8c0 --- /dev/null +++ b/ruy/pack_avx2.cc @@ -0,0 +1,816 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "ruy/check_macros.h" +#include "ruy/matrix.h" +#include "ruy/opt_set.h" +#include "ruy/pack.h" +#include "ruy/path.h" +#include "ruy/platform.h" +#include "ruy/profiler/instrumentation.h" + +#if RUY_PLATFORM(AVX2) && RUY_OPT_ENABLED(RUY_OPT_INTRINSICS) +#include // IWYU pragma: keep +#endif + +namespace ruy { + +#if !(RUY_PLATFORM(AVX2) && RUY_OPT_ENABLED(RUY_OPT_ASM)) + +void Pack8bitAvx2(const std::int8_t* src_ptr, std::int8_t input_xor, + const std::int8_t* zerobuf, int src_stride, + int remaining_src_cols, int src_rows, std::int8_t* packed_ptr, + std::int32_t* sums_ptr) { + // CPU-ID-based checks should disable the path that would reach this point. + RUY_DCHECK(false); +} + +void PackFloatAvx2(const float* src_ptr, const float* zerobuf, int src_stride, + int remaining_src_cols, int src_rows, float* packed_ptr) { + // CPU-ID-based checks should disable the path that would reach this point. + RUY_DCHECK(false); +} + +#else // RUY_PLATFORM(AVX2) && RUY_OPT_ENABLED(RUY_OPT_ASM) + +// The first int8_t template parameter is arbitrary: this routine is common to +// all 8-bit source matrix types. +using PackImpl8bitAvx2 = + PackImpl, + std::int8_t, std::int8_t, std::int32_t>; + +using PackImplFloatAvx2 = + PackImpl, float, + float, float>; + +namespace { + +inline __m256i MaskLoadu(int available_src_rows, std::int8_t zero_point, + const std::int8_t* addr) { + RUY_DCHECK_LT(available_src_rows, 32); + __m256i padded_data; + + if (available_src_rows >= 16) { + __m128i load_hi = _mm_set1_epi8(zero_point); + __m128i load_lo = _mm_loadu_si128(reinterpret_cast(addr)); + memcpy(&load_hi, addr + 16, available_src_rows - 16); + padded_data = _mm256_set_m128i(load_hi, load_lo); + } else { + __m128i load_hi = _mm_set1_epi8(zero_point); + __m128i load_lo = load_hi; + memcpy(&load_lo, addr, available_src_rows); + padded_data = _mm256_set_m128i(load_hi, load_lo); + } + return padded_data; +} + +inline void Pack8bitAvx2Packer(const std::int8_t* src_ptr, + std::int8_t input_xor, + const std::int8_t* zerobuf, int src_stride, + int remaining_src_cols, int src_rows, + std::int8_t* packed_ptr, std::int32_t* sums_ptr, + std::int8_t* trailing_buf) { + using Layout = PackImpl8bitAvx2::Layout; + RUY_DCHECK_EQ(Layout::kCols, 8); + RUY_DCHECK_EQ(Layout::kRows, 4); + // Each Layout::Rows is 4 contiguous input, contiguous packed elements. + // We process 8 of these chunks at a time, padding short input chunks. + constexpr int kNumRowChunks = 8; + constexpr int kNumChunkedSrcRows = kNumRowChunks * Layout::kRows; + + const std::int8_t* src_ptr0 = src_ptr; + const std::int8_t* src_ptr1 = src_ptr0 + src_stride; + const std::int8_t* src_ptr2 = src_ptr1 + src_stride; + const std::int8_t* src_ptr3 = src_ptr2 + src_stride; + const std::int8_t* src_ptr4 = src_ptr3 + src_stride; + const std::int8_t* src_ptr5 = src_ptr4 + src_stride; + const std::int8_t* src_ptr6 = src_ptr5 + src_stride; + const std::int8_t* src_ptr7 = src_ptr6 + src_stride; + std::int64_t src_inc0 = kNumChunkedSrcRows; + std::int64_t src_inc1 = kNumChunkedSrcRows; + std::int64_t src_inc2 = kNumChunkedSrcRows; + std::int64_t src_inc3 = kNumChunkedSrcRows; + std::int64_t src_inc4 = kNumChunkedSrcRows; + std::int64_t src_inc5 = kNumChunkedSrcRows; + std::int64_t src_inc6 = kNumChunkedSrcRows; + std::int64_t src_inc7 = kNumChunkedSrcRows; + // Handle cases where source does not have Layout::kCols (8) columns. + if (remaining_src_cols < 8) { + if (remaining_src_cols <= 0) { + src_ptr0 = zerobuf; + src_inc0 = 0; + } + if (remaining_src_cols <= 1) { + src_ptr1 = zerobuf; + src_inc1 = 0; + } + if (remaining_src_cols <= 2) { + src_ptr2 = zerobuf; + src_inc2 = 0; + } + if (remaining_src_cols <= 3) { + src_ptr3 = zerobuf; + src_inc3 = 0; + } + if (remaining_src_cols <= 4) { + src_ptr4 = zerobuf; + src_inc4 = 0; + } + if (remaining_src_cols <= 5) { + src_ptr5 = zerobuf; + src_inc5 = 0; + } + if (remaining_src_cols <= 6) { + src_ptr6 = zerobuf; + src_inc6 = 0; + } + src_ptr7 = zerobuf; + src_inc7 = 0; + } + + const std::int8_t zero_point = zerobuf[0]; + + if (sums_ptr) { + // i: Layout::kCols. + for (int i = 0; i < 8; ++i) { + sums_ptr[i] = 0; + } + } + std::int32_t sums_adjustment = 0; + const __m256i ones_16bit = _mm256_set1_epi16(1); + __m256i sums_4x2_32bit_lo = _mm256_set1_epi32(0); + __m256i sums_4x2_32bit_hi = _mm256_set1_epi32(0); + + // The overall packing effectively pads the source rows to + // (src_rows + 63) & ~63. The iteration over k may skip when m=1, and then we + // only pack for (src_rows + 31) & ~31. When there is an incomplete + // destination block, this is stored into trailing_buf instead of packed_ptr. + for (int k = 0; k < src_rows; k += kNumChunkedSrcRows) { + // Available source rows. + // If this is less than 0 (for m=1), we skip, having filled trailing + // buffer for m=0. Also, if source rows is zero on m=1, then we filled + // exactly to the end of the column in the packed buffer. + const int available_src_rows = src_rows - k; + // Effectively, + // available rows = std::max(0, std::min(8, src_rows - k)); + // treat each case separately. + if (available_src_rows >= kNumChunkedSrcRows) { + if (sums_ptr) { + __m256i t0, t1, t2, t3, t4, t5, t6, t7; + __m256i r0, r1, r2, r3, r4, r5, r6, r7; + const __m256i input_xor_v = _mm256_set1_epi8(input_xor); + + t0 = _mm256_loadu_si256(reinterpret_cast(src_ptr0)); + t4 = _mm256_loadu_si256(reinterpret_cast(src_ptr4)); + t1 = _mm256_loadu_si256(reinterpret_cast(src_ptr1)); + t5 = _mm256_loadu_si256(reinterpret_cast(src_ptr5)); + t2 = _mm256_loadu_si256(reinterpret_cast(src_ptr2)); + t6 = _mm256_loadu_si256(reinterpret_cast(src_ptr6)); + t3 = _mm256_loadu_si256(reinterpret_cast(src_ptr3)); + t7 = _mm256_loadu_si256(reinterpret_cast(src_ptr7)); + + r0 = _mm256_unpacklo_epi32(t0, t1); + r4 = _mm256_unpacklo_epi32(t4, t5); + r2 = _mm256_unpackhi_epi32(t0, t1); + r6 = _mm256_unpackhi_epi32(t4, t5); + r1 = _mm256_unpacklo_epi32(t2, t3); + r5 = _mm256_unpacklo_epi32(t6, t7); + r3 = _mm256_unpackhi_epi32(t2, t3); + r7 = _mm256_unpackhi_epi32(t6, t7); + + t0 = _mm256_unpacklo_epi64(r0, r1); + t4 = _mm256_unpacklo_epi64(r4, r5); + t2 = _mm256_unpackhi_epi64(r0, r1); + t6 = _mm256_unpackhi_epi64(r4, r5); + t1 = _mm256_unpacklo_epi64(r2, r3); + t5 = _mm256_unpacklo_epi64(r6, r7); + t3 = _mm256_unpackhi_epi64(r2, r3); + t7 = _mm256_unpackhi_epi64(r6, r7); + + // The preceding sets of rearrangement operations interleaved by 4 bytes + // and then by 8 bytes *within* lanes. The following set interleave by + // 16 bytes (128-bit), operating *between* AVX lanes. For instance (t0, + // t4) are interleaved to create (r0, r1). This complexity follows from + // the way that AVX is centered around MM 128-bit lanes. + r0 = _mm256_permute2x128_si256(t0, t4, 0x20); + r4 = _mm256_permute2x128_si256(t1, t5, 0x20); + r1 = _mm256_permute2x128_si256(t0, t4, 0x31); + r5 = _mm256_permute2x128_si256(t1, t5, 0x31); + r2 = _mm256_permute2x128_si256(t2, t6, 0x20); + r6 = _mm256_permute2x128_si256(t3, t7, 0x20); + r3 = _mm256_permute2x128_si256(t2, t6, 0x31); + r7 = _mm256_permute2x128_si256(t3, t7, 0x31); + + r0 = _mm256_xor_si256(r0, input_xor_v); + r1 = _mm256_xor_si256(r1, input_xor_v); + r2 = _mm256_xor_si256(r2, input_xor_v); + r3 = _mm256_xor_si256(r3, input_xor_v); + r4 = _mm256_xor_si256(r4, input_xor_v); + r5 = _mm256_xor_si256(r5, input_xor_v); + r6 = _mm256_xor_si256(r6, input_xor_v); + r7 = _mm256_xor_si256(r7, input_xor_v); + + __m256i sums_4x4_16bit_lo; + sums_4x4_16bit_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r0)); + sums_4x4_16bit_lo = + _mm256_add_epi16(sums_4x4_16bit_lo, + _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r1))); + sums_4x4_16bit_lo = + _mm256_add_epi16(sums_4x4_16bit_lo, + _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r2))); + sums_4x4_16bit_lo = + _mm256_add_epi16(sums_4x4_16bit_lo, + _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r3))); + sums_4x4_16bit_lo = + _mm256_add_epi16(sums_4x4_16bit_lo, + _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r4))); + sums_4x4_16bit_lo = + _mm256_add_epi16(sums_4x4_16bit_lo, + _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r5))); + sums_4x4_16bit_lo = + _mm256_add_epi16(sums_4x4_16bit_lo, + _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r6))); + sums_4x4_16bit_lo = + _mm256_add_epi16(sums_4x4_16bit_lo, + _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r7))); + + // The sums have been performed across columns, and now we have 4x16-bit + // sums packed together. We use madd for pairwise 32-bit sums. + const __m256i sums_4x2_32bit_lo_new = + _mm256_madd_epi16(sums_4x4_16bit_lo, ones_16bit); + sums_4x2_32bit_lo = + _mm256_add_epi32(sums_4x2_32bit_lo, sums_4x2_32bit_lo_new); + + __m256i sums_4x4_16bit_hi; + sums_4x4_16bit_hi = + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r0, 1)); + sums_4x4_16bit_hi = _mm256_add_epi16( + sums_4x4_16bit_hi, + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r1, 1))); + sums_4x4_16bit_hi = _mm256_add_epi16( + sums_4x4_16bit_hi, + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r2, 1))); + sums_4x4_16bit_hi = _mm256_add_epi16( + sums_4x4_16bit_hi, + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r3, 1))); + sums_4x4_16bit_hi = _mm256_add_epi16( + sums_4x4_16bit_hi, + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r4, 1))); + sums_4x4_16bit_hi = _mm256_add_epi16( + sums_4x4_16bit_hi, + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r5, 1))); + sums_4x4_16bit_hi = _mm256_add_epi16( + sums_4x4_16bit_hi, + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r6, 1))); + sums_4x4_16bit_hi = _mm256_add_epi16( + sums_4x4_16bit_hi, + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r7, 1))); + + const __m256i sums_4x2_32bit_hi_new = + _mm256_madd_epi16(sums_4x4_16bit_hi, ones_16bit); + sums_4x2_32bit_hi = + _mm256_add_epi32(sums_4x2_32bit_hi, sums_4x2_32bit_hi_new); + + _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 0 * 8 * 4), + r0); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 2 * 8 * 4), + r4); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 4 * 8 * 4), + r1); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 6 * 8 * 4), + r5); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 1 * 8 * 4), + r2); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 3 * 8 * 4), + r6); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 5 * 8 * 4), + r3); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 7 * 8 * 4), + r7); + } else { + __m256i t0, t1, t2, t3, t4, t5, t6, t7; + __m256i r0, r1, r2, r3, r4, r5, r6, r7; + const __m256i input_xor_v = _mm256_set1_epi8(input_xor); + + t0 = _mm256_loadu_si256(reinterpret_cast(src_ptr0)); + t4 = _mm256_loadu_si256(reinterpret_cast(src_ptr4)); + t1 = _mm256_loadu_si256(reinterpret_cast(src_ptr1)); + t5 = _mm256_loadu_si256(reinterpret_cast(src_ptr5)); + t2 = _mm256_loadu_si256(reinterpret_cast(src_ptr2)); + t6 = _mm256_loadu_si256(reinterpret_cast(src_ptr6)); + t3 = _mm256_loadu_si256(reinterpret_cast(src_ptr3)); + t7 = _mm256_loadu_si256(reinterpret_cast(src_ptr7)); + + r0 = _mm256_unpacklo_epi32(t0, t1); + r4 = _mm256_unpacklo_epi32(t4, t5); + r2 = _mm256_unpackhi_epi32(t0, t1); + r6 = _mm256_unpackhi_epi32(t4, t5); + r1 = _mm256_unpacklo_epi32(t2, t3); + r5 = _mm256_unpacklo_epi32(t6, t7); + r3 = _mm256_unpackhi_epi32(t2, t3); + r7 = _mm256_unpackhi_epi32(t6, t7); + + t0 = _mm256_unpacklo_epi64(r0, r1); + t4 = _mm256_unpacklo_epi64(r4, r5); + t2 = _mm256_unpackhi_epi64(r0, r1); + t6 = _mm256_unpackhi_epi64(r4, r5); + t1 = _mm256_unpacklo_epi64(r2, r3); + t5 = _mm256_unpacklo_epi64(r6, r7); + t3 = _mm256_unpackhi_epi64(r2, r3); + t7 = _mm256_unpackhi_epi64(r6, r7); + + // The preceding sets of rearrangement operations interleaved by 4 bytes + // and then by 8 bytes *within* lanes. The following set interleave by + // 16 bytes (128-bit), operating *between* AVX lanes. For instance (t0, + // t4) are interleaved to create (r0, r1). This complexity follows from + // the way that AVX is centered around MM 128-bit lanes. + r0 = _mm256_permute2x128_si256(t0, t4, 0x20); + r4 = _mm256_permute2x128_si256(t1, t5, 0x20); + r1 = _mm256_permute2x128_si256(t0, t4, 0x31); + r5 = _mm256_permute2x128_si256(t1, t5, 0x31); + r2 = _mm256_permute2x128_si256(t2, t6, 0x20); + r6 = _mm256_permute2x128_si256(t3, t7, 0x20); + r3 = _mm256_permute2x128_si256(t2, t6, 0x31); + r7 = _mm256_permute2x128_si256(t3, t7, 0x31); + + r0 = _mm256_xor_si256(r0, input_xor_v); + r1 = _mm256_xor_si256(r1, input_xor_v); + r2 = _mm256_xor_si256(r2, input_xor_v); + r3 = _mm256_xor_si256(r3, input_xor_v); + r4 = _mm256_xor_si256(r4, input_xor_v); + r5 = _mm256_xor_si256(r5, input_xor_v); + r6 = _mm256_xor_si256(r6, input_xor_v); + r7 = _mm256_xor_si256(r7, input_xor_v); + + _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 0 * 8 * 4), + r0); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 2 * 8 * 4), + r4); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 4 * 8 * 4), + r1); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 6 * 8 * 4), + r5); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 1 * 8 * 4), + r2); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 3 * 8 * 4), + r6); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 5 * 8 * 4), + r3); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(packed_ptr + 7 * 8 * 4), + r7); + } + } else if (available_src_rows > 0) { + RUY_DCHECK_LT(available_src_rows, kNumChunkedSrcRows); + // We do not care what goes into the trailing buffer, but we want + // in_data[...] ^ input_xor == 0 for irrelevant values in the summation. + // + // We compensate for padding-with-zero_point by initializing the + // summations with the compensating offset, effectively + // ((input_xor ^ input_xor) - (zero_point ^ input_xor)) * + // 4 * (8 - ((available_src_rows + 3) >> 2)). + // + // Note that (zero_point ^ input_xor) is performed in 8-bits and then + // cast. + sums_adjustment += + -(zero_point ^ input_xor) * 4 * (8 - ((available_src_rows + 3) >> 2)); + + __m256i t0, t1, t2, t3, t4, t5, t6, t7; + __m256i r0, r1, r2, r3, r4, r5, r6, r7; + const __m256i input_xor_v = _mm256_set1_epi8(input_xor); + + t0 = MaskLoadu(available_src_rows, zero_point, src_ptr0); + t4 = MaskLoadu(available_src_rows, zero_point, src_ptr4); + t1 = MaskLoadu(available_src_rows, zero_point, src_ptr1); + t5 = MaskLoadu(available_src_rows, zero_point, src_ptr5); + t2 = MaskLoadu(available_src_rows, zero_point, src_ptr2); + t6 = MaskLoadu(available_src_rows, zero_point, src_ptr6); + t3 = MaskLoadu(available_src_rows, zero_point, src_ptr3); + t7 = MaskLoadu(available_src_rows, zero_point, src_ptr7); + + r0 = _mm256_unpacklo_epi32(t0, t1); + r4 = _mm256_unpacklo_epi32(t4, t5); + r2 = _mm256_unpackhi_epi32(t0, t1); + r6 = _mm256_unpackhi_epi32(t4, t5); + r1 = _mm256_unpacklo_epi32(t2, t3); + r5 = _mm256_unpacklo_epi32(t6, t7); + r3 = _mm256_unpackhi_epi32(t2, t3); + r7 = _mm256_unpackhi_epi32(t6, t7); + + t0 = _mm256_unpacklo_epi64(r0, r1); + t4 = _mm256_unpacklo_epi64(r4, r5); + t2 = _mm256_unpackhi_epi64(r0, r1); + t6 = _mm256_unpackhi_epi64(r4, r5); + t1 = _mm256_unpacklo_epi64(r2, r3); + t5 = _mm256_unpacklo_epi64(r6, r7); + t3 = _mm256_unpackhi_epi64(r2, r3); + t7 = _mm256_unpackhi_epi64(r6, r7); + + // The preceding sets of rearrangement operations interleaved by 4 bytes + // and then by 8 bytes *within* lanes. The following set interleave by + // 16 bytes (128-bit), operating *between* AVX lanes. For instance (t0, + // t4) are interleaved to create (r0, r1). This complexity follows from + // the way that AVX is centered around MM 128-bit lanes. + r0 = _mm256_permute2x128_si256(t0, t4, 0x20); + r4 = _mm256_permute2x128_si256(t1, t5, 0x20); + r1 = _mm256_permute2x128_si256(t0, t4, 0x31); + r5 = _mm256_permute2x128_si256(t1, t5, 0x31); + r2 = _mm256_permute2x128_si256(t2, t6, 0x20); + r6 = _mm256_permute2x128_si256(t3, t7, 0x20); + r3 = _mm256_permute2x128_si256(t2, t6, 0x31); + r7 = _mm256_permute2x128_si256(t3, t7, 0x31); + + r0 = _mm256_xor_si256(r0, input_xor_v); + r1 = _mm256_xor_si256(r1, input_xor_v); + r2 = _mm256_xor_si256(r2, input_xor_v); + r3 = _mm256_xor_si256(r3, input_xor_v); + r4 = _mm256_xor_si256(r4, input_xor_v); + r5 = _mm256_xor_si256(r5, input_xor_v); + r6 = _mm256_xor_si256(r6, input_xor_v); + r7 = _mm256_xor_si256(r7, input_xor_v); + + __m256i sums_4x4_16bit_lo; + sums_4x4_16bit_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r0)); + sums_4x4_16bit_lo = _mm256_add_epi16( + sums_4x4_16bit_lo, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r1))); + sums_4x4_16bit_lo = _mm256_add_epi16( + sums_4x4_16bit_lo, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r2))); + sums_4x4_16bit_lo = _mm256_add_epi16( + sums_4x4_16bit_lo, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r3))); + sums_4x4_16bit_lo = _mm256_add_epi16( + sums_4x4_16bit_lo, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r4))); + sums_4x4_16bit_lo = _mm256_add_epi16( + sums_4x4_16bit_lo, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r5))); + sums_4x4_16bit_lo = _mm256_add_epi16( + sums_4x4_16bit_lo, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r6))); + sums_4x4_16bit_lo = _mm256_add_epi16( + sums_4x4_16bit_lo, _mm256_cvtepi8_epi16(_mm256_castsi256_si128(r7))); + + // The sums have been performed across columns, and now we have 4x16-bit + // sums packed together. We use madd for pairwise 32-bit sums. + const __m256i sums_4x2_32bit_lo_new = + _mm256_madd_epi16(sums_4x4_16bit_lo, ones_16bit); + sums_4x2_32bit_lo = + _mm256_add_epi32(sums_4x2_32bit_lo, sums_4x2_32bit_lo_new); + + __m256i sums_4x4_16bit_hi; + sums_4x4_16bit_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r0, 1)); + sums_4x4_16bit_hi = _mm256_add_epi16( + sums_4x4_16bit_hi, + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r1, 1))); + sums_4x4_16bit_hi = _mm256_add_epi16( + sums_4x4_16bit_hi, + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r2, 1))); + sums_4x4_16bit_hi = _mm256_add_epi16( + sums_4x4_16bit_hi, + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r3, 1))); + sums_4x4_16bit_hi = _mm256_add_epi16( + sums_4x4_16bit_hi, + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r4, 1))); + sums_4x4_16bit_hi = _mm256_add_epi16( + sums_4x4_16bit_hi, + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r5, 1))); + sums_4x4_16bit_hi = _mm256_add_epi16( + sums_4x4_16bit_hi, + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r6, 1))); + sums_4x4_16bit_hi = _mm256_add_epi16( + sums_4x4_16bit_hi, + _mm256_cvtepi8_epi16(_mm256_extracti128_si256(r7, 1))); + + const __m256i sums_4x2_32bit_hi_new = + _mm256_madd_epi16(sums_4x4_16bit_hi, ones_16bit); + sums_4x2_32bit_hi = + _mm256_add_epi32(sums_4x2_32bit_hi, sums_4x2_32bit_hi_new); + + _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 0 * 8 * 4), + r0); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 2 * 8 * 4), + r4); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 4 * 8 * 4), + r1); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 6 * 8 * 4), + r5); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 1 * 8 * 4), + r2); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 3 * 8 * 4), + r6); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 5 * 8 * 4), + r3); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(trailing_buf + 7 * 8 * 4), + r7); + } + + packed_ptr += 8 * kNumChunkedSrcRows; + src_ptr0 += src_inc0; + src_ptr1 += src_inc1; + src_ptr2 += src_inc2; + src_ptr3 += src_inc3; + src_ptr4 += src_inc4; + src_ptr5 += src_inc5; + src_ptr6 += src_inc6; + src_ptr7 += src_inc7; + } + + if (sums_ptr) { + const __m256i sums_adjustment_v = _mm256_set1_epi32(sums_adjustment); + + __m256i sums = + _mm256_loadu_si256(reinterpret_cast(sums_ptr)); + const __m256i idx = _mm256_set_epi32(7, 5, 3, 1, 6, 4, 2, 0); + + // We earlier used madd for pairwise 32-bit sums, and now we deinterlace the + // neighbours, finshing up by adding them to the stored accumulated sums. + const __m256i sums_2x4_32bit_lo = + _mm256_permutevar8x32_epi32(sums_4x2_32bit_lo, idx); + const __m256i sums_2x4_32bit_hi = + _mm256_permutevar8x32_epi32(sums_4x2_32bit_hi, idx); + const __m256i sums_2x4_32bit_a = + _mm256_permute2x128_si256(sums_2x4_32bit_lo, sums_2x4_32bit_hi, 0x20); + const __m256i sums_2x4_32bit_b = + _mm256_permute2x128_si256(sums_2x4_32bit_lo, sums_2x4_32bit_hi, 0x31); + sums = _mm256_add_epi32(sums, sums_adjustment_v); + sums = _mm256_add_epi32(sums, sums_2x4_32bit_a); + sums = _mm256_add_epi32(sums, sums_2x4_32bit_b); + + _mm256_storeu_si256(reinterpret_cast<__m256i*>(sums_ptr), sums); + } +} + +inline __m256 Mm256UnpackloPsx2(const __m256 a, const __m256 b) { + return _mm256_castpd_ps( + _mm256_unpacklo_pd(_mm256_castps_pd(a), _mm256_castps_pd(b))); +} + +inline __m256 Mm256UnpackhiPsx2(const __m256 a, const __m256 b) { + return _mm256_castpd_ps( + _mm256_unpackhi_pd(_mm256_castps_pd(a), _mm256_castps_pd(b))); +} + +inline void PackFloatAvx2Packer(const float* src_ptr, const float* zerobuf, + int src_stride, int remaining_src_cols, + int src_rows, float* packed_ptr, + float* trailing_buf) { + RUY_DCHECK_EQ(PackImplFloatAvx2::Layout::kCols, 8); + RUY_DCHECK_EQ(PackImplFloatAvx2::Layout::kRows, 1); + + // This packing amounts to transposition of 8x8 blocks. + static constexpr int kPackCols = 8; // Source cols packed together. + static constexpr int kPackRows = 8; // Short input is padded. + + const float* src_ptr0 = src_ptr; + const float* src_ptr1 = src_ptr0 + src_stride; + const float* src_ptr2 = src_ptr1 + src_stride; + const float* src_ptr3 = src_ptr2 + src_stride; + const float* src_ptr4 = src_ptr3 + src_stride; + const float* src_ptr5 = src_ptr4 + src_stride; + const float* src_ptr6 = src_ptr5 + src_stride; + const float* src_ptr7 = src_ptr6 + src_stride; + std::int64_t src_inc0 = 8; + std::int64_t src_inc1 = 8; + std::int64_t src_inc2 = 8; + std::int64_t src_inc3 = 8; + std::int64_t src_inc4 = 8; + std::int64_t src_inc5 = 8; + std::int64_t src_inc6 = 8; + std::int64_t src_inc7 = 8; + // Handle cases where source does not have kPackDim (8) columns. + if (remaining_src_cols < kPackCols) { + if (remaining_src_cols <= 0) { + src_ptr0 = zerobuf; + src_inc0 = 0; + } + if (remaining_src_cols <= 1) { + src_ptr1 = zerobuf; + src_inc1 = 0; + } + if (remaining_src_cols <= 2) { + src_ptr2 = zerobuf; + src_inc2 = 0; + } + if (remaining_src_cols <= 3) { + src_ptr3 = zerobuf; + src_inc3 = 0; + } + if (remaining_src_cols <= 4) { + src_ptr4 = zerobuf; + src_inc4 = 0; + } + if (remaining_src_cols <= 5) { + src_ptr5 = zerobuf; + src_inc5 = 0; + } + if (remaining_src_cols <= 6) { + src_ptr6 = zerobuf; + src_inc6 = 0; + } + src_ptr7 = zerobuf; + src_inc7 = 0; + } + + for (int k = 0; k < src_rows; k += kPackRows) { + const int available_src_rows = src_rows - k; + // Effectively, + // available_src_rows = std::max(0, std::min(kPackDim, src_rows - k)); + // but treat each case separately. + if (available_src_rows >= kPackRows) { + __m256 t0, t1, t2, t3, t4, t5, t6, t7; + __m256 r0, r1, r2, r3, r4, r5, r6, r7; + + t0 = _mm256_loadu_ps(src_ptr0); + t4 = _mm256_loadu_ps(src_ptr4); + t1 = _mm256_loadu_ps(src_ptr1); + t5 = _mm256_loadu_ps(src_ptr5); + t2 = _mm256_loadu_ps(src_ptr2); + t6 = _mm256_loadu_ps(src_ptr6); + t3 = _mm256_loadu_ps(src_ptr3); + t7 = _mm256_loadu_ps(src_ptr7); + + r0 = _mm256_unpacklo_ps(t0, t1); + r4 = _mm256_unpacklo_ps(t4, t5); + r2 = _mm256_unpackhi_ps(t0, t1); + r6 = _mm256_unpackhi_ps(t4, t5); + r1 = _mm256_unpacklo_ps(t2, t3); + r5 = _mm256_unpacklo_ps(t6, t7); + r3 = _mm256_unpackhi_ps(t2, t3); + r7 = _mm256_unpackhi_ps(t6, t7); + + t0 = Mm256UnpackloPsx2(r0, r1); + t4 = Mm256UnpackloPsx2(r4, r5); + t2 = Mm256UnpackhiPsx2(r0, r1); + t6 = Mm256UnpackhiPsx2(r4, r5); + t1 = Mm256UnpackloPsx2(r2, r3); + t5 = Mm256UnpackloPsx2(r6, r7); + t3 = Mm256UnpackhiPsx2(r2, r3); + t7 = Mm256UnpackhiPsx2(r6, r7); + + // The preceding sets of rearrangement operations interleaved by 4 bytes + // and then by 8 bytes *within* lanes. The following set interleave by 16 + // bytes (128-bit), operating *between* AVX lanes. For instance (t0, t4) + // are interleaved to create (r0, r1). This complexity follows from the + // way that AVX is centered around MM 128-bit lanes. + r0 = _mm256_permute2f128_ps(t0, t4, 0x20); + r4 = _mm256_permute2f128_ps(t1, t5, 0x20); + r1 = _mm256_permute2f128_ps(t0, t4, 0x31); + r5 = _mm256_permute2f128_ps(t1, t5, 0x31); + r2 = _mm256_permute2f128_ps(t2, t6, 0x20); + r6 = _mm256_permute2f128_ps(t3, t7, 0x20); + r3 = _mm256_permute2f128_ps(t2, t6, 0x31); + r7 = _mm256_permute2f128_ps(t3, t7, 0x31); + + _mm256_storeu_ps(packed_ptr + 0 * 8, r0); + _mm256_storeu_ps(packed_ptr + 2 * 8, r4); + _mm256_storeu_ps(packed_ptr + 4 * 8, r1); + _mm256_storeu_ps(packed_ptr + 6 * 8, r5); + _mm256_storeu_ps(packed_ptr + 1 * 8, r2); + _mm256_storeu_ps(packed_ptr + 3 * 8, r6); + _mm256_storeu_ps(packed_ptr + 5 * 8, r3); + _mm256_storeu_ps(packed_ptr + 7 * 8, r7); + } else if (available_src_rows > 0) { + const __m256i series = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0); + const __m256i row_mask_v = + _mm256_cmpgt_epi32(_mm256_set1_epi32(available_src_rows), series); + + __m256 t0, t1, t2, t3, t4, t5, t6, t7; + __m256 r0, r1, r2, r3, r4, r5, r6, r7; + + t0 = _mm256_maskload_ps(src_ptr0, row_mask_v); + t4 = _mm256_maskload_ps(src_ptr4, row_mask_v); + t1 = _mm256_maskload_ps(src_ptr1, row_mask_v); + t5 = _mm256_maskload_ps(src_ptr5, row_mask_v); + t2 = _mm256_maskload_ps(src_ptr2, row_mask_v); + t6 = _mm256_maskload_ps(src_ptr6, row_mask_v); + t3 = _mm256_maskload_ps(src_ptr3, row_mask_v); + t7 = _mm256_maskload_ps(src_ptr7, row_mask_v); + + r0 = _mm256_unpacklo_ps(t0, t1); + r4 = _mm256_unpacklo_ps(t4, t5); + r2 = _mm256_unpackhi_ps(t0, t1); + r6 = _mm256_unpackhi_ps(t4, t5); + r1 = _mm256_unpacklo_ps(t2, t3); + r5 = _mm256_unpacklo_ps(t6, t7); + r3 = _mm256_unpackhi_ps(t2, t3); + r7 = _mm256_unpackhi_ps(t6, t7); + + t0 = Mm256UnpackloPsx2(r0, r1); + t4 = Mm256UnpackloPsx2(r4, r5); + t2 = Mm256UnpackhiPsx2(r0, r1); + t6 = Mm256UnpackhiPsx2(r4, r5); + t1 = Mm256UnpackloPsx2(r2, r3); + t5 = Mm256UnpackloPsx2(r6, r7); + t3 = Mm256UnpackhiPsx2(r2, r3); + t7 = Mm256UnpackhiPsx2(r6, r7); + + // The preceding sets of rearrangement operations interleaved by 4 bytes + // and then by 8 bytes *within* lanes. The following set interleave by 16 + // bytes (128-bit), operating *between* AVX lanes. For instance (t0, t4) + // are interleaved to create (r0, r1). This complexity follows from the + // way that AVX is centered around MM 128-bit lanes. + r0 = _mm256_permute2f128_ps(t0, t4, 0x20); + r4 = _mm256_permute2f128_ps(t1, t5, 0x20); + r1 = _mm256_permute2f128_ps(t0, t4, 0x31); + r5 = _mm256_permute2f128_ps(t1, t5, 0x31); + r2 = _mm256_permute2f128_ps(t2, t6, 0x20); + r6 = _mm256_permute2f128_ps(t3, t7, 0x20); + r3 = _mm256_permute2f128_ps(t2, t6, 0x31); + // r7 no longer needed. + + _mm256_storeu_ps(trailing_buf + 0 * 8, r0); + _mm256_storeu_ps(trailing_buf + 2 * 8, r4); + _mm256_storeu_ps(trailing_buf + 4 * 8, r1); + _mm256_storeu_ps(trailing_buf + 6 * 8, r5); + _mm256_storeu_ps(trailing_buf + 1 * 8, r2); + _mm256_storeu_ps(trailing_buf + 3 * 8, r6); + _mm256_storeu_ps(trailing_buf + 5 * 8, r3); + // No store to (trailing_buf + 7 * 8), space not allocated. + } + + packed_ptr += kPackRows * kPackCols; + src_ptr0 += src_inc0; + src_ptr1 += src_inc1; + src_ptr2 += src_inc2; + src_ptr3 += src_inc3; + src_ptr4 += src_inc4; + src_ptr5 += src_inc5; + src_ptr6 += src_inc6; + src_ptr7 += src_inc7; + } +} + +} // namespace. + +void Pack8bitAvx2(const std::int8_t* src_ptr, std::int8_t input_xor, + const std::int8_t* zerobuf, int src_stride, + int remaining_src_cols, int src_rows, std::int8_t* packed_ptr, + std::int32_t* sums_ptr) { + profiler::ScopeLabel label("Pack kAvx2 8bit"); + + using Layout = PackImpl8bitAvx2::Layout; + RUY_DCHECK_EQ(Layout::kCols, 8); + RUY_DCHECK_EQ(Layout::kRows, 4); + + // Each Layout::Rows is 4 contiguous input, contiguous packed elements. + // We process 8 of these chunks at a time, padding short input chunks. + static constexpr int kNumRowChunks = 8; // Short input is padded. + + // Each packed block is 4*8, and there are normally 8. The trailing block is + // only slightly shorter. + constexpr int kTrailingBufSize = + kNumRowChunks * Layout::kCols * Layout::kRows; + std::int8_t trailing_buf[kTrailingBufSize]; + memset(trailing_buf, 0, kTrailingBufSize * sizeof(std::int8_t)); + + Pack8bitAvx2Packer(src_ptr, input_xor, zerobuf, src_stride, + remaining_src_cols, src_rows, packed_ptr, sums_ptr, + trailing_buf); + + constexpr int kChunkedRowMask = kNumRowChunks * Layout::kRows - 1; + const bool trailing_data = (src_rows & kChunkedRowMask) > 0; + // If the number of source rows is not a multiple of kChunkedRowMask, there + // will be data in the trailing buffer, + if (trailing_data > 0) { + const int non_trailing_rows = src_rows & ~kChunkedRowMask; + // Destination "rows" are padded to next highest multiple of Layout::kRows. + const int dst_rows = (src_rows + 3) & ~3; + const int trailing_rows = dst_rows - non_trailing_rows; + memcpy(packed_ptr + Layout::kCols * non_trailing_rows, trailing_buf, + Layout::kCols * trailing_rows * sizeof(std::int8_t)); + } +} + +void PackFloatAvx2(const float* src_ptr, const float* zerobuf, int src_stride, + int remaining_src_cols, int src_rows, float* packed_ptr) { + profiler::ScopeLabel label("Pack kAvx2 float"); + static constexpr int kPackCols = 8; // Source cols packed together. + static constexpr int kPackRows = 8; // Short input is padded. + float trailing_buf[(kPackRows - 1) * kPackCols]; + if (remaining_src_cols < 8) { + memset(trailing_buf, 0, sizeof(trailing_buf)); + } + PackFloatAvx2Packer(src_ptr, zerobuf, src_stride, remaining_src_cols, + src_rows, packed_ptr, trailing_buf); + + const int trailing_rows = src_rows & (kPackRows - 1); + if (trailing_rows > 0) { + const int non_trailing_rows = src_rows & ~(kPackRows - 1); + memcpy(packed_ptr + kPackCols * non_trailing_rows, trailing_buf, + kPackCols * trailing_rows * sizeof(float)); + } +} + +#endif // RUY_PLATFORM(AVX2) && RUY_OPT_ENABLED(RUY_OPT_INTRINSICS) + +} // namespace ruy diff --git a/ruy/pack_avx512.cc b/ruy/pack_avx512.cc new file mode 100644 index 0000000..ecad3a2 --- /dev/null +++ b/ruy/pack_avx512.cc @@ -0,0 +1,693 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "ruy/check_macros.h" +#include "ruy/matrix.h" +#include "ruy/opt_set.h" +#include "ruy/pack.h" +#include "ruy/path.h" +#include "ruy/platform.h" +#include "ruy/profiler/instrumentation.h" + +#if RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_INTRINSICS) +#include // IWYU pragma: keep +#endif + +namespace ruy { + +#if !(RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_ASM)) + +void Pack8bitAvx512(const std::int8_t* src_ptr, std::int8_t input_xor, + const std::int8_t* zerobuf, int src_stride, + int remaining_src_cols, int src_rows, + std::int8_t* packed_ptr, std::int32_t* sums_ptr) { + // CPU-ID-based checks should disable the path that would reach this point. + RUY_DCHECK(false); +} + +void PackFloatAvx512(const float* src_ptr, const float* zerobuf, int src_stride, + int remaining_src_cols, int src_rows, float* packed_ptr) { + // CPU-ID-based checks should disable the path that would reach this point. + RUY_DCHECK(false); +} + +#else // RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_ASM) + +// The first int8_t template parameter is arbitrary: this routine is common to +// all 8-bit source matrix types. +using PackImpl8bitAvx512 = + PackImpl, + std::int8_t, std::int8_t, std::int32_t>; + +namespace { + +inline void ZeroHalf8bitAvx512(int src_rows, std::int8_t packed_zero_point, + std::int8_t* packed_ptr) { + using Layout = PackImpl8bitAvx512::Layout; + static constexpr int kHalfLayoutCols = + PackImpl8bitAvx512::kHalfLayoutCols; // Half the number of cols in a + // block. + RUY_DCHECK_EQ(kHalfLayoutCols, 8); + RUY_DCHECK_EQ(Layout::kCols, 16); + RUY_DCHECK_EQ(Layout::kRows, 4); + + const int non_trailing_blocks = (src_rows & ~31) >> 2; + // This routine fills half blocks, and typically fills the second halves. + // Thus packed_ptr is already offset by 8 * 4. + for (int k = 0; k < non_trailing_blocks; ++k) { + for (int j = 0; j < (kHalfLayoutCols * Layout::kRows); ++j) { + packed_ptr[Layout::kCols * Layout::kRows * k + j] = packed_zero_point; + } + } +} + +inline __m512i LoaduTwo(const std::int8_t* addr_lo, + const std::int8_t* addr_hi) { + __m512i lower_filled = _mm512_castsi256_si512(_mm256_loadu_epi8(addr_lo)); + return _mm512_inserti32x8(lower_filled, _mm256_loadu_epi8(addr_hi), 1); +} + +inline __m512i MaskLoaduTwo(__mmask32 row_mask, const __m256i default_value_v, + const std::int8_t* addr_lo, + const std::int8_t* addr_hi) { + const __m512i lower_filled = _mm512_castsi256_si512( + _mm256_mask_loadu_epi8(default_value_v, row_mask, addr_lo)); + return _mm512_inserti32x8( + lower_filled, _mm256_mask_loadu_epi8(default_value_v, row_mask, addr_hi), + 1); +} + +inline void HalfPack8bitAvx512(const std::int8_t* src_ptr, + std::int8_t input_xor, + const std::int8_t* zerobuf, int src_stride, + int remaining_src_cols, int src_rows, + std::int8_t* packed_ptr, std::int32_t* sums_ptr, + std::int8_t* trailing_buf) { + using Layout = PackImpl8bitAvx512::Layout; + RUY_DCHECK_EQ(Layout::kCols, 16); + RUY_DCHECK_EQ(Layout::kRows, 4); + // Each Layout::Rows is 4 contiguous input, contiguous packed elements. + // We process 8 of these chunks at a time, padding short input chunks. + constexpr int kNumRowChunks = 8; + constexpr int kNumChunkedSrcRows = kNumRowChunks * Layout::kRows; + + const std::int8_t* src_ptr0 = src_ptr; + const std::int8_t* src_ptr1 = src_ptr0 + src_stride; + const std::int8_t* src_ptr2 = src_ptr1 + src_stride; + const std::int8_t* src_ptr3 = src_ptr2 + src_stride; + const std::int8_t* src_ptr4 = src_ptr3 + src_stride; + const std::int8_t* src_ptr5 = src_ptr4 + src_stride; + const std::int8_t* src_ptr6 = src_ptr5 + src_stride; + const std::int8_t* src_ptr7 = src_ptr6 + src_stride; + std::int64_t src_inc0 = kNumChunkedSrcRows; + std::int64_t src_inc1 = kNumChunkedSrcRows; + std::int64_t src_inc2 = kNumChunkedSrcRows; + std::int64_t src_inc3 = kNumChunkedSrcRows; + std::int64_t src_inc4 = kNumChunkedSrcRows; + std::int64_t src_inc5 = kNumChunkedSrcRows; + std::int64_t src_inc6 = kNumChunkedSrcRows; + std::int64_t src_inc7 = kNumChunkedSrcRows; + // Handle cases where source does not have kHalfLayoutCols (8) columns. + if (remaining_src_cols < 8) { + if (remaining_src_cols <= 0) { + src_ptr0 = zerobuf; + src_inc0 = 0; + } + if (remaining_src_cols <= 1) { + src_ptr1 = zerobuf; + src_inc1 = 0; + } + if (remaining_src_cols <= 2) { + src_ptr2 = zerobuf; + src_inc2 = 0; + } + if (remaining_src_cols <= 3) { + src_ptr3 = zerobuf; + src_inc3 = 0; + } + if (remaining_src_cols <= 4) { + src_ptr4 = zerobuf; + src_inc4 = 0; + } + if (remaining_src_cols <= 5) { + src_ptr5 = zerobuf; + src_inc5 = 0; + } + if (remaining_src_cols <= 6) { + src_ptr6 = zerobuf; + src_inc6 = 0; + } + src_ptr7 = zerobuf; + src_inc7 = 0; + } + + const std::int8_t zero_point = zerobuf[0]; + + if (sums_ptr) { + // i: kHalfLayoutCols. + for (int i = 0; i < 8; ++i) { + sums_ptr[i] = 0; + } + } + std::int32_t sums_adjustment = 0; + const __m512i ones_16bit = _mm512_set1_epi16(1); + __m512i sums_8x2_32bit = _mm512_set1_epi32(0); + + // The overall packing effectively pads the source rows to + // (src_rows + 63) & ~63. The iteration over k may skip when m=1, and then we + // only pack for (src_rows + 31) & ~31. When there is an incomplete + // destination block, this is stored into trailing_buf instead of packed_ptr. + for (int k = 0; k < src_rows; k += 2 * kNumChunkedSrcRows) { + // m: {0, 1} for 2 chunks of rows. + for (int m = 0; m < 2; ++m) { + // Available source rows. + // If this is less than 0 (for m=1), we skip, having filled trailing + // buffer for m=0. Also, if source rows is zero on m=1, then we filled + // exactly to the end of the column in the packed buffer. + const int available_src_rows = src_rows - k - m * kNumChunkedSrcRows; + // Effectively, + // available rows = std::max(0, std::min(8, src_rows - k - 8 * 4 * m)); + // treat each case separately. + if (available_src_rows >= kNumChunkedSrcRows) { + // i: chunks, s: Layout::Rows. + if (sums_ptr) { + __m512i t0, t1, t2, t3; + __m512i r0, r1, r2, r3; + const __m512i input_xor_v = _mm512_set1_epi8(input_xor); + + t0 = LoaduTwo(src_ptr0, src_ptr4); + t1 = LoaduTwo(src_ptr1, src_ptr5); + t2 = LoaduTwo(src_ptr2, src_ptr6); + t3 = LoaduTwo(src_ptr3, src_ptr7); + + r0 = _mm512_unpacklo_epi32(t0, t1); + r2 = _mm512_unpackhi_epi32(t0, t1); + r1 = _mm512_unpacklo_epi32(t2, t3); + r3 = _mm512_unpackhi_epi32(t2, t3); + + t0 = _mm512_unpacklo_epi64(r0, r1); + t2 = _mm512_unpackhi_epi64(r0, r1); + t1 = _mm512_unpacklo_epi64(r2, r3); + t3 = _mm512_unpackhi_epi64(r2, r3); + + r0 = _mm512_shuffle_i32x4(t0, t1, 0x88); + r1 = _mm512_shuffle_i32x4(t0, t1, 0xdd); + r2 = _mm512_shuffle_i32x4(t2, t3, 0x88); + r3 = _mm512_shuffle_i32x4(t2, t3, 0xdd); + + r0 = _mm512_xor_si512(r0, input_xor_v); + r1 = _mm512_xor_si512(r1, input_xor_v); + r2 = _mm512_xor_si512(r2, input_xor_v); + r3 = _mm512_xor_si512(r3, input_xor_v); + + const __m256i r0_0 = _mm512_castsi512_si256(r0); + const __m256i r0_1 = _mm512_extracti32x8_epi32(r0, 1); + const __m256i r1_0 = _mm512_castsi512_si256(r1); + const __m256i r1_1 = _mm512_extracti32x8_epi32(r1, 1); + const __m256i r2_0 = _mm512_castsi512_si256(r2); + const __m256i r2_1 = _mm512_extracti32x8_epi32(r2, 1); + const __m256i r3_0 = _mm512_castsi512_si256(r3); + const __m256i r3_1 = _mm512_extracti32x8_epi32(r3, 1); + + __m512i sums_8x4_16bit; + sums_8x4_16bit = _mm512_cvtepi8_epi16(r0_0); + sums_8x4_16bit = + _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r0_1)); + sums_8x4_16bit = + _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r1_0)); + sums_8x4_16bit = + _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r1_1)); + sums_8x4_16bit = + _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r2_0)); + sums_8x4_16bit = + _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r2_1)); + sums_8x4_16bit = + _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r3_0)); + sums_8x4_16bit = + _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r3_1)); + // The sums have been performed across columns, and now we have + // 4x16-bit sums packed together. We use madd for pairwise 32-bit + // sums. + const __m512i sums_8x2_32bit_new = + _mm512_madd_epi16(sums_8x4_16bit, ones_16bit); + sums_8x2_32bit = _mm512_add_epi32(sums_8x2_32bit, sums_8x2_32bit_new); + + _mm256_storeu_epi8(packed_ptr + 0 * 16 * 4, r0_0); + _mm256_storeu_epi8(packed_ptr + 2 * 16 * 4, r0_1); + _mm256_storeu_epi8(packed_ptr + 4 * 16 * 4, r1_0); + _mm256_storeu_epi8(packed_ptr + 6 * 16 * 4, r1_1); + _mm256_storeu_epi8(packed_ptr + 1 * 16 * 4, r2_0); + _mm256_storeu_epi8(packed_ptr + 3 * 16 * 4, r2_1); + _mm256_storeu_epi8(packed_ptr + 5 * 16 * 4, r3_0); + _mm256_storeu_epi8(packed_ptr + 7 * 16 * 4, r3_1); + } else { + __m512i t0, t1, t2, t3; + __m512i r0, r1, r2, r3; + const __m512i input_xor_v = _mm512_set1_epi8(input_xor); + + t0 = LoaduTwo(src_ptr0, src_ptr4); + t1 = LoaduTwo(src_ptr1, src_ptr5); + t2 = LoaduTwo(src_ptr2, src_ptr6); + t3 = LoaduTwo(src_ptr3, src_ptr7); + + r0 = _mm512_unpacklo_epi32(t0, t1); + r2 = _mm512_unpackhi_epi32(t0, t1); + r1 = _mm512_unpacklo_epi32(t2, t3); + r3 = _mm512_unpackhi_epi32(t2, t3); + + t0 = _mm512_unpacklo_epi64(r0, r1); + t2 = _mm512_unpackhi_epi64(r0, r1); + t1 = _mm512_unpacklo_epi64(r2, r3); + t3 = _mm512_unpackhi_epi64(r2, r3); + + r0 = _mm512_shuffle_i32x4(t0, t1, 0x88); + r1 = _mm512_shuffle_i32x4(t0, t1, 0xdd); + r2 = _mm512_shuffle_i32x4(t2, t3, 0x88); + r3 = _mm512_shuffle_i32x4(t2, t3, 0xdd); + + r0 = _mm512_xor_si512(r0, input_xor_v); + r1 = _mm512_xor_si512(r1, input_xor_v); + r2 = _mm512_xor_si512(r2, input_xor_v); + r3 = _mm512_xor_si512(r3, input_xor_v); + + const __m256i r0_0 = _mm512_castsi512_si256(r0); + const __m256i r0_1 = _mm512_extracti32x8_epi32(r0, 1); + const __m256i r1_0 = _mm512_castsi512_si256(r1); + const __m256i r1_1 = _mm512_extracti32x8_epi32(r1, 1); + const __m256i r2_0 = _mm512_castsi512_si256(r2); + const __m256i r2_1 = _mm512_extracti32x8_epi32(r2, 1); + const __m256i r3_0 = _mm512_castsi512_si256(r3); + const __m256i r3_1 = _mm512_extracti32x8_epi32(r3, 1); + _mm256_storeu_epi8(packed_ptr + 0 * 16 * 4, r0_0); + _mm256_storeu_epi8(packed_ptr + 2 * 16 * 4, r0_1); + _mm256_storeu_epi8(packed_ptr + 4 * 16 * 4, r1_0); + _mm256_storeu_epi8(packed_ptr + 6 * 16 * 4, r1_1); + _mm256_storeu_epi8(packed_ptr + 1 * 16 * 4, r2_0); + _mm256_storeu_epi8(packed_ptr + 3 * 16 * 4, r2_1); + _mm256_storeu_epi8(packed_ptr + 5 * 16 * 4, r3_0); + _mm256_storeu_epi8(packed_ptr + 7 * 16 * 4, r3_1); + } + } else if (available_src_rows > 0) { + RUY_DCHECK_LT(available_src_rows >> 2, kNumChunkedSrcRows); + const __mmask32 row_mask = + (static_cast(1) << available_src_rows) - 1; + + // We do not care what goes into the trailing buffer, but we want + // in_data[...] ^ input_xor == 0 for irrelevant values in the summation. + // + // We compensate for padding-with-zero_point by initializing the + // summations with the compensating offset, effectively + // ((input_xor ^ input_xor) - (zero_point ^ input_xor)) * + // 4 * (8 - ((available_src_rows + 3) >> 2)). + // + // Note that (zero_point ^ input_xor) is performed in 8-bits and then + // cast. + sums_adjustment += -(zero_point ^ input_xor) * 4 * + (8 - ((available_src_rows + 3) >> 2)); + + __m512i t0, t1, t2, t3; + __m512i r0, r1, r2, r3; + const __m512i input_xor_v = _mm512_set1_epi8(input_xor); + const __m256i zero_point_v = _mm256_set1_epi8(zero_point); + + t0 = MaskLoaduTwo(row_mask, zero_point_v, src_ptr0, src_ptr4); + t1 = MaskLoaduTwo(row_mask, zero_point_v, src_ptr1, src_ptr5); + t2 = MaskLoaduTwo(row_mask, zero_point_v, src_ptr2, src_ptr6); + t3 = MaskLoaduTwo(row_mask, zero_point_v, src_ptr3, src_ptr7); + + r0 = _mm512_unpacklo_epi32(t0, t1); + r2 = _mm512_unpackhi_epi32(t0, t1); + r1 = _mm512_unpacklo_epi32(t2, t3); + r3 = _mm512_unpackhi_epi32(t2, t3); + + t0 = _mm512_unpacklo_epi64(r0, r1); + t2 = _mm512_unpackhi_epi64(r0, r1); + t1 = _mm512_unpacklo_epi64(r2, r3); + t3 = _mm512_unpackhi_epi64(r2, r3); + + r0 = _mm512_shuffle_i32x4(t0, t1, 0x88); + r1 = _mm512_shuffle_i32x4(t0, t1, 0xdd); + r2 = _mm512_shuffle_i32x4(t2, t3, 0x88); + r3 = _mm512_shuffle_i32x4(t2, t3, 0xdd); + + r0 = _mm512_xor_si512(r0, input_xor_v); + r1 = _mm512_xor_si512(r1, input_xor_v); + r2 = _mm512_xor_si512(r2, input_xor_v); + r3 = _mm512_xor_si512(r3, input_xor_v); + + const __m256i r0_0 = _mm512_castsi512_si256(r0); + const __m256i r0_1 = _mm512_extracti32x8_epi32(r0, 1); + const __m256i r1_0 = _mm512_castsi512_si256(r1); + const __m256i r1_1 = _mm512_extracti32x8_epi32(r1, 1); + const __m256i r2_0 = _mm512_castsi512_si256(r2); + const __m256i r2_1 = _mm512_extracti32x8_epi32(r2, 1); + const __m256i r3_0 = _mm512_castsi512_si256(r3); + const __m256i r3_1 = _mm512_extracti32x8_epi32(r3, 1); + + __m512i sums_8x4_16bit; + sums_8x4_16bit = _mm512_cvtepi8_epi16(r0_0); + sums_8x4_16bit = + _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r0_1)); + sums_8x4_16bit = + _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r1_0)); + sums_8x4_16bit = + _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r1_1)); + sums_8x4_16bit = + _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r2_0)); + sums_8x4_16bit = + _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r2_1)); + sums_8x4_16bit = + _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r3_0)); + sums_8x4_16bit = + _mm512_add_epi16(sums_8x4_16bit, _mm512_cvtepi8_epi16(r3_1)); + // The sums have been performed across columns, and now we have + // 4x16-bit sums packed together. We use madd for pairwise 32-bit + // sums. + const __m512i sums_8x2_32bit_new = + _mm512_madd_epi16(sums_8x4_16bit, ones_16bit); + sums_8x2_32bit = _mm512_add_epi32(sums_8x2_32bit, sums_8x2_32bit_new); + + _mm256_storeu_epi8(trailing_buf + 0 * 16 * 4, r0_0); + _mm256_storeu_epi8(trailing_buf + 2 * 16 * 4, r0_1); + _mm256_storeu_epi8(trailing_buf + 4 * 16 * 4, r1_0); + _mm256_storeu_epi8(trailing_buf + 6 * 16 * 4, r1_1); + _mm256_storeu_epi8(trailing_buf + 1 * 16 * 4, r2_0); + _mm256_storeu_epi8(trailing_buf + 3 * 16 * 4, r2_1); + _mm256_storeu_epi8(trailing_buf + 5 * 16 * 4, r3_0); + _mm256_storeu_epi8(trailing_buf + 7 * 16 * 4, r3_1); + } + + packed_ptr += 16 * kNumChunkedSrcRows; + src_ptr0 += src_inc0; + src_ptr1 += src_inc1; + src_ptr2 += src_inc2; + src_ptr3 += src_inc3; + src_ptr4 += src_inc4; + src_ptr5 += src_inc5; + src_ptr6 += src_inc6; + src_ptr7 += src_inc7; + } + } + + if (sums_ptr) { + const __m256i sums_adjustment_v = _mm256_set1_epi32(sums_adjustment); + + __m256i sums = _mm256_loadu_epi32(sums_ptr); + const __m512i idx = + _mm512_set_epi32(15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0); + + // We earlier used madd for pairwise 32-bit sums, and now we deinterlace the + // neighbours, finshing up by adding them to the stored accumulated sums. + const __m512i sums_2x8_32bit = + _mm512_permutexvar_epi32(idx, sums_8x2_32bit); + sums = _mm256_add_epi32(sums, sums_adjustment_v); + sums = _mm256_add_epi32(sums, _mm512_castsi512_si256(sums_2x8_32bit)); + sums = _mm256_add_epi32(sums, _mm512_extracti32x8_epi32(sums_2x8_32bit, 1)); + + _mm256_storeu_epi32(sums_ptr, sums); + } +} + +inline __m512 LoaduTwo(const float* addr_lo, const float* addr_hi) { + const __m512 lower_filled = _mm512_castps256_ps512(_mm256_loadu_ps(addr_lo)); + return _mm512_insertf32x8(lower_filled, _mm256_loadu_ps(addr_hi), 1); +} + +inline __m512 MaskLoaduTwo(__mmask8 row_mask, const float* addr_lo, + const float* addr_hi) { + const __m512 lower_filled = + _mm512_castps256_ps512(_mm256_maskz_loadu_ps(row_mask, addr_lo)); + return _mm512_insertf32x8(lower_filled, + _mm256_maskz_loadu_ps(row_mask, addr_hi), 1); +} + +inline __m512 Mm512UnpackloPsx2(const __m512 a, const __m512 b) { + return _mm512_castpd_ps( + _mm512_unpacklo_pd(_mm512_castps_pd(a), _mm512_castps_pd(b))); +} + +inline __m512 Mm512UnpackhiPsx2(const __m512 a, const __m512 b) { + return _mm512_castpd_ps( + _mm512_unpackhi_pd(_mm512_castps_pd(a), _mm512_castps_pd(b))); +} + +inline void HalfPackFloatAvx512(const float* src_ptr, const float* zerobuf, + int src_stride, int remaining_src_cols, + int src_rows, float* packed_ptr, + float* trailing_buf) { + const float* src_ptr0 = src_ptr; + const float* src_ptr1 = src_ptr0 + src_stride; + const float* src_ptr2 = src_ptr1 + src_stride; + const float* src_ptr3 = src_ptr2 + src_stride; + const float* src_ptr4 = src_ptr3 + src_stride; + const float* src_ptr5 = src_ptr4 + src_stride; + const float* src_ptr6 = src_ptr5 + src_stride; + const float* src_ptr7 = src_ptr6 + src_stride; + std::int64_t src_inc0 = 8; + std::int64_t src_inc1 = 8; + std::int64_t src_inc2 = 8; + std::int64_t src_inc3 = 8; + std::int64_t src_inc4 = 8; + std::int64_t src_inc5 = 8; + std::int64_t src_inc6 = 8; + std::int64_t src_inc7 = 8; + if (remaining_src_cols < 8) { + if (remaining_src_cols <= 0) { + src_ptr0 = zerobuf; + src_inc0 = 0; + } + if (remaining_src_cols <= 1) { + src_ptr1 = zerobuf; + src_inc1 = 0; + } + if (remaining_src_cols <= 2) { + src_ptr2 = zerobuf; + src_inc2 = 0; + } + if (remaining_src_cols <= 3) { + src_ptr3 = zerobuf; + src_inc3 = 0; + } + if (remaining_src_cols <= 4) { + src_ptr4 = zerobuf; + src_inc4 = 0; + } + if (remaining_src_cols <= 5) { + src_ptr5 = zerobuf; + src_inc5 = 0; + } + if (remaining_src_cols <= 6) { + src_ptr6 = zerobuf; + src_inc6 = 0; + } + src_ptr7 = zerobuf; + src_inc7 = 0; + } + + for (int k = 0; k < src_rows; k += 16) { + for (int m = 0; m < 2; ++m) { + const int available_src_rows = src_rows - k - 8 * m; + // Effectively, + // available_src_rows = std::max(0, std::min(8, src_rows - k - 8 * m)); + // but treat each case separately. + if (available_src_rows > 7) { + __m512 t0, t1, t2, t3; + __m512 r0, r1, r2, r3; + + t0 = LoaduTwo(src_ptr0, src_ptr4); + t1 = LoaduTwo(src_ptr1, src_ptr5); + t2 = LoaduTwo(src_ptr2, src_ptr6); + t3 = LoaduTwo(src_ptr3, src_ptr7); + + r0 = _mm512_unpacklo_ps(t0, t1); + r2 = _mm512_unpackhi_ps(t0, t1); + r1 = _mm512_unpacklo_ps(t2, t3); + r3 = _mm512_unpackhi_ps(t2, t3); + + t0 = Mm512UnpackloPsx2(r0, r1); + t2 = Mm512UnpackhiPsx2(r0, r1); + t1 = Mm512UnpackloPsx2(r2, r3); + t3 = Mm512UnpackhiPsx2(r2, r3); + + r0 = _mm512_shuffle_f32x4(t0, t1, 0x88); + r1 = _mm512_shuffle_f32x4(t0, t1, 0xdd); + r2 = _mm512_shuffle_f32x4(t2, t3, 0x88); + r3 = _mm512_shuffle_f32x4(t2, t3, 0xdd); + + _mm256_storeu_ps(packed_ptr + 0 * 16, _mm512_castps512_ps256(r0)); + _mm256_storeu_ps(packed_ptr + 2 * 16, _mm512_extractf32x8_ps(r0, 1)); + _mm256_storeu_ps(packed_ptr + 4 * 16, _mm512_castps512_ps256(r1)); + _mm256_storeu_ps(packed_ptr + 6 * 16, _mm512_extractf32x8_ps(r1, 1)); + _mm256_storeu_ps(packed_ptr + 1 * 16, _mm512_castps512_ps256(r2)); + _mm256_storeu_ps(packed_ptr + 3 * 16, _mm512_extractf32x8_ps(r2, 1)); + _mm256_storeu_ps(packed_ptr + 5 * 16, _mm512_castps512_ps256(r3)); + _mm256_storeu_ps(packed_ptr + 7 * 16, _mm512_extractf32x8_ps(r3, 1)); + } else if (available_src_rows > 0) { + const __mmask8 row_mask = + (static_cast(1) << available_src_rows) - 1; + + __m512 t0, t1, t2, t3; + __m512 r0, r1, r2, r3; + + t0 = MaskLoaduTwo(row_mask, src_ptr0, src_ptr4); + t1 = MaskLoaduTwo(row_mask, src_ptr1, src_ptr5); + t2 = MaskLoaduTwo(row_mask, src_ptr2, src_ptr6); + t3 = MaskLoaduTwo(row_mask, src_ptr3, src_ptr7); + + r0 = _mm512_unpacklo_ps(t0, t1); + r2 = _mm512_unpackhi_ps(t0, t1); + r1 = _mm512_unpacklo_ps(t2, t3); + r3 = _mm512_unpackhi_ps(t2, t3); + + t0 = Mm512UnpackloPsx2(r0, r1); + t2 = Mm512UnpackhiPsx2(r0, r1); + t1 = Mm512UnpackloPsx2(r2, r3); + t3 = Mm512UnpackhiPsx2(r2, r3); + + r0 = _mm512_shuffle_f32x4(t0, t1, 0x88); + r1 = _mm512_shuffle_f32x4(t0, t1, 0xdd); + r2 = _mm512_shuffle_f32x4(t2, t3, 0x88); + r3 = _mm512_shuffle_f32x4(t2, t3, 0xdd); + + _mm256_storeu_ps(trailing_buf + 0 * 16, _mm512_castps512_ps256(r0)); + _mm256_storeu_ps(trailing_buf + 2 * 16, _mm512_extractf32x8_ps(r0, 1)); + _mm256_storeu_ps(trailing_buf + 4 * 16, _mm512_castps512_ps256(r1)); + _mm256_storeu_ps(trailing_buf + 6 * 16, _mm512_extractf32x8_ps(r1, 1)); + _mm256_storeu_ps(trailing_buf + 1 * 16, _mm512_castps512_ps256(r2)); + _mm256_storeu_ps(trailing_buf + 3 * 16, _mm512_extractf32x8_ps(r2, 1)); + _mm256_storeu_ps(trailing_buf + 5 * 16, _mm512_castps512_ps256(r3)); + // Do not store _mm512_extractf32x8_ps(r3, 1). + } + + packed_ptr += 16 * 8; + src_ptr0 += src_inc0; + src_ptr1 += src_inc1; + src_ptr2 += src_inc2; + src_ptr3 += src_inc3; + src_ptr4 += src_inc4; + src_ptr5 += src_inc5; + src_ptr6 += src_inc6; + src_ptr7 += src_inc7; + } + } +} + +inline void ZeroHalfFloatAvx512(int src_rows, float* packed_ptr) { + const int non_trailing_rows = src_rows & ~7; + for (int k = 0; k < non_trailing_rows; ++k) { + for (int j = 0; j < 8; ++j) { + packed_ptr[j] = 0.0f; + } + packed_ptr += 16; + } +} + +} // namespace. + +void Pack8bitAvx512(const std::int8_t* src_ptr, std::int8_t input_xor, + const std::int8_t* zerobuf, int src_stride, + int remaining_src_cols, int src_rows, + std::int8_t* packed_ptr, std::int32_t* sums_ptr) { + profiler::ScopeLabel label("Pack kAvx512 8bit"); + + using Layout = PackImpl8bitAvx512::Layout; + constexpr int kHalfBlockOffset = 32; + RUY_DCHECK_EQ(kHalfBlockOffset * 2, Layout::kRows * Layout::kCols); + static constexpr int kHalfLayoutCols = + PackImpl8bitAvx512::kHalfLayoutCols; // Half the number of cols in a + // block. + RUY_DCHECK_EQ(kHalfLayoutCols, 8); + RUY_DCHECK_EQ(Layout::kCols, 16); + RUY_DCHECK_EQ(Layout::kRows, 4); + + // Each Layout::Rows is 4 contiguous input, contiguous packed elements. + // We process 8 of these chunks at a time, padding short input chunks. + constexpr int kNumRowChunks = 8; + + // Each packed block is 4*16, and there are normally 8. The trailing block is + // only slightly shorter. + constexpr int kTrailingBufSize = + kNumRowChunks * Layout::kCols * Layout::kRows; + std::int8_t trailing_buf[kTrailingBufSize]; + memset(trailing_buf, 0, kTrailingBufSize * sizeof(std::int8_t)); + + std::int32_t* second_sums_ptr = + sums_ptr ? sums_ptr + kHalfLayoutCols : nullptr; + if (remaining_src_cols > kHalfLayoutCols) { + HalfPack8bitAvx512(src_ptr, input_xor, zerobuf, src_stride, + remaining_src_cols, src_rows, packed_ptr, sums_ptr, + trailing_buf); + HalfPack8bitAvx512(src_ptr + src_stride * kHalfLayoutCols, input_xor, + zerobuf, src_stride, + remaining_src_cols - kHalfLayoutCols, src_rows, + packed_ptr + kHalfBlockOffset, second_sums_ptr, + trailing_buf + kHalfBlockOffset); + } else { + HalfPack8bitAvx512(src_ptr, input_xor, zerobuf, src_stride, + remaining_src_cols, src_rows, packed_ptr, sums_ptr, + trailing_buf); + ZeroHalf8bitAvx512(src_rows, zerobuf[0] ^ input_xor, + packed_ptr + kHalfBlockOffset); + // The kernel may not need the second half-blocks sums to be set. + if (second_sums_ptr) { + for (int i = 0; i < kHalfLayoutCols; ++i) { + second_sums_ptr[i] = (zerobuf[0] ^ input_xor) * ((src_rows + 3) & ~3); + } + } + } + constexpr int kChunkedRowMask = kNumRowChunks * Layout::kRows - 1; + const bool trailing_data = (src_rows & kChunkedRowMask) > 0; + // If the number of source rows is not a multiple of kChunkedRowMask, there + // will be data in the trailing buffer, + if (trailing_data > 0) { + const int non_trailing_rows = src_rows & ~kChunkedRowMask; + // Destination "rows" are padded to next highest multiple of Layout::kRows. + const int dst_rows = (src_rows + 3) & ~3; + const int trailing_rows = dst_rows - non_trailing_rows; + memcpy(packed_ptr + Layout::kCols * non_trailing_rows, trailing_buf, + Layout::kCols * trailing_rows * sizeof(std::int8_t)); + } +} + +void PackFloatAvx512(const float* src_ptr, const float* zerobuf, int src_stride, + int remaining_src_cols, int src_rows, float* packed_ptr) { + profiler::ScopeLabel label("Pack kAvx512 float"); + float trailing_buf[7 * 16]; + if (remaining_src_cols > 8) { + HalfPackFloatAvx512(src_ptr, zerobuf, src_stride, remaining_src_cols, + src_rows, packed_ptr, trailing_buf); + HalfPackFloatAvx512(src_ptr + src_stride * 8, zerobuf, src_stride, + remaining_src_cols - 8, src_rows, packed_ptr + 8, + trailing_buf + 8); + } else { + memset(trailing_buf, 0, sizeof(trailing_buf)); + HalfPackFloatAvx512(src_ptr, zerobuf, src_stride, remaining_src_cols, + src_rows, packed_ptr, trailing_buf); + ZeroHalfFloatAvx512(src_rows, packed_ptr + 8); + } + const int trailing_rows = src_rows & 7; + if (trailing_rows > 0) { + const int non_trailing_rows = src_rows & ~7; + memcpy(packed_ptr + 16 * non_trailing_rows, trailing_buf, + 16 * trailing_rows * sizeof(float)); + } +} + +#endif // RUY_PLATFORM(AVX512) && RUY_OPT_ENABLED(RUY_OPT_INTRINSICS) + +} // namespace ruy diff --git a/ruy/pack_avxvnni.cc b/ruy/pack_avxvnni.cc new file mode 100644 index 0000000..bb9a730 --- /dev/null +++ b/ruy/pack_avxvnni.cc @@ -0,0 +1,478 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "ruy/check_macros.h" +#include "ruy/matrix.h" +#include "ruy/opt_set.h" +#include "ruy/pack.h" +#include "ruy/path.h" +#include "ruy/platform.h" +#include "ruy/profiler/instrumentation.h" + +#if RUY_PLATFORM(AVX_VNNI) && RUY_OPT_ENABLED(RUY_OPT_INTRINSICS) +#include // IWYU pragma: keep +#endif + +namespace ruy { + +#if !(RUY_PLATFORM(AVX_VNNI) && RUY_OPT_ENABLED(RUY_OPT_ASM)) + +void Pack8bitAvxVnni(const std::int8_t* src_ptr, std::int8_t input_xor, + const std::int8_t* zerobuf, int src_stride, + int remaining_src_cols, int src_rows, + std::int8_t* packed_ptr, std::int32_t* sums_ptr) { + // CPU-ID-based checks should disable the path that would reach this point. + RUY_DCHECK(false); +} + +void PackFloatAvxVnni(const float* src_ptr, const float* zerobuf, + int src_stride, int remaining_src_cols, int src_rows, + float* packed_ptr) { + // CPU-ID-based checks should disable the path that would reach this point. + RUY_DCHECK(false); +} + +#else // RUY_PLATFORM(AVX_VNNI) && RUY_OPT_ENABLED(RUY_OPT_ASM) + +// The first int8_t template parameter is arbitrary: this routine is common to +// all 8-bit source matrix types. +using PackImpl8bitAvxVnni = + PackImpl, + std::int8_t, std::int8_t, std::int32_t>; + +namespace { + +inline void ZeroHalf8bitAvxVnni(int src_rows, std::int8_t packed_zero_point, + std::int8_t* packed_ptr) { + const int non_trailing_blocks = (src_rows & ~31) >> 2; + // This routine fills half blocks, and typically fills the second halves. Thus + // packed_ptr is already offset by 8*4. + for (int k = 0; k < non_trailing_blocks; ++k) { + for (int j = 0; j < (8 * 4); ++j) { + packed_ptr[16 * 4 * k + j] = packed_zero_point; + } + } +} + +inline void HalfPack8bitAvxVnni(const std::int8_t* src_ptr, + std::int8_t input_xor, + const std::int8_t* zerobuf, int src_stride, + int remaining_src_cols, int src_rows, + std::int8_t* packed_ptr, std::int32_t* sums_ptr, + std::int8_t* trailing_buf) { + std::int8_t in_data[8][8][4]; + + const std::int8_t* src_ptr0 = src_ptr; + const std::int8_t* src_ptr1 = src_ptr0 + src_stride; + const std::int8_t* src_ptr2 = src_ptr1 + src_stride; + const std::int8_t* src_ptr3 = src_ptr2 + src_stride; + const std::int8_t* src_ptr4 = src_ptr3 + src_stride; + const std::int8_t* src_ptr5 = src_ptr4 + src_stride; + const std::int8_t* src_ptr6 = src_ptr5 + src_stride; + const std::int8_t* src_ptr7 = src_ptr6 + src_stride; + std::int64_t src_inc0 = 8 * 4; + std::int64_t src_inc1 = 8 * 4; + std::int64_t src_inc2 = 8 * 4; + std::int64_t src_inc3 = 8 * 4; + std::int64_t src_inc4 = 8 * 4; + std::int64_t src_inc5 = 8 * 4; + std::int64_t src_inc6 = 8 * 4; + std::int64_t src_inc7 = 8 * 4; + if (remaining_src_cols < 8) { + if (remaining_src_cols <= 0) { + src_ptr0 = zerobuf; + src_inc0 = 0; + } + if (remaining_src_cols <= 1) { + src_ptr1 = zerobuf; + src_inc1 = 0; + } + if (remaining_src_cols <= 2) { + src_ptr2 = zerobuf; + src_inc2 = 0; + } + if (remaining_src_cols <= 3) { + src_ptr3 = zerobuf; + src_inc3 = 0; + } + if (remaining_src_cols <= 4) { + src_ptr4 = zerobuf; + src_inc4 = 0; + } + if (remaining_src_cols <= 5) { + src_ptr5 = zerobuf; + src_inc5 = 0; + } + if (remaining_src_cols <= 6) { + src_ptr6 = zerobuf; + src_inc6 = 0; + } + src_ptr7 = zerobuf; + src_inc7 = 0; + } + + const std::int8_t zero_point = zerobuf[0]; + + if (sums_ptr) { + for (int i = 0; i < 8; ++i) { + sums_ptr[i] = 0; + } + } + + // The overall packing effectively pads the source rows to + // (src_rows + 63) & ~63. The iteration over k may skip when m=1, and then we + // only pack for (src_rows + 31) & ~31. When there is an incomplete + // destination block, this is stored into trailing_buf instead of packed_ptr. + for (int k = 0; k < src_rows; k += 16 * 4) { + for (int m = 0; m < 2; ++m) { + // Available source rows. + // If this is less than 0 (for m=1), we skip, having filled trailing + // buffer for m=0. Also, if source rows is zero on m=1, then we filled + // exactly to the end of the column in the packed buffer. + const int packed_rows = src_rows - k - 8 * m * 4; + // Effectively, + // packed_rows = std::max(0, std::min(8, src_rows - k - 8 * m)); + // but treat each case separately. + if (packed_rows >= (8 * 4)) { + for (int i = 0; i < 8; ++i) { + for (int s = 0; s < 4; ++s) { + in_data[0][i][s] = src_ptr0[i * 4 + s]; + in_data[1][i][s] = src_ptr1[i * 4 + s]; + in_data[2][i][s] = src_ptr2[i * 4 + s]; + in_data[3][i][s] = src_ptr3[i * 4 + s]; + in_data[4][i][s] = src_ptr4[i * 4 + s]; + in_data[5][i][s] = src_ptr5[i * 4 + s]; + in_data[6][i][s] = src_ptr6[i * 4 + s]; + in_data[7][i][s] = src_ptr7[i * 4 + s]; + } + } + for (int i = 0; i < 8; ++i) { + for (int j = 0; j < 8; ++j) { + for (int s = 0; s < 4; ++s) { + packed_ptr[(16 * i + j) * 4 + s] = + static_cast(in_data[j][i][s] ^ input_xor); + } + if (sums_ptr) { + for (int s = 0; s < 4; ++s) { + sums_ptr[j] += in_data[j][i][s] ^ input_xor; + } + } + } + } + } else if (packed_rows > 0) { + RUY_DCHECK_LT(packed_rows >> 2, 8); + int i = 0; + for (; i < (packed_rows >> 2); ++i) { + for (int s = 0; s < 4; ++s) { + in_data[0][i][s] = src_ptr0[i * 4 + s]; + in_data[1][i][s] = src_ptr1[i * 4 + s]; + in_data[2][i][s] = src_ptr2[i * 4 + s]; + in_data[3][i][s] = src_ptr3[i * 4 + s]; + in_data[4][i][s] = src_ptr4[i * 4 + s]; + in_data[5][i][s] = src_ptr5[i * 4 + s]; + in_data[6][i][s] = src_ptr6[i * 4 + s]; + in_data[7][i][s] = src_ptr7[i * 4 + s]; + } + } + if (i < ((packed_rows + 3) >> 2)) { + int s = 0; + for (; s < (packed_rows & 3); ++s) { + in_data[0][i][s] = src_ptr0[i * 4 + s]; + in_data[1][i][s] = src_ptr1[i * 4 + s]; + in_data[2][i][s] = src_ptr2[i * 4 + s]; + in_data[3][i][s] = src_ptr3[i * 4 + s]; + in_data[4][i][s] = src_ptr4[i * 4 + s]; + in_data[5][i][s] = src_ptr5[i * 4 + s]; + in_data[6][i][s] = src_ptr6[i * 4 + s]; + in_data[7][i][s] = src_ptr7[i * 4 + s]; + } + RUY_DCHECK_LE(s, 4); + for (; s < 4; ++s) { + for (int j = 0; j < 8; ++j) { + in_data[j][i][s] = zero_point; + } + } + ++i; + } + // We do not care what goes into the trailing buffer, but we want + // in_data[...] ^ input_xor == 0 for irrelevant values in the summation. + // + // It might prove better in optimized code to pad uniformly with + // zero_point, and compensate by initializing the summations with the + // compensating offset, effectively + // ((input_xor - zero_point) ^ input_xor) * + // 4 * (8 - ((packed_rows + 3) >> 2)). + for (; i < 8; ++i) { + for (int s = 0; s < 4; ++s) { + for (int j = 0; j < 8; ++j) { + in_data[j][i][s] = input_xor; + } + } + } + // We loop through [0, 8) rather than [0, (packed_rows + 3) >> 2), since + // that emulates what we might do in fully-optimized code. + if (sums_ptr) { + for (int i = 0; i < 8; ++i) { + for (int j = 0; j < 8; ++j) { + for (int s = 0; s < 4; ++s) { + trailing_buf[(16 * i + j) * 4 + s] = + static_cast(in_data[j][i][s] ^ input_xor); + sums_ptr[j] += in_data[j][i][s] ^ input_xor; + } + } + } + } else { + for (int i = 0; i < 8; ++i) { + for (int j = 0; j < 8; ++j) { + for (int s = 0; s < 4; ++s) { + trailing_buf[(16 * i + j) * 4 + s] = + static_cast(in_data[j][i][s] ^ input_xor); + } + } + } + } + } + + packed_ptr += 16 * 8 * 4; + src_ptr0 += src_inc0; + src_ptr1 += src_inc1; + src_ptr2 += src_inc2; + src_ptr3 += src_inc3; + src_ptr4 += src_inc4; + src_ptr5 += src_inc5; + src_ptr6 += src_inc6; + src_ptr7 += src_inc7; + } + } +} + +inline void HalfPackFloatAvxVnni(const float* src_ptr, const float* zerobuf, + int src_stride, int remaining_src_cols, + int src_rows, float* packed_ptr, + float* trailing_buf) { + float in_data[8][8]; + + const float* src_ptr0 = src_ptr; + const float* src_ptr1 = src_ptr0 + src_stride; + const float* src_ptr2 = src_ptr1 + src_stride; + const float* src_ptr3 = src_ptr2 + src_stride; + const float* src_ptr4 = src_ptr3 + src_stride; + const float* src_ptr5 = src_ptr4 + src_stride; + const float* src_ptr6 = src_ptr5 + src_stride; + const float* src_ptr7 = src_ptr6 + src_stride; + std::int64_t src_inc0 = 8; + std::int64_t src_inc1 = 8; + std::int64_t src_inc2 = 8; + std::int64_t src_inc3 = 8; + std::int64_t src_inc4 = 8; + std::int64_t src_inc5 = 8; + std::int64_t src_inc6 = 8; + std::int64_t src_inc7 = 8; + if (remaining_src_cols < 8) { + if (remaining_src_cols <= 0) { + src_ptr0 = zerobuf; + src_inc0 = 0; + } + if (remaining_src_cols <= 1) { + src_ptr1 = zerobuf; + src_inc1 = 0; + } + if (remaining_src_cols <= 2) { + src_ptr2 = zerobuf; + src_inc2 = 0; + } + if (remaining_src_cols <= 3) { + src_ptr3 = zerobuf; + src_inc3 = 0; + } + if (remaining_src_cols <= 4) { + src_ptr4 = zerobuf; + src_inc4 = 0; + } + if (remaining_src_cols <= 5) { + src_ptr5 = zerobuf; + src_inc5 = 0; + } + if (remaining_src_cols <= 6) { + src_ptr6 = zerobuf; + src_inc6 = 0; + } + src_ptr7 = zerobuf; + src_inc7 = 0; + } + + for (int k = 0; k < src_rows; k += 16) { + for (int m = 0; m < 2; ++m) { + const int packed_rows = src_rows - k - 8 * m; + // Effectively, + // packed_rows = std::max(0, std::min(8, src_rows - k - 8 * m)); + // but treat each case separately. + if (packed_rows > 7) { + for (int i = 0; i < 8; ++i) { + in_data[0][i] = src_ptr0[i]; + in_data[1][i] = src_ptr1[i]; + in_data[2][i] = src_ptr2[i]; + in_data[3][i] = src_ptr3[i]; + in_data[4][i] = src_ptr4[i]; + in_data[5][i] = src_ptr5[i]; + in_data[6][i] = src_ptr6[i]; + in_data[7][i] = src_ptr7[i]; + } + for (int i = 0; i < 8; ++i) { + for (int j = 0; j < 8; ++j) { + packed_ptr[16 * i + j] = in_data[j][i]; + } + } + } else if (packed_rows > 0) { + for (int i = 0; i < packed_rows; ++i) { + in_data[0][i] = src_ptr0[i]; + in_data[1][i] = src_ptr1[i]; + in_data[2][i] = src_ptr2[i]; + in_data[3][i] = src_ptr3[i]; + in_data[4][i] = src_ptr4[i]; + in_data[5][i] = src_ptr5[i]; + in_data[6][i] = src_ptr6[i]; + in_data[7][i] = src_ptr7[i]; + } + for (int i = packed_rows; i < 8; ++i) { + in_data[0][i] = 0.0f; + in_data[1][i] = 0.0f; + in_data[2][i] = 0.0f; + in_data[3][i] = 0.0f; + in_data[4][i] = 0.0f; + in_data[5][i] = 0.0f; + in_data[6][i] = 0.0f; + in_data[7][i] = 0.0f; + } + // We loop through [0, 7) rather than [0, packed_rows), since that + // emulates what we might do in fully-optimized code. + for (int i = 0; i < 7; ++i) { + for (int j = 0; j < 8; ++j) { + trailing_buf[16 * i + j] = in_data[j][i]; + } + } + } + + packed_ptr += 16 * 8; + src_ptr0 += src_inc0; + src_ptr1 += src_inc1; + src_ptr2 += src_inc2; + src_ptr3 += src_inc3; + src_ptr4 += src_inc4; + src_ptr5 += src_inc5; + src_ptr6 += src_inc6; + src_ptr7 += src_inc7; + } + } +} + +inline void ZeroHalfFloatAvxVnni(int src_rows, float* packed_ptr) { + const int non_trailing_rows = src_rows & ~7; + for (int k = 0; k < non_trailing_rows; ++k) { + for (int j = 0; j < 8; ++j) { + packed_ptr[j] = 0.0f; + } + packed_ptr += 16; + } +} + +} // namespace. + +// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. +// Optimization is not finished. In particular the dimensions of the kernel +// blocks can be changed as desired. +// +// When removing this comment, update profiling label below. +void Pack8bitAvxVnni(const std::int8_t* src_ptr, std::int8_t input_xor, + const std::int8_t* zerobuf, int src_stride, + int remaining_src_cols, int src_rows, + std::int8_t* packed_ptr, std::int32_t* sums_ptr) { + profiler::ScopeLabel label("Pack kAvxVnni 8bit (UNFINISHED)"); + + // Each packed block is 4*16, and there are normally 8. The trailing block is + // only slightly shorter. + std::int8_t trailing_buf[8 * 16 * 4]; + memset(trailing_buf, 0, 8 * 16 * 4 * sizeof(std::int8_t)); + + std::int32_t* second_sums_ptr = sums_ptr ? sums_ptr + 8 : nullptr; + if (remaining_src_cols > 8) { + HalfPack8bitAvxVnni(src_ptr, input_xor, zerobuf, src_stride, + remaining_src_cols, src_rows, packed_ptr, sums_ptr, + trailing_buf); + HalfPack8bitAvxVnni(src_ptr + src_stride * 8, input_xor, zerobuf, + src_stride, remaining_src_cols - 8, src_rows, + packed_ptr + 8 * 4, second_sums_ptr, + trailing_buf + 8 * 4); + } else { + HalfPack8bitAvxVnni(src_ptr, input_xor, zerobuf, src_stride, + remaining_src_cols, src_rows, packed_ptr, sums_ptr, + trailing_buf); + ZeroHalf8bitAvxVnni(src_rows, zerobuf[0] ^ input_xor, packed_ptr + 8 * 4); + // The kernel may not need the second half-blocks sums to be set. + if (second_sums_ptr) { + for (int i = 0; i < 8; ++i) { + second_sums_ptr[i] = (zerobuf[0] ^ input_xor) * ((src_rows + 3) & ~3); + } + } + } + const bool trailing_data = (src_rows & 31) > 0; + // If the number of source rows is not a multiple of 32, there will be data in + // the trailing buffer, + if (trailing_data > 0) { + const int non_trailing_rows = src_rows & ~31; + // Destination "rows" are padded to next highest multiple of 4. + const int dst_rows = (src_rows + 3) & ~3; + const int trailing_rows = dst_rows - non_trailing_rows; + memcpy(packed_ptr + 16 * non_trailing_rows, trailing_buf, + 16 * trailing_rows * sizeof(std::int8_t)); + } +} + +// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. +// Optimization is not finished. In particular the dimensions of the kernel +// blocks can be changed as desired. +// +// When removing this comment, update profiling label below. +void PackFloatAvxVnni(const float* src_ptr, const float* zerobuf, + int src_stride, int remaining_src_cols, int src_rows, + float* packed_ptr) { + profiler::ScopeLabel label("Pack kAvxVnni float (UNFINISHED)"); + float trailing_buf[7 * 16]; + if (remaining_src_cols > 8) { + HalfPackFloatAvxVnni(src_ptr, zerobuf, src_stride, remaining_src_cols, + src_rows, packed_ptr, trailing_buf); + HalfPackFloatAvxVnni(src_ptr + src_stride * 8, zerobuf, src_stride, + remaining_src_cols - 8, src_rows, packed_ptr + 8, + trailing_buf + 8); + } else { + memset(trailing_buf, 0, sizeof(trailing_buf)); + HalfPackFloatAvxVnni(src_ptr, zerobuf, src_stride, remaining_src_cols, + src_rows, packed_ptr, trailing_buf); + ZeroHalfFloatAvxVnni(src_rows, packed_ptr + 8); + } + const int trailing_rows = src_rows & 7; + if (trailing_rows > 0) { + const int non_trailing_rows = src_rows & ~7; + memcpy(packed_ptr + 16 * non_trailing_rows, trailing_buf, + 16 * trailing_rows * sizeof(float)); + } +} + +#endif // RUY_PLATFORM(AVX_VNNI) && RUY_OPT_ENABLED(RUY_OPT_INTRINSICS) + +} // namespace ruy diff --git a/ruy/pack_common.h b/ruy/pack_common.h new file mode 100644 index 0000000..5c03afd --- /dev/null +++ b/ruy/pack_common.h @@ -0,0 +1,246 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// # What is "packing"? +// +// Before feeding data to the gemm kernels (the parts of Ruy that do lots +// of multiply-add operations), Ruy first performs a data transformation (which +// we call "packing") on the input matrices. This transformation has two main +// goals: +// - rearrange data into blocks that are a convenient size/layout for the gemm +// kernels to consume. This helps make the memory access pattern of the gemm +// kernel simpler and more contiguous, and puts the data in a layout most +// convenient for specific arithmetic instructions in the gemm kernel. +// - compute row/column sums needed for handling quantization with non-symmetric +// zero points. +// +// # Simplified algorithmic analysis of packing +// +// Packing is a relatively simple transformation which does a small constant +// amount of work on each element of an input matrix, and hence for an NxM +// matrix performs O(N*M) work. If N and M are of the same order, then this is +// O(N^2) work. +// +// A NxKxM matrix multiplication requires N*K*M multiply-accumulate operations. +// Note that if N, K, and M are all the same order, then the number of +// multiply-accumulate operations is O(N^3). +// +// Thus, the O(N^2) cost of packing is small compared to the O(N^3) work, in the +// case of all dimensions being roughly the same order. +// +// # Packing cost can be significant +// +// When matrix * matrix multiplications begin to look more like matrix * vector +// multiplications, packing cost can become significant. We sometimes call these +// cases "gemv-like". +// +// Continuing the algorithmic analysis above, if we consider a case where an +// NxKxM matrix multiplication has either N = O(1) or M = O(1), then the +// situation is different. In this case, the multiply-accumulate work is only +// quadratic, so the quadratic cost of packing can be come significant. +// +// Another way to say this is that the cost of packing an input matrix (either +// the LHS or RHS) is amortized across the non-depth dimension of the opposite +// input matrix. Thus, when the LHS has very few rows or the RHS has very few +// columns, the cost of packing the opposite input matrix can become +// significant. +// +// As a rough rule of thumb, the cost of packing starts to become significant +// when either N or M is below 32 (and other dimensions are hundreds), with very +// significant packing costs at 8 or below. This varies by data type, Path, and +// tuning, so these numbers are only rough guides. +// +// One practical use case that is affected by this is inference of +// fully connected neural network layers with a low batch size. The weight +// matrix (which is a constant for inference) is the one affected by significant +// packing cost. +// +// Ruy provides an API in ruy_advanced.h for advanced users to pre-pack +// input matrices that are affected by significant packing costs. +// +// # Implementation notes +// +// Ruy's packing routines always operate on a range of columns and can be +// applied to either the LHS or RHS. This is possible because Ruy internally +// implements a TrMul, so the accumulation along depth is done along columns of +// both the LHS and RHS (whereas for a normal Mul the accumulation along depth +// for the LHS is along rows). As another example, we are always computing +// column sums for quantization (and never row sums, since the LHS is +// transposed). + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PACK_COMMON_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PACK_COMMON_H_ + +#include + +#include "ruy/check_macros.h" +#include "ruy/common.h" +#include "ruy/internal_matrix.h" +#include "ruy/matrix.h" +#include "ruy/opt_set.h" +#include "ruy/path.h" +#include "ruy/platform.h" +#include "ruy/profiler/instrumentation.h" +#include "ruy/tune.h" + +namespace ruy { + +template +struct PackedTypeImpl { + using Type = Scalar; +}; + +#if RUY_PLATFORM(NEON_32) +struct PackParams8bit { + const void* src_ptr0; + const void* src_ptr1; + const void* src_ptr2; + const void* src_ptr3; + const std::int32_t* sums_ptr; + const std::int8_t* packed_ptr; + int src_inc0; + int src_inc1; + int src_inc2; + int src_inc3; + int src_rows; + int src_zero_point; + int input_xor; +}; + +inline void MakePackParams8bit(const void* src_ptr0, const void* src_ptr1, + const void* src_ptr2, const void* src_ptr3, + const std::int32_t* sums_ptr, + const std::int8_t* packed_ptr, int src_inc0, + int src_inc1, int src_inc2, int src_inc3, + int src_rows, int src_zero_point, int input_xor, + PackParams8bit* params) { + params->src_ptr0 = src_ptr0; + params->src_ptr1 = src_ptr1; + params->src_ptr2 = src_ptr2; + params->src_ptr3 = src_ptr3; + params->sums_ptr = sums_ptr; + params->packed_ptr = packed_ptr; + params->src_inc0 = src_inc0; + params->src_inc1 = src_inc1; + params->src_inc2 = src_inc2; + params->src_inc3 = src_inc3; + params->src_rows = src_rows; + params->src_zero_point = src_zero_point; + params->input_xor = input_xor; +} +#endif + +#if RUY_PLATFORM(NEON) +template <> +struct PackedTypeImpl { + using Type = std::int8_t; +}; +template <> +struct PackedTypeImpl { + using Type = std::int8_t; +}; +#elif RUY_PLATFORM(X86) +template <> +struct PackedTypeImpl { + using Type = std::int8_t; +}; +template <> +struct PackedTypeImpl { + using Type = std::int8_t; +}; +template <> +struct PackedTypeImpl { + using Type = std::int8_t; +}; +template <> +struct PackedTypeImpl { + using Type = std::int8_t; +}; +#endif + +template +using PackedType = typename PackedTypeImpl::Type; + +template +PackedScalar Pack(Scalar x) { + return x - SymmetricZeroPoint() + SymmetricZeroPoint(); +} + +template +struct PackImpl {}; + +#define RUY_INHERIT_PACK(PARENT, CHILD) \ + template \ + struct PackImpl \ + : PackImpl { \ + }; + +template +struct PackImpl { + static void Run(Tuning, const Matrix& src_matrix, + PackedMatrix* packed_matrix, int start_col, + int end_col) { + profiler::ScopeLabel label("Pack (generic)"); + RUY_DCHECK_EQ((end_col - start_col) % FixedKernelLayout::kCols, 0); + SumsType* sums = packed_matrix->sums; + for (int col = start_col; col < end_col; col++) { + SumsType accum = 0; + for (int row = 0; row < packed_matrix->layout.rows; row++) { + PackedScalar packed_val; + if (col < src_matrix.layout.cols && row < src_matrix.layout.rows) { + packed_val = Pack(Element(src_matrix, row, col)); + } else { + packed_val = packed_matrix->zero_point; + } + accum += packed_val; + *ElementPtr(packed_matrix, row, col) = packed_val; + } + if (sums) { + sums[col] = accum; + } + } + } +}; + +#if RUY_PLATFORM(NEON) +RUY_INHERIT_PACK(Path::kStandardCpp, Path::kNeon) +RUY_INHERIT_PACK(Path::kNeon, Path::kNeonDotprod) +#elif RUY_PLATFORM(X86) +RUY_INHERIT_PACK(Path::kStandardCpp, Path::kSse42) +RUY_INHERIT_PACK(Path::kSse42, Path::kAvx2) +RUY_INHERIT_PACK(Path::kAvx2, Path::kAvx512) +RUY_INHERIT_PACK(Path::kAvx512, Path::kAvxVnni) +#endif + +// Main entry point for packing. +template +void RunPack(Tuning tuning, const DMatrix& src_matrix, PMatrix* packed_matrix, + int start_col, int end_col) { + using SumsType = typename PackedMatrix::SumsType; + Matrix src = ToMatrix(src_matrix); + PackedMatrix packed = + ToPackedMatrix(*packed_matrix); + PackImpl::Run( + tuning, src, &packed, start_col, end_col); +} + +} // namespace ruy + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PACK_COMMON_H_ diff --git a/ruy/pack_sse42.cc b/ruy/pack_sse42.cc new file mode 100644 index 0000000..90c7250 --- /dev/null +++ b/ruy/pack_sse42.cc @@ -0,0 +1,471 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "ruy/check_macros.h" +#include "ruy/matrix.h" +#include "ruy/opt_set.h" +#include "ruy/pack.h" +#include "ruy/path.h" +#include "ruy/platform.h" +#include "ruy/profiler/instrumentation.h" + +#if RUY_PLATFORM(SSE42) && RUY_OPT_ENABLED(RUY_OPT_INTRINSICS) +#include // IWYU pragma: keep +#endif + +namespace ruy { + +#if !(RUY_PLATFORM(SSE42) && RUY_OPT_ENABLED(RUY_OPT_ASM)) + +void Pack8bitSse42(const std::int8_t* src_ptr, std::int8_t input_xor, + const std::int8_t* zerobuf, int src_stride, + int remaining_src_cols, int src_rows, + std::int8_t* packed_ptr, std::int32_t* sums_ptr) { + // CPU-ID-based checks should disable the path that would reach this point. + RUY_DCHECK(false); +} + +void PackFloatSse42(const float* src_ptr, const float* zerobuf, int src_stride, + int remaining_src_cols, int src_rows, float* packed_ptr) { + // CPU-ID-based checks should disable the path that would reach this point. + RUY_DCHECK(false); +} + +#else // RUY_PLATFORM(SSE42) && RUY_OPT_ENABLED(RUY_OPT_ASM) + +// The first int8_t template parameter is arbitrary: this routine is common to +// all 8-bit source matrix types. +using PackImpl8bitSse42 = + PackImpl, + std::int8_t, std::int8_t, std::int32_t>; + +using PackImplFloatSse42 = + PackImpl, float, + float, float>; + +namespace { + +inline void Pack8bitSse42Packer(const std::int8_t* src_ptr, + std::int8_t input_xor, + const std::int8_t* zerobuf, int src_stride, + int remaining_src_cols, int src_rows, + std::int8_t* packed_ptr, std::int32_t* sums_ptr, + std::int8_t* trailing_buf) { + using Layout = PackImpl8bitSse42::Layout; + RUY_DCHECK_EQ(Layout::kCols, 8); + RUY_DCHECK_EQ(Layout::kRows, 4); + // Each Layout::Rows is 4 contiguous input, contiguous packed elements. + // We process 8 of these chunks at a time, padding short input chunks. + constexpr int kNumRowChunks = 8; + constexpr int kNumChunkedSrcRows = kNumRowChunks * Layout::kRows; + + std::int8_t in_data[Layout::kCols][kNumRowChunks][Layout::kRows]; + + const std::int8_t* src_ptr0 = src_ptr; + const std::int8_t* src_ptr1 = src_ptr0 + src_stride; + const std::int8_t* src_ptr2 = src_ptr1 + src_stride; + const std::int8_t* src_ptr3 = src_ptr2 + src_stride; + const std::int8_t* src_ptr4 = src_ptr3 + src_stride; + const std::int8_t* src_ptr5 = src_ptr4 + src_stride; + const std::int8_t* src_ptr6 = src_ptr5 + src_stride; + const std::int8_t* src_ptr7 = src_ptr6 + src_stride; + std::int64_t src_inc0 = kNumChunkedSrcRows; + std::int64_t src_inc1 = kNumChunkedSrcRows; + std::int64_t src_inc2 = kNumChunkedSrcRows; + std::int64_t src_inc3 = kNumChunkedSrcRows; + std::int64_t src_inc4 = kNumChunkedSrcRows; + std::int64_t src_inc5 = kNumChunkedSrcRows; + std::int64_t src_inc6 = kNumChunkedSrcRows; + std::int64_t src_inc7 = kNumChunkedSrcRows; + // Handle cases where source does not have Layout::kCols (8) columns. + if (remaining_src_cols < 8) { + if (remaining_src_cols <= 0) { + src_ptr0 = zerobuf; + src_inc0 = 0; + } + if (remaining_src_cols <= 1) { + src_ptr1 = zerobuf; + src_inc1 = 0; + } + if (remaining_src_cols <= 2) { + src_ptr2 = zerobuf; + src_inc2 = 0; + } + if (remaining_src_cols <= 3) { + src_ptr3 = zerobuf; + src_inc3 = 0; + } + if (remaining_src_cols <= 4) { + src_ptr4 = zerobuf; + src_inc4 = 0; + } + if (remaining_src_cols <= 5) { + src_ptr5 = zerobuf; + src_inc5 = 0; + } + if (remaining_src_cols <= 6) { + src_ptr6 = zerobuf; + src_inc6 = 0; + } + src_ptr7 = zerobuf; + src_inc7 = 0; + } + + const std::int8_t zero_point = zerobuf[0]; + + if (sums_ptr) { + // i: Layout::kCols. + for (int i = 0; i < 8; ++i) { + sums_ptr[i] = 0; + } + } + + // The overall packing effectively pads the source rows to + // (src_rows + 63) & ~63. The iteration over k may skip when m=1, and then we + // only pack for (src_rows + 31) & ~31. When there is an incomplete + // destination block, this is stored into trailing_buf instead of packed_ptr. + for (int k = 0; k < src_rows; k += kNumChunkedSrcRows) { + // Available source rows. + // If this is less than 0 (for m=1), we skip, having filled trailing + // buffer for m=0. Also, if source rows is zero on m=1, then we filled + // exactly to the end of the column in the packed buffer. + const int available_src_rows = src_rows - k; + // Effectively, + // available rows = std::max(0, std::min(8, src_rows - k)); + // treat each case separately. + if (available_src_rows >= kNumChunkedSrcRows) { + // i: chunks, s: Layout::Rows. + for (int i = 0; i < 8; ++i) { + for (int s = 0; s < 4; ++s) { + in_data[0][i][s] = src_ptr0[i * 4 + s]; + in_data[1][i][s] = src_ptr1[i * 4 + s]; + in_data[2][i][s] = src_ptr2[i * 4 + s]; + in_data[3][i][s] = src_ptr3[i * 4 + s]; + in_data[4][i][s] = src_ptr4[i * 4 + s]; + in_data[5][i][s] = src_ptr5[i * 4 + s]; + in_data[6][i][s] = src_ptr6[i * 4 + s]; + in_data[7][i][s] = src_ptr7[i * 4 + s]; + } + } + // i: chunks, j: Layout::kCols, s: Layout::Rows. + for (int i = 0; i < 8; ++i) { + for (int j = 0; j < 8; ++j) { + for (int s = 0; s < 4; ++s) { + // 8 * 4 * i is offset for each block, that is + // (Layout::kCols * Layout::kRows * i) + packed_ptr[(8 * i + j) * 4 + s] = in_data[j][i][s] ^ input_xor; + } + if (sums_ptr) { + for (int s = 0; s < 4; ++s) { + sums_ptr[j] += in_data[j][i][s] ^ input_xor; + } + } + } + } + } else if (available_src_rows > 0) { + RUY_DCHECK_LT(available_src_rows, kNumChunkedSrcRows); + int i = 0; + // Consume chunks of 4 rows that are complete. + for (; i < (available_src_rows >> 2); ++i) { + for (int s = 0; s < 4; ++s) { + in_data[0][i][s] = src_ptr0[i * 4 + s]; + in_data[1][i][s] = src_ptr1[i * 4 + s]; + in_data[2][i][s] = src_ptr2[i * 4 + s]; + in_data[3][i][s] = src_ptr3[i * 4 + s]; + in_data[4][i][s] = src_ptr4[i * 4 + s]; + in_data[5][i][s] = src_ptr5[i * 4 + s]; + in_data[6][i][s] = src_ptr6[i * 4 + s]; + in_data[7][i][s] = src_ptr7[i * 4 + s]; + } + } + // Consume any incomplete chunk. + if (i < ((available_src_rows + 3) >> 2)) { + int s = 0; + for (; s < (available_src_rows & 3); ++s) { + in_data[0][i][s] = src_ptr0[i * 4 + s]; + in_data[1][i][s] = src_ptr1[i * 4 + s]; + in_data[2][i][s] = src_ptr2[i * 4 + s]; + in_data[3][i][s] = src_ptr3[i * 4 + s]; + in_data[4][i][s] = src_ptr4[i * 4 + s]; + in_data[5][i][s] = src_ptr5[i * 4 + s]; + in_data[6][i][s] = src_ptr6[i * 4 + s]; + in_data[7][i][s] = src_ptr7[i * 4 + s]; + } + RUY_DCHECK_LE(s, 4); + for (; s < 4; ++s) { + // j: Layout::kCols. + for (int j = 0; j < 8; ++j) { + in_data[j][i][s] = zero_point; + } + } + ++i; + } + // We do not care what goes into the trailing buffer, but we want + // in_data[...] ^ input_xor == 0 for irrelevant values in the summation. + // + // It might prove better in optimized code to pad uniformly with + // zero_point, and compensate by initializing the summations with the + // compensating offset, effectively + // ((input_xor - zero_point) ^ input_xor) * + // 4 * (8 - ((available_src_rows + 3) >> 2)). + for (; i < 8; ++i) { + for (int s = 0; s < 4; ++s) { + for (int j = 0; j < 8; ++j) { + in_data[j][i][s] = input_xor; + } + } + } + // We loop through [0, 8) rather than + // [0, (available_src_rows + 3) >> 2), since that emulates what we might + // do in fully-optimized code. + // + // i: chunks, j: Layout::kCols, s: Layout::Rows. + if (sums_ptr) { + for (int i = 0; i < 8; ++i) { + for (int j = 0; j < 8; ++j) { + for (int s = 0; s < 4; ++s) { + trailing_buf[(8 * i + j) * 4 + s] = in_data[j][i][s] ^ input_xor; + sums_ptr[j] = sums_ptr[j] + (in_data[j][i][s] ^ input_xor); + } + } + } + } else { + for (int i = 0; i < 8; ++i) { + for (int j = 0; j < 8; ++j) { + for (int s = 0; s < 4; ++s) { + trailing_buf[(8 * i + j) * 4 + s] = in_data[j][i][s] ^ input_xor; + } + } + } + } + } + + packed_ptr += 8 * kNumChunkedSrcRows; + src_ptr0 += src_inc0; + src_ptr1 += src_inc1; + src_ptr2 += src_inc2; + src_ptr3 += src_inc3; + src_ptr4 += src_inc4; + src_ptr5 += src_inc5; + src_ptr6 += src_inc6; + src_ptr7 += src_inc7; + } +} + +inline void PackFloatSse42Packer(const float* src_ptr, const float* zerobuf, + int src_stride, int remaining_src_cols, + int src_rows, float* packed_ptr, + float* trailing_buf) { + using Layout = PackImplFloatSse42::Layout; + RUY_DCHECK_EQ(Layout::kCols, 8); + RUY_DCHECK_EQ(Layout::kRows, 1); + + // This packing amounts to tranposition of 8x8 blocks. + static constexpr int kPackCols = 8; // Source cols packed together. + static constexpr int kPackRows = 8; // Short input is padded. + + float in_data[kPackCols][kPackRows]; + + const float* src_ptr0 = src_ptr; + const float* src_ptr1 = src_ptr0 + src_stride; + const float* src_ptr2 = src_ptr1 + src_stride; + const float* src_ptr3 = src_ptr2 + src_stride; + const float* src_ptr4 = src_ptr3 + src_stride; + const float* src_ptr5 = src_ptr4 + src_stride; + const float* src_ptr6 = src_ptr5 + src_stride; + const float* src_ptr7 = src_ptr6 + src_stride; + std::int64_t src_inc0 = 8; + std::int64_t src_inc1 = 8; + std::int64_t src_inc2 = 8; + std::int64_t src_inc3 = 8; + std::int64_t src_inc4 = 8; + std::int64_t src_inc5 = 8; + std::int64_t src_inc6 = 8; + std::int64_t src_inc7 = 8; + // Handle cases where source does not have kPackDim (8) columns. + if (remaining_src_cols < kPackCols) { + if (remaining_src_cols <= 0) { + src_ptr0 = zerobuf; + src_inc0 = 0; + } + if (remaining_src_cols <= 1) { + src_ptr1 = zerobuf; + src_inc1 = 0; + } + if (remaining_src_cols <= 2) { + src_ptr2 = zerobuf; + src_inc2 = 0; + } + if (remaining_src_cols <= 3) { + src_ptr3 = zerobuf; + src_inc3 = 0; + } + if (remaining_src_cols <= 4) { + src_ptr4 = zerobuf; + src_inc4 = 0; + } + if (remaining_src_cols <= 5) { + src_ptr5 = zerobuf; + src_inc5 = 0; + } + if (remaining_src_cols <= 6) { + src_ptr6 = zerobuf; + src_inc6 = 0; + } + src_ptr7 = zerobuf; + src_inc7 = 0; + } + + for (int k = 0; k < src_rows; k += kPackRows) { + const int available_src_rows = src_rows - k; + // Effectively, + // available_src_rows = std::max(0, std::min(kPackDim, src_rows - k)); + // but treat each case separately. + if (available_src_rows >= kPackRows) { + for (int i = 0; i < 8; ++i) { + in_data[0][i] = src_ptr0[i]; + in_data[1][i] = src_ptr1[i]; + in_data[2][i] = src_ptr2[i]; + in_data[3][i] = src_ptr3[i]; + in_data[4][i] = src_ptr4[i]; + in_data[5][i] = src_ptr5[i]; + in_data[6][i] = src_ptr6[i]; + in_data[7][i] = src_ptr7[i]; + } + for (int i = 0; i < 8; ++i) { + for (int j = 0; j < 8; ++j) { + packed_ptr[8 * i + j] = in_data[j][i]; + } + } + } else if (available_src_rows > 0) { + for (int i = 0; i < available_src_rows; ++i) { + in_data[0][i] = src_ptr0[i]; + in_data[1][i] = src_ptr1[i]; + in_data[2][i] = src_ptr2[i]; + in_data[3][i] = src_ptr3[i]; + in_data[4][i] = src_ptr4[i]; + in_data[5][i] = src_ptr5[i]; + in_data[6][i] = src_ptr6[i]; + in_data[7][i] = src_ptr7[i]; + } + for (int i = available_src_rows; i < kPackRows; ++i) { + in_data[0][i] = 0.0f; + in_data[1][i] = 0.0f; + in_data[2][i] = 0.0f; + in_data[3][i] = 0.0f; + in_data[4][i] = 0.0f; + in_data[5][i] = 0.0f; + in_data[6][i] = 0.0f; + in_data[7][i] = 0.0f; + } + // We loop through [0, 7) rather than [0, packed_rows), since that + // emulates what we might do in fully-optimized code. + // i: (kPackRows - 1), j: kPackCols. + for (int i = 0; i < 7; ++i) { + for (int j = 0; j < 8; ++j) { + trailing_buf[kPackRows * i + j] = in_data[j][i]; + } + } + } + + packed_ptr += kPackRows * kPackCols; + src_ptr0 += src_inc0; + src_ptr1 += src_inc1; + src_ptr2 += src_inc2; + src_ptr3 += src_inc3; + src_ptr4 += src_inc4; + src_ptr5 += src_inc5; + src_ptr6 += src_inc6; + src_ptr7 += src_inc7; + } +} + +} // namespace. + +// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. +// Optimization is not finished. In particular the dimensions of the kernel +// blocks can be changed as desired. +// +// When removing this comment, update profiling label below. +void Pack8bitSse42(const std::int8_t* src_ptr, std::int8_t input_xor, + const std::int8_t* zerobuf, int src_stride, + int remaining_src_cols, int src_rows, + std::int8_t* packed_ptr, std::int32_t* sums_ptr) { + profiler::ScopeLabel label("Pack kSse42 8bit (UNFINISHED)"); + + using Layout = PackImpl8bitSse42::Layout; + RUY_DCHECK_EQ(Layout::kCols, 8); + RUY_DCHECK_EQ(Layout::kRows, 4); + + // Each Layout::Rows is 4 contiguous input, contiguous packed elements. + // We process 8 of these chunks at a time, padding short input chunks. + static constexpr int kNumRowChunks = 8; // Short input is padded. + + // Each packed block is 4*8, and there are normally 8. The trailing block is + // only slightly shorter. + constexpr int kTrailingBufSize = + kNumRowChunks * Layout::kCols * Layout::kRows; + std::int8_t trailing_buf[kTrailingBufSize]; + memset(trailing_buf, 0, kTrailingBufSize * sizeof(std::int8_t)); + + Pack8bitSse42Packer(src_ptr, input_xor, zerobuf, src_stride, + remaining_src_cols, src_rows, packed_ptr, sums_ptr, + trailing_buf); + + constexpr int kChunkedRowMask = kNumRowChunks * Layout::kRows - 1; + const bool trailing_data = (src_rows & kChunkedRowMask) > 0; + // If the number of source rows is not a multiple of kChunkedRowMask, there + // will be data in the trailing buffer, + if (trailing_data > 0) { + const int non_trailing_rows = src_rows & ~kChunkedRowMask; + // Destination "rows" are padded to next highest multiple of Layout::kRows. + const int dst_rows = (src_rows + 3) & ~3; + const int trailing_rows = dst_rows - non_trailing_rows; + memcpy(packed_ptr + Layout::kCols * non_trailing_rows, trailing_buf, + Layout::kCols * trailing_rows * sizeof(std::int8_t)); + } +} + +// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. +// Optimization is not finished. In particular the dimensions of the kernel +// blocks can be changed as desired. +// +// When removing this comment, update profiling label below. +void PackFloatSse42(const float* src_ptr, const float* zerobuf, int src_stride, + int remaining_src_cols, int src_rows, float* packed_ptr) { + profiler::ScopeLabel label("Pack kSse42 float (UNFINISHED)"); + static constexpr int kPackCols = 8; // Source cols packed together. + static constexpr int kPackRows = 8; // Short input is padded. + float trailing_buf[(kPackRows - 1) * kPackCols]; + if (remaining_src_cols < 8) { + memset(trailing_buf, 0, sizeof(trailing_buf)); + } + PackFloatSse42Packer(src_ptr, zerobuf, src_stride, remaining_src_cols, + src_rows, packed_ptr, trailing_buf); + + const int trailing_rows = src_rows & (kPackRows - 1); + if (trailing_rows > 0) { + const int non_trailing_rows = src_rows & ~(kPackRows - 1); + memcpy(packed_ptr + kPackCols * non_trailing_rows, trailing_buf, + kPackCols * trailing_rows * sizeof(float)); + } +} + +#endif // RUY_PLATFORM(SSE42) && RUY_OPT_ENABLED(RUY_OPT_INTRINSICS) + +} // namespace ruy diff --git a/ruy/pack_x86.h b/ruy/pack_x86.h new file mode 100644 index 0000000..b777cc1 --- /dev/null +++ b/ruy/pack_x86.h @@ -0,0 +1,461 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// # What is "packing"? +// +// Before feeding data to the gemm kernels (the parts of Ruy that do lots +// of multiply-add operations), Ruy first performs a data transformation (which +// we call "packing") on the input matrices. This transformation has two main +// goals: +// - rearrange data into blocks that are a convenient size/layout for the gemm +// kernels to consume. This helps make the memory access pattern of the gemm +// kernel simpler and more contiguous, and puts the data in a layout most +// convenient for specific arithmetic instructions in the gemm kernel. +// - compute row/column sums needed for handling quantization with non-symmetric +// zero points. +// +// # Simplified algorithmic analysis of packing +// +// Packing is a relatively simple transformation which does a small constant +// amount of work on each element of an input matrix, and hence for an NxM +// matrix performs O(N*M) work. If N and M are of the same order, then this is +// O(N^2) work. +// +// A NxKxM matrix multiplication requires N*K*M multiply-accumulate operations. +// Note that if N, K, and M are all the same order, then the number of +// multiply-accumulate operations is O(N^3). +// +// Thus, the O(N^2) cost of packing is small compared to the O(N^3) work, in the +// case of all dimensions being roughly the same order. +// +// # Packing cost can be significant +// +// When matrix * matrix multiplications begin to look more like matrix * vector +// multiplications, packing cost can become significant. We sometimes call these +// cases "gemv-like". +// +// Continuing the algorithmic analysis above, if we consider a case where an +// NxKxM matrix multiplication has either N = O(1) or M = O(1), then the +// situation is different. In this case, the multiply-accumulate work is only +// quadratic, so the quadratic cost of packing can be come significant. +// +// Another way to say this is that the cost of packing an input matrix (either +// the LHS or RHS) is amortized across the non-depth dimension of the opposite +// input matrix. Thus, when the LHS has very few rows or the RHS has very few +// columns, the cost of packing the opposite input matrix can become +// significant. +// +// As a rough rule of thumb, the cost of packing starts to become significant +// when either N or M is below 32 (and other dimensions are hundreds), with very +// significant packing costs at 8 or below. This varies by data type, Path, and +// tuning, so these numbers are only rough guides. +// +// One practical use case that is affected by this is inference of +// fully connected neural network layers with a low batch size. The weight +// matrix (which is a constant for inference) is the one affected by significant +// packing cost. +// +// Ruy provides an API in ruy_advanced.h for advanced users to pre-pack +// input matrices that are affected by significant packing costs. +// +// # Implementation notes +// +// Ruy's packing routines always operate on a range of columns and can be +// applied to either the LHS or RHS. This is possible because Ruy internally +// implements a TrMul, so the accumulation along depth is done along columns of +// both the LHS and RHS (whereas for a normal Mul the accumulation along depth +// for the LHS is along rows). As another example, we are always computing +// column sums for quantization (and never row sums, since the LHS is +// transposed). + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PACK_X86_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PACK_X86_H_ + +#include +#include +#include + +#include "ruy/check_macros.h" +#include "ruy/common.h" +#include "ruy/internal_matrix.h" +#include "ruy/matrix.h" +#include "ruy/opt_set.h" +#include "ruy/pack_common.h" +#include "ruy/path.h" +#include "ruy/platform.h" +#include "ruy/profiler/instrumentation.h" +#include "ruy/tune.h" + +namespace ruy { + +#if RUY_PLATFORM(X86) +// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. +// Optimization is not finished. In particular the dimensions of the kernel +// blocks can be changed as desired. +// +// Note that source and zero buffers can be uint8 type, but in the packing +// function are reinterpreted as int8, and are XOR-ed with input_xor. +void Pack8bitSse42(const std::int8_t* src_ptr, std::int8_t input_xor, + const std::int8_t* zerobuf, int src_stride, + int remaining_src_cols, int src_rows, + std::int8_t* packed_ptr, std::int32_t* sums_ptr); + +template +struct PackImpl, Scalar, + std::int8_t, std::int32_t> { + static_assert(std::is_same::value || + std::is_same::value, + ""); + using Layout = FixedKernelLayout; + static constexpr std::int8_t kInputXor = + std::is_same::value ? 0 : 0x80; + + static void Run(Tuning tuning, const Matrix& src_matrix, + PackedMatrix* packed_matrix, int start_col, + int end_col) { + profiler::ScopeLabel label("Pack (SSE 4.2 8-bit)"); + + RUY_DCHECK(IsColMajor(src_matrix.layout)); + RUY_DCHECK(IsColMajor(packed_matrix->layout)); + RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0); + RUY_DCHECK_EQ(start_col % Layout::kCols, 0); + std::int32_t* sums = packed_matrix->sums; + Scalar zerobuf[Layout::kCols * Layout::kRows]; + memset(zerobuf, packed_matrix->zero_point ^ kInputXor, + Layout::kCols * Layout::kRows * sizeof(Scalar)); + for (int block_col = start_col; block_col < end_col; + block_col += Layout::kCols) { + std::int32_t* sums_ptr = sums ? sums + block_col : nullptr; + int src_stride = src_matrix.layout.stride; + const Scalar* src_ptr = src_matrix.data.get() + src_stride * block_col; + int remaining_src_cols = src_matrix.layout.cols - block_col; + + static constexpr int block_col_mask = ~(Layout::kCols - 1); // High bits. + std::int8_t* packed_ptr = + packed_matrix->data + + packed_matrix->layout.stride * (block_col & block_col_mask); + Pack8bitSse42(reinterpret_cast(src_ptr), kInputXor, + reinterpret_cast(zerobuf), src_stride, + remaining_src_cols, src_matrix.layout.rows, packed_ptr, + sums_ptr); + } + } +}; + +// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. +// Optimization is not finished. In particular the dimensions of the kernel +// blocks can be changed as desired. +// +void PackFloatSse42(const float* src_ptr, const float* zerobuf, int src_stride, + int remaining_src_cols, int src_rows, float* packed_ptr); + +template <> +struct PackImpl, float, + float, float> { + using Layout = FixedKernelLayout; + static void Run(Tuning, const Matrix& src_matrix, + PackedMatrix* packed_matrix, int start_col, + int end_col) { + profiler::ScopeLabel label("Pack (SSE 4.2 float)"); + + RUY_DCHECK(IsColMajor(src_matrix.layout)); + RUY_DCHECK(IsColMajor(packed_matrix->layout)); + RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0); + RUY_DCHECK_EQ(start_col % Layout::kCols, 0); + const float zerobuf[Layout::kCols] = { + 0.0f}; // Remainder default inits to 0.0f. + for (int block_col = start_col; block_col < end_col; + block_col += Layout::kCols) { + int src_stride = src_matrix.layout.stride; + const float* src_ptr = src_matrix.data.get() + src_stride * block_col; + int remaining_src_cols = src_matrix.layout.cols - block_col; + + static constexpr int block_col_mask = ~(Layout::kCols - 1); // High bits. + float* packed_ptr = + packed_matrix->data + + packed_matrix->layout.stride * (block_col & block_col_mask); + PackFloatSse42(src_ptr, zerobuf, src_stride, remaining_src_cols, + src_matrix.layout.rows, packed_ptr); + } + } +}; + +// Note that source and zero buffers can be uint8 type, but in the packing +// function are reinterpreted as int8, and are XOR-ed with input_xor. +void Pack8bitAvx2(const std::int8_t* src_ptr, std::int8_t input_xor, + const std::int8_t* zerobuf, int src_stride, + int remaining_src_cols, int src_rows, std::int8_t* packed_ptr, + std::int32_t* sums_ptr); + +template +struct PackImpl, Scalar, + std::int8_t, std::int32_t> { + static_assert(std::is_same::value || + std::is_same::value, + ""); + using Layout = FixedKernelLayout; + static constexpr std::int8_t kInputXor = + std::is_same::value ? 0 : 0x80; + + static void Run(Tuning tuning, const Matrix& src_matrix, + PackedMatrix* packed_matrix, int start_col, + int end_col) { + profiler::ScopeLabel label("Pack (AVX2 8-bit)"); + + RUY_DCHECK(IsColMajor(src_matrix.layout)); + RUY_DCHECK(IsColMajor(packed_matrix->layout)); + RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0); + RUY_DCHECK_EQ(start_col % Layout::kCols, 0); + std::int32_t* sums = packed_matrix->sums; + Scalar zerobuf[Layout::kCols * Layout::kRows]; + memset(zerobuf, packed_matrix->zero_point ^ kInputXor, + Layout::kCols * Layout::kRows * sizeof(Scalar)); + for (int block_col = start_col; block_col < end_col; + block_col += Layout::kCols) { + std::int32_t* sums_ptr = sums ? sums + block_col : nullptr; + int src_stride = src_matrix.layout.stride; + const Scalar* src_ptr = src_matrix.data.get() + src_stride * block_col; + int remaining_src_cols = src_matrix.layout.cols - block_col; + + static constexpr int block_col_mask = ~(Layout::kCols - 1); // High bits. + std::int8_t* packed_ptr = + packed_matrix->data + + packed_matrix->layout.stride * (block_col & block_col_mask); + Pack8bitAvx2(reinterpret_cast(src_ptr), kInputXor, + reinterpret_cast(zerobuf), src_stride, + remaining_src_cols, src_matrix.layout.rows, packed_ptr, + sums_ptr); + } + } +}; + +void PackFloatAvx2(const float* src_ptr, const float* zerobuf, int src_stride, + int remaining_src_cols, int src_rows, float* packed_ptr); + +template <> +struct PackImpl, float, + float, float> { + using Layout = FixedKernelLayout; + static void Run(Tuning, const Matrix& src_matrix, + PackedMatrix* packed_matrix, int start_col, + int end_col) { + profiler::ScopeLabel label("Pack (AVX2 float)"); + + RUY_DCHECK(IsColMajor(src_matrix.layout)); + RUY_DCHECK(IsColMajor(packed_matrix->layout)); + RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0); + RUY_DCHECK_EQ(start_col % Layout::kCols, 0); + const float zerobuf[Layout::kCols] = { + 0.0f}; // Remainder default inits to 0.0f. + for (int block_col = start_col; block_col < end_col; + block_col += Layout::kCols) { + int src_stride = src_matrix.layout.stride; + const float* src_ptr = src_matrix.data.get() + src_stride * block_col; + int remaining_src_cols = src_matrix.layout.cols - block_col; + + static constexpr int block_col_mask = ~(Layout::kCols - 1); // High bits. + float* packed_ptr = + packed_matrix->data + + packed_matrix->layout.stride * (block_col & block_col_mask); + PackFloatAvx2(src_ptr, zerobuf, src_stride, remaining_src_cols, + src_matrix.layout.rows, packed_ptr); + } + } +}; + +// Note that source and zero buffers can be uint8 type, but in the packing +// function are reinterpreted as int8, and are XOR-ed with input_xor. +void Pack8bitAvx512(const std::int8_t* src_ptr, std::int8_t input_xor, + const std::int8_t* zerobuf, int src_stride, + int remaining_src_cols, int src_rows, + std::int8_t* packed_ptr, std::int32_t* sums_ptr); + +template +struct PackImpl, + Scalar, std::int8_t, std::int32_t> { + static_assert(std::is_same::value || + std::is_same::value, + ""); + using Layout = FixedKernelLayout; + static constexpr int kHalfLayoutCols = + 8; // Half the number of cols in a block. + static constexpr std::int8_t kInputXor = + std::is_same::value ? 0 : 0x80; + + static void Run(Tuning tuning, const Matrix& src_matrix, + PackedMatrix* packed_matrix, int start_col, + int end_col) { + profiler::ScopeLabel label("Pack (AVX-512 8-bit)"); + + RUY_DCHECK(IsColMajor(src_matrix.layout)); + RUY_DCHECK(IsColMajor(packed_matrix->layout)); + RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0); + RUY_DCHECK_EQ(start_col % Layout::kCols, 0); + RUY_DCHECK_EQ(kHalfLayoutCols * 2, Layout::kCols); + std::int32_t* sums = packed_matrix->sums; + Scalar zerobuf[kHalfLayoutCols * Layout::kRows]; + memset(zerobuf, packed_matrix->zero_point ^ kInputXor, + kHalfLayoutCols * Layout::kRows * sizeof(Scalar)); + for (int block_col = start_col; block_col < end_col; + block_col += Layout::kCols) { + std::int32_t* sums_ptr = sums ? sums + block_col : nullptr; + int src_stride = src_matrix.layout.stride; + const Scalar* src_ptr = src_matrix.data.get() + src_stride * block_col; + int remaining_src_cols = src_matrix.layout.cols - block_col; + + static constexpr int block_col_mask = ~(Layout::kCols - 1); // High bits. + std::int8_t* packed_ptr = + packed_matrix->data + + packed_matrix->layout.stride * (block_col & block_col_mask); + Pack8bitAvx512(reinterpret_cast(src_ptr), kInputXor, + reinterpret_cast(zerobuf), src_stride, + remaining_src_cols, src_matrix.layout.rows, packed_ptr, + sums_ptr); + } + } +}; + +void PackFloatAvx512(const float* src_ptr, const float* zerobuf, int src_stride, + int remaining_src_cols, int src_rows, float* packed_ptr); + +template <> +struct PackImpl, + float, float, float> { + static void Run(Tuning, const Matrix& src_matrix, + PackedMatrix* packed_matrix, int start_col, + int end_col) { + profiler::ScopeLabel label("Pack (AVX-512 float)"); + using Layout = FixedKernelLayout; + RUY_DCHECK(IsColMajor(src_matrix.layout)); + RUY_DCHECK(IsColMajor(packed_matrix->layout)); + RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0); + RUY_DCHECK_EQ(start_col % Layout::kCols, 0); + const float zerobuf[Layout::kCols] = { + 0.0f}; // Remainder default inits to 0.0f. + for (int block_col = start_col; block_col < end_col; + block_col += Layout::kCols) { + int src_stride = src_matrix.layout.stride; + const float* src_ptr = src_matrix.data.get() + src_stride * block_col; + int remaining_src_cols = src_matrix.layout.cols - block_col; + + static constexpr int block_col_mask = ~(Layout::kCols - 1); // High bits. + float* packed_ptr = + packed_matrix->data + + packed_matrix->layout.stride * (block_col & block_col_mask); + PackFloatAvx512(src_ptr, zerobuf, src_stride, remaining_src_cols, + src_matrix.layout.rows, packed_ptr); + } + } +}; + +// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. +// Optimization is not finished. In particular the dimensions of the kernel +// blocks can be changed as desired. +// +// Note that source and zero buffers can be uint8 type, but in the packing +// function are reinterpreted as int8, and are XOR-ed with input_xor. +void Pack8bitAvxVnni(const std::int8_t* src_ptr, std::int8_t input_xor, + const std::int8_t* zerobuf, int src_stride, + int remaining_src_cols, int src_rows, + std::int8_t* packed_ptr, std::int32_t* sums_ptr); + +template +struct PackImpl, + Scalar, std::int8_t, std::int32_t> { + static_assert(std::is_same::value || + std::is_same::value, + ""); + using Layout = FixedKernelLayout; + static constexpr int kHalfLayoutCols = + 8; // Half the number of cols in a block. + static constexpr std::int8_t kInputXor = + std::is_same::value ? 0 : 0x80; + + static void Run(Tuning tuning, const Matrix& src_matrix, + PackedMatrix* packed_matrix, int start_col, + int end_col) { + profiler::ScopeLabel label("Pack (AVX-512 8-bit)"); + + RUY_DCHECK(IsColMajor(src_matrix.layout)); + RUY_DCHECK(IsColMajor(packed_matrix->layout)); + RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0); + RUY_DCHECK_EQ(start_col % Layout::kCols, 0); + RUY_DCHECK_EQ(kHalfLayoutCols * 2, Layout::kCols); + std::int32_t* sums = packed_matrix->sums; + Scalar zerobuf[kHalfLayoutCols * Layout::kRows]; + memset(zerobuf, packed_matrix->zero_point ^ kInputXor, + kHalfLayoutCols * Layout::kRows * sizeof(Scalar)); + for (int block_col = start_col; block_col < end_col; + block_col += Layout::kCols) { + std::int32_t* sums_ptr = sums ? sums + block_col : nullptr; + int src_stride = src_matrix.layout.stride; + const Scalar* src_ptr = src_matrix.data.get() + src_stride * block_col; + int remaining_src_cols = src_matrix.layout.cols - block_col; + + static constexpr int block_col_mask = ~(Layout::kCols - 1); // High bits. + std::int8_t* packed_ptr = + packed_matrix->data + + packed_matrix->layout.stride * (block_col & block_col_mask); + Pack8bitAvxVnni(reinterpret_cast(src_ptr), kInputXor, + reinterpret_cast(zerobuf), src_stride, + remaining_src_cols, src_matrix.layout.rows, packed_ptr, + sums_ptr); + } + } +}; + +// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. +// Optimization is not finished. In particular the dimensions of the kernel +// blocks can be changed as desired. +// +void PackFloatAvxVnni(const float* src_ptr, const float* zerobuf, + int src_stride, int remaining_src_cols, int src_rows, + float* packed_ptr); + +template <> +struct PackImpl, + float, float, float> { + static void Run(Tuning, const Matrix& src_matrix, + PackedMatrix* packed_matrix, int start_col, + int end_col) { + profiler::ScopeLabel label("Pack (AVX-512 float)"); + + using Layout = FixedKernelLayout; + RUY_DCHECK(IsColMajor(src_matrix.layout)); + RUY_DCHECK(IsColMajor(packed_matrix->layout)); + RUY_DCHECK_EQ((end_col - start_col) % Layout::kCols, 0); + RUY_DCHECK_EQ(start_col % Layout::kCols, 0); + const float zerobuf[Layout::kCols] = { + 0.0f}; // Remainder default inits to 0.0f. + for (int block_col = start_col; block_col < end_col; + block_col += Layout::kCols) { + int src_stride = src_matrix.layout.stride; + const float* src_ptr = src_matrix.data.get() + src_stride * block_col; + int remaining_src_cols = src_matrix.layout.cols - block_col; + + static constexpr int block_col_mask = ~(Layout::kCols - 1); // High bits. + float* packed_ptr = + packed_matrix->data + + packed_matrix->layout.stride * (block_col & block_col_mask); + PackFloatAvxVnni(src_ptr, zerobuf, src_stride, remaining_src_cols, + src_matrix.layout.rows, packed_ptr); + } + } +}; +#endif // RUY_PLATFORM(X86) + +} // namespace ruy + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PACK_X86_H_ diff --git a/ruy/path.h b/ruy/path.h new file mode 100644 index 0000000..7141b16 --- /dev/null +++ b/ruy/path.h @@ -0,0 +1,162 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PATH_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PATH_H_ + +#include + +#include "ruy/platform.h" +#include "ruy/size_util.h" + +namespace ruy { + +// A Path is a choice of implementation path, e.g. between reference code +// and optimized code, or between different optimized code paths using different +// instruction sets. +// +// It's important that any symbol that depends on such implementation +// details, is somehow templatized in such a Path, so that different Path values +// yield different symbols, so we never have the situation where a symbols has +// multiple inequivalent definitions based on which code paths are compiled. +// That would be a violation of the ODR (One Definition Rule) which is Undefined +// Behavior, and one of the most serious issues plaguing both Eigen and +// gemmlowp. +// +// This enum is actually a bit-field: aside from kNone, all other values are +// powers of two, thus are one bit each. We define bit-wise operators below +// for this enum. Some places in Ruy accept a Path bit-field where multiple +// Paths may be selected, while some other places require a single Path (i.e. +// just one of the enum values here). Typically, user-facing parts of Ruy +// accept arbitrary bit-fields, allowing the user to compile support for +// multiple paths and to inform Ruy of all the paths that are to be enabled +// at runtime; then, typically in dispatch.h, we internally pick one +// specific path and from there on, internal Ruy code deals with only one +// path. +// +// When a user selects a set of compiled paths, Ruy internally dispatches to the +// "best" one, which typically means the newest optimized instructions for a +// given base architecture (such as ARM). Higher values of this enum correspond +// to "better" code paths within a given base architecture for which Ruy has +// optimized code paths. +// +// Values are reused across architectures. +// Rationale: Scale better to N architectures, it is good to have small values +// both for the compile-time logic to select paths, and when manually spelling +// out Path values, such as when invoking a test or benchmark. +enum class Path : std::uint8_t { + // This is a special null value, representing the absence of any path. + kNone = 0, + // Reference multiplication code. + // The main purpose of this path is to have a very simple standalone Mul + // implementation to check against. + // This path bypasses almost all of Ruy's internal implementation details. + // + // This is intended for testing/development. + kReference = 0x1, + // Standard C++ implementation of Ruy's architecture-specific parts. + // Unlike Path::kReference, this path exercises most of Ruy's internal logic. + // + // This is intended for testing/development. + kStandardCpp = 0x2, + +#if RUY_PLATFORM(ARM) + // ARM architectures. + // + // Optimized path using a widely available subset of ARM NEON instructions. + kNeon = 0x4, + // Optimized path making use of ARM NEON dot product instructions that are + // available on newer ARM cores. + kNeonDotprod = 0x8, +#endif // RUY_PLATFORM(ARM) + +#if RUY_PLATFORM(X86) + // x86 architectures. + // + // TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / + // placeholder. + // Optimization is not finished. In particular the dimensions of the kernel + // blocks can be changed as desired. + // + // Optimized for SSE 4.2. + kSse42 = 0x4, + // Optimized for AVX2. + kAvx2 = 0x8, + // Optimized for AVX-512. + kAvx512 = 0x10, + // TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / + // placeholder. + // Optimization is not finished. In particular the dimensions of the kernel + // blocks can be changed as desired. + // + // Optimized for AVX-VNNI. + kAvxVnni = 0x20, +#endif // RUY_PLATFORM(X86) +}; + +inline constexpr Path operator|(Path p, Path q) { + return static_cast(static_cast(p) | + static_cast(q)); +} + +inline constexpr Path operator&(Path p, Path q) { + return static_cast(static_cast(p) & + static_cast(q)); +} + +inline constexpr Path operator^(Path p, Path q) { + return static_cast(static_cast(p) ^ + static_cast(q)); +} + +inline constexpr Path operator~(Path p) { + return static_cast(~static_cast(p)); +} + +inline Path GetMostSignificantPath(Path path_mask) { + return static_cast(round_down_pot(static_cast(path_mask))); +} + +// ruy::kAllPaths represents all Path's that make sense to on a given +// base architecture. +#ifdef __linux__ +#if RUY_PLATFORM(NEON_64) +constexpr Path kAllPaths = + Path::kReference | Path::kStandardCpp | Path::kNeon | Path::kNeonDotprod; +#elif RUY_PLATFORM(NEON_32) +constexpr Path kAllPaths = Path::kReference | Path::kStandardCpp | Path::kNeon; +#elif RUY_PLATFORM(X86) +constexpr Path kAllPaths = Path::kReference | Path::kStandardCpp | + Path::kSse42 | Path::kAvx2 | Path::kAvx512 | + Path::kAvxVnni; +#else +constexpr Path kAllPaths = Path::kReference | Path::kStandardCpp; +#endif +#else // __linux__ +// We don't know how to do runtime dotprod detection outside of linux for now. +#if RUY_PLATFORM(NEON) +constexpr Path kAllPaths = Path::kReference | Path::kStandardCpp | Path::kNeon; +#elif RUY_PLATFORM(X86) +constexpr Path kAllPaths = Path::kReference | Path::kStandardCpp | + Path::kSse42 | Path::kAvx2 | Path::kAvx512 | + Path::kAvxVnni; +#else +constexpr Path kAllPaths = Path::kReference | Path::kStandardCpp; +#endif +#endif // __linux__ + +} // namespace ruy + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PATH_H_ diff --git a/ruy/platform.h b/ruy/platform.h new file mode 100644 index 0000000..d6e86e6 --- /dev/null +++ b/ruy/platform.h @@ -0,0 +1,156 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PLATFORM_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PLATFORM_H_ + +#ifdef __ANDROID_NDK__ +#include +#endif + +#define RUY_PLATFORM(X) ((RUY_DONOTUSEDIRECTLY_##X) != 0) + +// Architecture-level platform detection. +// +// Ruy requires these to be mutually exclusive. + +// Detect x86. +#if defined(__x86_64__) || defined(__i386__) || defined(__i386) || \ + defined(__x86__) || defined(__X86__) || defined(_X86_) || \ + defined(_M_IX86) || defined(_M_X64) +#define RUY_DONOTUSEDIRECTLY_X86 1 +#else +#define RUY_DONOTUSEDIRECTLY_X86 0 +#endif + +// Detect ARM 32-bit. +#ifdef __arm__ +#define RUY_DONOTUSEDIRECTLY_ARM_32 1 +#else +#define RUY_DONOTUSEDIRECTLY_ARM_32 0 +#endif + +// Detect ARM 64-bit. +#ifdef __aarch64__ +#define RUY_DONOTUSEDIRECTLY_ARM_64 1 +#else +#define RUY_DONOTUSEDIRECTLY_ARM_64 0 +#endif + +// Combined ARM. +#define RUY_DONOTUSEDIRECTLY_ARM \ + (RUY_DONOTUSEDIRECTLY_ARM_64 || RUY_DONOTUSEDIRECTLY_ARM_32) + +// Feature and capability platform detection. +// +// These are mostly sub-selections of architectures. + +// Detect NEON. Explicitly avoid emulation, or anything like it, on x86. +#if (defined(__ARM_NEON) || defined(__ARM_NEON__)) && !RUY_PLATFORM(X86) +#define RUY_DONOTUSEDIRECTLY_NEON 1 +#else +#define RUY_DONOTUSEDIRECTLY_NEON 0 +#endif + +// Define ARM 32-bit NEON. +#define RUY_DONOTUSEDIRECTLY_NEON_32 \ + (RUY_DONOTUSEDIRECTLY_NEON && RUY_DONOTUSEDIRECTLY_ARM_32) + +// Define ARM 64-bit NEON. +// Note: NEON is implied by ARM64, so this define is redundant. +// It still allows some conveyance of intent. +#define RUY_DONOTUSEDIRECTLY_NEON_64 \ + (RUY_DONOTUSEDIRECTLY_NEON && RUY_DONOTUSEDIRECTLY_ARM_64) + +// Disable X86 enhancements on __APPLE__ because b/138922878, see comment #8, we +// may only need to disable this on XCode <= 10.2. +// +// Disable when not using Clang-Linux, because too many user issues arise from +// compilation variations. +// +// NOTE: Consider guarding by !defined(__APPLE__) when removing Linux-only +// restriction. +// +// __EMSCRIPTEN__ is checked because the runtime Path resolution can use asm. +// +// The Android NDK logic excludes earlier and very broken versions of intrinsics +// headers. +#if defined(RUY_FORCE_ENABLE_X86_ENHANCEMENTS) || \ + (defined(__clang__) && (__clang_major__ >= 8) && defined(__linux__) && \ + !defined(__EMSCRIPTEN__) && \ + (!defined(__ANDROID_NDK__) || \ + (defined(__NDK_MAJOR__) && (__NDK_MAJOR__ >= 20)))) +#define RUY_DONOTUSEDIRECTLY_X86_ENHANCEMENTS 1 +#else +#define RUY_DONOTUSEDIRECTLY_X86_ENHANCEMENTS 0 +#endif + +// These CPU capabilities will all be true when Skylake, etc, are enabled during +// compilation. +#if RUY_PLATFORM(X86_ENHANCEMENTS) && RUY_PLATFORM(X86) && \ + defined(__AVX512F__) && defined(__AVX512DQ__) && defined(__AVX512CD__) && \ + defined(__AVX512BW__) && defined(__AVX512VL__) +#define RUY_DONOTUSEDIRECTLY_AVX512 1 +#else +#define RUY_DONOTUSEDIRECTLY_AVX512 0 +#endif + +#if RUY_PLATFORM(X86_ENHANCEMENTS) && RUY_PLATFORM(X86) && defined(__AVX2__) +#define RUY_DONOTUSEDIRECTLY_AVX2 1 +#else +#define RUY_DONOTUSEDIRECTLY_AVX2 0 +#endif + +// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. +// Optimization is not finished. In particular the dimensions of the kernel +// blocks can be changed as desired. +// +// Note does not check for LZCNT or POPCNT. +#if defined(RUY_ENABLE_SSE_ENHANCEMENTS) && RUY_PLATFORM(X86_ENHANCEMENTS) && \ + RUY_PLATFORM(X86) && defined(__SSE4_2__) && defined(__FMA__) +#define RUY_DONOTUSEDIRECTLY_SSE42 1 +#else +#define RUY_DONOTUSEDIRECTLY_SSE42 0 +#endif + +// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. +// Optimization is not finished. In particular the dimensions of the kernel +// blocks can be changed as desired. +// +// Note that defined(__AVX512VBMI2__) can be false for compilation with +// -march=cascadelake. +// TODO(b/146646451) Check if we should also gate on defined(__AVX512VBMI2__). +#if defined(RUY_ENABLE_VNNI_ENHANCEMENTS) && RUY_PLATFORM(AVX512) && \ + defined(__AVX512VNNI__) +#define RUY_DONOTUSEDIRECTLY_AVX_VNNI 1 +#else +#define RUY_DONOTUSEDIRECTLY_AVX_VNNI 0 +#endif + +// Detect APPLE. +#ifdef __APPLE__ +#define RUY_DONOTUSEDIRECTLY_APPLE 1 +#else +#define RUY_DONOTUSEDIRECTLY_APPLE 0 +#endif + +// Detect Emscripten, typically Wasm. +#ifdef __EMSCRIPTEN__ +#define RUY_DONOTUSEDIRECTLY_EMSCRIPTEN 1 +#else +#define RUY_DONOTUSEDIRECTLY_EMSCRIPTEN 0 +#endif + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PLATFORM_H_ diff --git a/ruy/pmu.cc b/ruy/pmu.cc new file mode 100644 index 0000000..1d87b1f --- /dev/null +++ b/ruy/pmu.cc @@ -0,0 +1,281 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/pmu.h" + +#include "ruy/check_macros.h" + +#ifdef __linux__ +#include +#include +#include +#include +#include + +#include +#endif + +#include +#include +#include +#include + +namespace ruy { + +// Linux-specific. Not ARM-specific. +#ifdef __linux__ +class PerfEvent { + public: + PerfEvent(std::uint32_t type, std::uint64_t config) { + perf_event_attr pe; + memset(&pe, 0, sizeof(pe)); + pe.size = sizeof(pe); + pe.type = type; + pe.config = config; + pe.disabled = 1; + pe.exclude_kernel = 1; + pe.exclude_hv = 1; + pe.inherit = 1; + fd_ = syscall(__NR_perf_event_open, &pe, 0, -1, -1, 0); + if (fd_ == -1) { + fprintf(stderr, "perf_event_open failed for config 0x%lx\n", + static_cast(config)); + // abort(); + } + } + + ~PerfEvent() { + RUY_CHECK(!started_); + close(fd_); + } + + void Start() { + RUY_CHECK(!started_); + started_ = true; + ioctl(fd_, PERF_EVENT_IOC_RESET, 0); + ioctl(fd_, PERF_EVENT_IOC_ENABLE, 0); + count_at_start_ = Read(); + } + + void Stop() { + RUY_CHECK(started_); + started_ = false; + ioctl(fd_, PERF_EVENT_IOC_DISABLE, 0); + count_at_stop_ = Read(); + } + + std::int64_t Count() const { + RUY_CHECK(!started_); + return count_at_stop_ - count_at_start_; + } + + private: + std::int64_t Read() const { + std::int64_t count; + RUY_CHECK_NE(read(fd_, &count, sizeof(count)), -1); + return count; + } + std::int64_t count_at_start_ = -1; + std::int64_t count_at_stop_ = -1; + bool started_ = false; + int fd_ = -1; +}; +#else +// Placeholder implementation to at least compile outside of linux. +#define PERF_TYPE_RAW 0 +class PerfEvent { + public: + PerfEvent(std::uint32_t, std::uint64_t) {} + ~PerfEvent() {} + void Start() {} + void Stop() {} + std::int64_t Count() const { return 0; } +}; +#endif + +// ARM-specific. Query ARM PMU counters as Linux perf events using +// PERF_TYPE_RAW. +namespace arm_pmuv3 { + +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-const-variable" + +// These event numbers are listed in the ARMv8 architecture reference manual. +constexpr std::uint16_t L1I_CACHE_REFILL = 0x01; +constexpr std::uint16_t L1I_TLB_REFILL = 0x02; +constexpr std::uint16_t L1D_CACHE_REFILL = 0x03; +constexpr std::uint16_t L1D_CACHE = 0x04; +constexpr std::uint16_t L1D_TLB_REFILL = 0x05; +constexpr std::uint16_t LD_RETIRED = 0x06; +constexpr std::uint16_t ST_RETIRED = 0x07; +constexpr std::uint16_t INST_RETIRED = 0x08; +constexpr std::uint16_t EXC_TAKEN = 0x09; +constexpr std::uint16_t EXC_RETURN = 0x0A; +constexpr std::uint16_t CID_WRITE_RETIRED = 0x0B; +constexpr std::uint16_t PC_WRITE_RETIRED = 0x0C; +constexpr std::uint16_t BR_IMMED_RETIRED = 0x0D; +constexpr std::uint16_t BR_RETURN_RETIRED = 0x0E; +constexpr std::uint16_t UNALIGNED_LDST_RETIRED = 0x0F; +constexpr std::uint16_t BR_MIS_PRED = 0x10; +constexpr std::uint16_t CPU_CYCLES = 0x11; +constexpr std::uint16_t BR_PRED = 0x12; +constexpr std::uint16_t MEM_ACCESS = 0x13; +constexpr std::uint16_t L1I_CACHE = 0x14; +constexpr std::uint16_t L1D_CACHE_WB = 0x15; +constexpr std::uint16_t L2D_CACHE = 0x16; +constexpr std::uint16_t L2D_CACHE_REFILL = 0x17; +constexpr std::uint16_t L2D_CACHE_WB = 0x18; +constexpr std::uint16_t BUS_ACCESS = 0x19; +constexpr std::uint16_t MEMORY_ERROR = 0x1A; +constexpr std::uint16_t INST_SPEC = 0x1B; +constexpr std::uint16_t TTBR_WRITE_RETIRED = 0x1C; +constexpr std::uint16_t BUS_CYCLES = 0x1D; +constexpr std::uint16_t CHAIN = 0x1E; +constexpr std::uint16_t L1D_CACHE_ALLOCATE = 0x1F; +constexpr std::uint16_t L2D_CACHE_ALLOCATE = 0x20; +constexpr std::uint16_t BR_RETIRED = 0x21; +constexpr std::uint16_t BR_MIS_PRED_RETIRED = 0x22; +constexpr std::uint16_t STALL_FRONTEND = 0x23; +constexpr std::uint16_t STALL_BACKEND = 0x24; +constexpr std::uint16_t L1D_TLB = 0x25; +constexpr std::uint16_t L1I_TLB = 0x26; +constexpr std::uint16_t L2I_CACHE = 0x27; +constexpr std::uint16_t L2I_CACHE_REFILL = 0x28; +constexpr std::uint16_t L3D_CACHE_ALLOCATE = 0x29; +constexpr std::uint16_t L3D_CACHE_REFILL = 0x2A; +constexpr std::uint16_t L3D_CACHE = 0x2B; +constexpr std::uint16_t L3D_CACHE_WB = 0x2C; +constexpr std::uint16_t L2D_TLB_REFILL = 0x2D; +constexpr std::uint16_t L2I_TLB_REFILL = 0x2E; +constexpr std::uint16_t L2D_TLB = 0x2F; +constexpr std::uint16_t L2I_TLB = 0x30; +constexpr std::uint16_t LL_CACHE = 0x32; +constexpr std::uint16_t LL_CACHE_MISS = 0x33; +constexpr std::uint16_t DTLB_WALK = 0x34; +constexpr std::uint16_t LL_CACHE_RD = 0x36; +constexpr std::uint16_t LL_CACHE_MISS_RD = 0x37; + +// Additional implementation-defined events found by googling around. +constexpr std::uint16_t L1D_CACHE_RD = 0x40; +constexpr std::uint16_t L1D_CACHE_REFILL_RD = 0x42; +constexpr std::uint16_t L1D_TLB_REFILL_RD = 0x4C; +constexpr std::uint16_t L1D_TLB_RD = 0x4E; +constexpr std::uint16_t L2D_CACHE_RD = 0x50; +constexpr std::uint16_t L2D_CACHE_REFILL_RD = 0x52; +constexpr std::uint16_t BUS_ACCESS_RD = 0x60; +constexpr std::uint16_t MEM_ACCESS_RD = 0x66; +constexpr std::uint16_t L3D_CACHE_RD = 0xA0; +constexpr std::uint16_t L3D_CACHE_REFILL_RD = 0xA2; + +#pragma GCC diagnostic pop + +}; // namespace arm_pmuv3 + +class PmuEventsPrivate { + public: + PmuEventsPrivate() + : l1d_cache_refill(PERF_TYPE_RAW, arm_pmuv3::L1D_CACHE_REFILL), + l2d_cache_refill(PERF_TYPE_RAW, arm_pmuv3::L2D_CACHE_REFILL), + l3d_cache_refill(PERF_TYPE_RAW, arm_pmuv3::L3D_CACHE_REFILL), + ll_cache_miss(PERF_TYPE_RAW, arm_pmuv3::LL_CACHE_MISS), + l1d_tlb_refill(PERF_TYPE_RAW, arm_pmuv3::L1D_TLB_REFILL), + l2d_tlb_refill(PERF_TYPE_RAW, arm_pmuv3::L2D_TLB_REFILL), + stall_frontend(PERF_TYPE_RAW, arm_pmuv3::STALL_FRONTEND), + stall_backend(PERF_TYPE_RAW, arm_pmuv3::STALL_BACKEND), + br_mis_pred(PERF_TYPE_RAW, arm_pmuv3::BR_MIS_PRED) {} + + private: + friend class PmuEvents; + PerfEvent l1d_cache_refill; + PerfEvent l2d_cache_refill; + PerfEvent l3d_cache_refill; + PerfEvent ll_cache_miss; + PerfEvent l1d_tlb_refill; + PerfEvent l2d_tlb_refill; + PerfEvent stall_frontend; + PerfEvent stall_backend; + PerfEvent br_mis_pred; +}; + +PmuEvents::PmuEvents() : priv(new PmuEventsPrivate) {} +PmuEvents::~PmuEvents() { delete priv; } + +void PmuEvents::StartRecording() { + priv->l1d_cache_refill.Start(); + priv->l2d_cache_refill.Start(); + priv->l3d_cache_refill.Start(); + priv->ll_cache_miss.Start(); + priv->l1d_tlb_refill.Start(); + priv->l2d_tlb_refill.Start(); + priv->stall_frontend.Start(); + priv->stall_backend.Start(); + priv->br_mis_pred.Start(); +} + +void PmuEvents::StopRecording() { + priv->l1d_cache_refill.Stop(); + priv->l2d_cache_refill.Stop(); + priv->l3d_cache_refill.Stop(); + priv->ll_cache_miss.Stop(); + priv->l1d_tlb_refill.Stop(); + priv->l2d_tlb_refill.Stop(); + priv->stall_frontend.Stop(); + priv->stall_backend.Stop(); + priv->br_mis_pred.Stop(); +} + +float PmuEvents::BranchMispredictionCount() const { + return static_cast(priv->br_mis_pred.Count()); +} + +float PmuEvents::FrontendStallCount() const { + return static_cast(priv->stall_frontend.Count()); +} + +float PmuEvents::BackendStallCount() const { + return static_cast(priv->stall_backend.Count()); +} + +float PmuEvents::L1RefillCount() const { + return static_cast(priv->l1d_cache_refill.Count()); +} + +float PmuEvents::L2RefillCount() const { + return static_cast(priv->l2d_cache_refill.Count()); +} + +float PmuEvents::L3RefillCount() const { + // Important: this was discovered in the context of the above experiments, + // which also tested the _RD variants of these counters. So it's possible that + // it's just not needed here with the default (non _RD) counters. + // + // Some CPUs implement LL_CACHE_MISS[_RD], some implement + // L3D_CACHE_REFILL[_RD]. It seems that either one of these two counters is + // zero, or they roughly both agree with each other. Therefore, taking the max + // of them is a reasonable way to get something more portable across various + // CPUs. + return static_cast( + std::max(priv->l3d_cache_refill.Count(), priv->ll_cache_miss.Count())); +} + +float PmuEvents::L1TLBRefillCount() const { + return static_cast(priv->l1d_tlb_refill.Count()); +} + +float PmuEvents::L2TLBRefillCount() const { + return static_cast(priv->l2d_tlb_refill.Count()); +} + +} // namespace ruy diff --git a/ruy/pmu.h b/ruy/pmu.h new file mode 100644 index 0000000..721c1d5 --- /dev/null +++ b/ruy/pmu.h @@ -0,0 +1,44 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PMU_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PMU_H_ + +namespace ruy { + +class PmuEventsPrivate; + +class PmuEvents { + public: + PmuEvents(); + ~PmuEvents(); + void StartRecording(); + void StopRecording(); + float L1RefillCount() const; + float L2RefillCount() const; + float L3RefillCount() const; + float BranchMispredictionCount() const; + float FrontendStallCount() const; + float BackendStallCount() const; + float L1TLBRefillCount() const; + float L2TLBRefillCount() const; + + private: + PmuEventsPrivate* priv = nullptr; +}; + +} // namespace ruy + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PMU_H_ diff --git a/ruy/prepack.h b/ruy/prepack.h new file mode 100644 index 0000000..4bfc9ed --- /dev/null +++ b/ruy/prepack.h @@ -0,0 +1,108 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Implementation of low-level pre-packing API. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PREPACK_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PREPACK_H_ + +#include +#include + +#include "ruy/check_macros.h" +#include "ruy/context.h" +#include "ruy/dispatch.h" +#include "ruy/internal_matrix.h" +#include "ruy/matrix.h" +#include "ruy/path.h" +#include "ruy/profiler/instrumentation.h" +#include "ruy/side_pair.h" +#include "ruy/spec.h" +#include "ruy/trmul.h" +#include "ruy/trmul_params.h" +#include "ruy/tune.h" + +namespace ruy { + +template +void PrePackForMulInternal(const Matrix& lhs, + const Matrix& rhs, const Spec& spec, + Context* context, Matrix* dst, + SidePair prepacked, + std::function alloc_fn) { + profiler::ScopeLabel label("PrePackForMul"); + Path the_path = context->GetPathToTake(); + RUY_CHECK_NE(the_path, Path::kReference); + constexpr Path TrMulCompiledPaths = CompiledPaths & ~Path::kReference; + Matrix transposed_lhs(lhs); + Transpose(&transposed_lhs); + TrMulParams params; + CreateTrMulParams(transposed_lhs, rhs, spec, context, dst, + the_path, ¶ms); + + const SidePair origin{0, 0}; + const SidePair rounded_dims{params.packed[Side::kLhs].layout.cols, + params.packed[Side::kRhs].layout.cols}; + + Tuning tuning = context->GetMainThreadTuning(); + for (Side side : {Side::kLhs, Side::kRhs}) { + if (prepacked[side]) { + prepacked[side]->data_size = DataSize(params.packed[side]); + prepacked[side]->sums_size = SumsSize(params.packed[side]); + prepacked[side]->data = alloc_fn(prepacked[side]->data_size); + prepacked[side]->sums = alloc_fn(prepacked[side]->sums_size); + params.packed[side].data = prepacked[side]->data; + params.packed[side].sums = prepacked[side]->sums; + params.RunPack(side, tuning, origin[side], rounded_dims[side]); + } + } +} + +template +void MulWithPrepackedInternal(const Matrix& lhs, + const Matrix& rhs, const Spec& spec, + Context* context, Matrix* dst, + SidePair prepacked) { + profiler::ScopeLabel label("MulWithPrepacked"); + + EnforceLayoutSupport(lhs.layout, rhs.layout, dst->layout); + EnforceZeroPointSupport(lhs.zero_point, rhs.zero_point, + dst->zero_point); + + Path the_path = context->GetPathToTake(); + RUY_CHECK_NE(the_path, Path::kReference); + constexpr Path TrMulCompiledPaths = CompiledPaths & ~Path::kReference; + Matrix transposed_lhs(lhs); + Transpose(&transposed_lhs); + TrMulParams params; + CreateTrMulParams(transposed_lhs, rhs, spec, context, dst, + the_path, ¶ms); + + for (Side side : {Side::kLhs, Side::kRhs}) { + if (prepacked[side]) { + params.packed[side].data = prepacked[side]->data; + params.packed[side].sums = prepacked[side]->sums; + params.is_prepacked[side] = true; + } + } + + TrMul(¶ms, context); +} + +} // namespace ruy + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PREPACK_H_ diff --git a/ruy/prepacked_cache.cc b/ruy/prepacked_cache.cc new file mode 100644 index 0000000..020fdf7 --- /dev/null +++ b/ruy/prepacked_cache.cc @@ -0,0 +1,82 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/prepacked_cache.h" + +#include "ruy/matrix.h" +#include "ruy/profiler/instrumentation.h" + +namespace ruy { + +using CacheIterator = PrepackedCache::CacheIterator; + +// Looks for an entry with `key`. If found, update its time stamp. +CacheIterator PrepackedCache::FindAndUpdate(const CacheKey &key) { + auto itr = cache_.find(key); + // If found, update with new access time for this entry. + if (itr != cache_.end()) { + const TimePoint time = CacheNow(); + itr->second.second = time; + } + return itr; +} + +void PrepackedCache::Insert(const CacheKey &key, + const PrepackedMatrix &matrix) { + // Calculate size of this new item. + const size_t size_bytes = matrix.data_size + matrix.sums_size; + + // While we are above the threshold of ejection, eject the LRU entry. + while (!cache_.empty() && + ((TotalSize() + size_bytes) > ejection_threshold_)) { + EjectOne(); + } + DoInsert(key, matrix); + cache_size_ += matrix.data_size + matrix.sums_size; +} + +void PrepackedCache::EjectOne() { + TimePoint oldest_time = CacheNow(); + auto oldest = cache_.begin(); + { + profiler::ScopeLabel label("PepackedCacheEjection"); + for (auto itr = cache_.begin(); itr != cache_.end(); ++itr) { + if (itr->second.second < oldest_time) { + oldest_time = itr->second.second; + oldest = itr; + } + } + } + PrepackedMatrix &pmatrix = oldest->second.first; + cache_size_ -= pmatrix.data_size; + cache_size_ -= pmatrix.sums_size; + allocator_.Free(pmatrix.data); + allocator_.Free(pmatrix.sums); + cache_.erase(oldest); +} + +void PrepackedCache::AllocatePrepackedMatrix(PrepackedMatrix *pmatrix) { + pmatrix->data = allocator_.Alloc(pmatrix->data_size); + pmatrix->sums = allocator_.Alloc(pmatrix->sums_size); +} + +void PrepackedCache::DoInsert(const CacheKey &key, + const PrepackedMatrix &matrix) { + const TimePoint t = CacheNow(); + const MatrixWithTimeStamp mts({matrix, t}); + cache_.insert({key, mts}); +} + +} // namespace ruy diff --git a/ruy/prepacked_cache.h b/ruy/prepacked_cache.h new file mode 100644 index 0000000..eedd7e4 --- /dev/null +++ b/ruy/prepacked_cache.h @@ -0,0 +1,130 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PREPACKED_CACHE_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PREPACKED_CACHE_H_ + +#include +#include +#include +#include +#include + +#include "ruy/allocator.h" +#include "ruy/matrix.h" +#include "ruy/time.h" + +namespace ruy { + +namespace detail { + +// Tracks a set of blocks allocated from the underlying system allocator. +class SystemBlockAllocator { + public: + void *Alloc(std::ptrdiff_t num_bytes) { + void *p = detail::SystemAlignedAlloc(num_bytes); + blocks_.push_back(p); + return p; + } + + void Free(void *block) { + for (auto it = blocks_.begin(); it != blocks_.end(); ++it) { + if (*it == block) { + detail::SystemAlignedFree(block); + blocks_.erase(it); + return; + } + } + RUY_DCHECK(false); // Trying to free pointer we did not allocate. + } + + ~SystemBlockAllocator() { + for (void *block : blocks_) { + detail::SystemAlignedFree(block); + } + } + + private: + std::vector blocks_; +}; + +} // namespace detail + +enum CachePolicy { kNoCache, kCacheLHSOnNarrowMul }; + +// "Low effort" Least Recently Used Cache for Prepacked Matrices +// A cache mechanism for prepacked matrices that ejects oldest entries. +// The implementation is "low effort" in the following ways: +// - we just linearly search for the oldest entry when doing an ejection +// - the ejection policy is very simple: if the new size would be above the +// . threshold, we will eject entries until the size is below the threshold. +// Current use cases (RNNs with GEMV operations) indicate that ejection is rare +// and memory constraints are tight, so we devote no additional storage to the +// LRU mechanism and accept O(n) search to eject oldest entry. In practice, +// the number of total entries has not been shown to be large. +// This class is not thread safe. In Ruy, memory allocation for packed matrices +// is done in a single threaded context and the actual packing activity may +// be done in a multi-threaded context. +class PrepackedCache { + public: + static constexpr int kDefaultEjectionThresholdBytes = 1 << 28; + + using CacheKey = std::pair; + + using MatrixWithTimeStamp = std::pair; + + using CacheIterator = std::map::const_iterator; + + using AlignedAllocator = detail::AlignedAllocator; + + explicit PrepackedCache( + int32_t ejection_threshold = kDefaultEjectionThresholdBytes) + : ejection_threshold_(ejection_threshold), cache_size_(0) {} + + // Looks for an entry with `key`. If found, update its time stamp. + CacheIterator FindAndUpdate(const CacheKey &key); + + // Returns end iterator for internal cache. The iterator type is appropriate + // to use with `FindAndUpdate`. + CacheIterator cend() const { return cache_.end(); } + + // Returns the total size (in bytes) of data held in this cache. + int TotalSize() const { return cache_size_; } + + // All calls to get current TimePoints go through here. + // TODO(b/145625614) Profile timestamps on relevant models to see if + // this level of granularity is sufficient. CoarseNow is cheap so + // it would be nice to keep it. + TimePoint CacheNow() const { return CoarseNow(); } + + // Performs the memory allocation for the `data` and `sums` members of a + // PrepackedMatrix. + void AllocatePrepackedMatrix(PrepackedMatrix *pmatrix); + + // Adds the PrepackedMatrix to the cache, possibly ejecting other values. + void Insert(const CacheKey &key, const PrepackedMatrix &matrix); + + private: + void EjectOne(); + void DoInsert(const CacheKey &key, const PrepackedMatrix &matrix); + detail::SystemBlockAllocator allocator_; + std::map cache_; + const int32_t ejection_threshold_; + size_t cache_size_; +}; + +} // namespace ruy + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PREPACKED_CACHE_H_ diff --git a/ruy/prepacked_cache_test.cc b/ruy/prepacked_cache_test.cc new file mode 100644 index 0000000..a65841e --- /dev/null +++ b/ruy/prepacked_cache_test.cc @@ -0,0 +1,210 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/prepacked_cache.h" + +#include // NOLINT(build/c++11) + +#include "testing/base/public/gunit.h" +#include "ruy/ruy.h" +#include "ruy/time.h" + +namespace ruy { +namespace { + +TEST(PrepackedCacheTest, TestCacheEjection) { + // Create the cache. + PrepackedCache prepacked_cache(32); + // Allocate the prepacked matrix. + PrepackedMatrix mat1; + mat1.data_size = 16; + mat1.sums_size = 8; + prepacked_cache.AllocatePrepackedMatrix(&mat1); + auto cache_key1 = std::make_pair(nullptr, mat1.data); + prepacked_cache.Insert(cache_key1, mat1); + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + // Get a time point after the insertion into the cache. + TimePoint current = CoarseNow(); + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + PrepackedCache::CacheIterator itr = prepacked_cache.FindAndUpdate(cache_key1); + EXPECT_NE(itr, prepacked_cache.cend()); + // By finding mat1, we updated its timestamp. Verify that `current` is older + // than the time stamp now associated with mat1. + EXPECT_LT(current, itr->second.second); + PrepackedMatrix mat2; + mat2.data_size = 8; + mat2.sums_size = 4; + prepacked_cache.AllocatePrepackedMatrix(&mat2); + + auto cache_key2 = std::make_pair(nullptr, mat2.data); + prepacked_cache.Insert(cache_key2, mat2); + // The cache size was exceeded by inserting mat2. Ensure that mat1 was + // ejected. + EXPECT_EQ(prepacked_cache.FindAndUpdate(cache_key1), prepacked_cache.cend()); +} + +TEST(PrepackedCacheTest, TestCacheBasic) { + // Create the cache. + PrepackedCache prepacked_cache(48); + // Allocate the prepacked matrix. + PrepackedMatrix mat1; + mat1.data_size = 16; + mat1.sums_size = 8; + prepacked_cache.AllocatePrepackedMatrix(&mat1); + + auto cache_key1 = std::make_pair(nullptr, mat1.data); + prepacked_cache.Insert(cache_key1, mat1); + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + EXPECT_NE(prepacked_cache.FindAndUpdate(cache_key1), prepacked_cache.cend()); + + PrepackedMatrix mat2; + mat2.data_size = 8; + mat2.sums_size = 4; + prepacked_cache.AllocatePrepackedMatrix(&mat2); + + auto cache_key2 = std::make_pair(nullptr, mat2.data); + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + prepacked_cache.Insert(cache_key2, mat2); + // The cache size was not exceeded by inserting mat2. Ensure that mat1 was not + // ejected. + EXPECT_NE(prepacked_cache.FindAndUpdate(cache_key1), prepacked_cache.cend()); +} + +TEST(PrepackedCacheTest, TestCacheEjection2) { + // Create the cache. + PrepackedCache prepacked_cache(73); + // Allocate the prepacked matrix 1. + PrepackedMatrix mat1; + mat1.data_size = 16; + mat1.sums_size = 8; + prepacked_cache.AllocatePrepackedMatrix(&mat1); + auto cache_key1 = std::make_pair(nullptr, mat1.data); + prepacked_cache.Insert(cache_key1, mat1); + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + + // Allocate the prepacked matrix 2. + PrepackedMatrix mat2; + mat2.data_size = 16; + mat2.sums_size = 8; + prepacked_cache.AllocatePrepackedMatrix(&mat2); + auto cache_key2 = std::make_pair(nullptr, mat2.data); + prepacked_cache.Insert(cache_key2, mat2); + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + + // Allocate the prepacked matrix 3. + PrepackedMatrix mat31; + mat31.data_size = 16; + mat31.sums_size = 8; + prepacked_cache.AllocatePrepackedMatrix(&mat31); + auto cache_key3 = std::make_pair(nullptr, mat31.data); + prepacked_cache.Insert(cache_key3, mat31); + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + + // The next insertion will cause the cache size to go over the ejection + // threshold. Touch matrix 1 and matrix 3 to make matrix 2 the oldest + EXPECT_NE(prepacked_cache.FindAndUpdate(cache_key1), prepacked_cache.cend()); + EXPECT_NE(prepacked_cache.FindAndUpdate(cache_key3), prepacked_cache.cend()); + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + + // Allocate the prepacked matrix 4. + PrepackedMatrix mat4; + mat4.data_size = 16; + mat4.sums_size = 8; + prepacked_cache.AllocatePrepackedMatrix(&mat4); + auto cache_key4 = std::make_pair(nullptr, mat4.data); + prepacked_cache.Insert(cache_key4, mat4); + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + + // Ensure that mat2 was ejected, but mat1, mat3, and mat4 were not. + EXPECT_EQ(prepacked_cache.FindAndUpdate(cache_key2), prepacked_cache.cend()); + EXPECT_NE(prepacked_cache.FindAndUpdate(cache_key3), prepacked_cache.cend()); + EXPECT_NE(prepacked_cache.FindAndUpdate(cache_key1), prepacked_cache.cend()); + EXPECT_NE(prepacked_cache.FindAndUpdate(cache_key4), prepacked_cache.cend()); +} + +TEST(PrepackedCacheTest, TestCacheOnCacheable) { + // Create context and set the cache policy + ruy::Context context; + context.cache_policy = ruy::kCacheLHSOnNarrowMul; + PrepackedCache* cache = context.GetPrepackedCache(); + EXPECT_EQ(cache->TotalSize(), 0); + + const float lhs_data[] = {1, 2, 3, 4}; + const float rhs_data[] = {1, 2}; + float dst_data[4]; + + ruy::Matrix lhs; + ruy::MakeSimpleLayout(2, 2, ruy::Order::kRowMajor, &lhs.layout); + lhs.data = lhs_data; + ruy::Matrix rhs; + ruy::MakeSimpleLayout(2, 1, ruy::Order::kColMajor, &rhs.layout); + rhs.data = rhs_data; + ruy::Matrix dst; + ruy::MakeSimpleLayout(2, 1, ruy::Order::kColMajor, &dst.layout); + dst.data = dst_data; + + ruy::BasicSpec spec; + // Perform the multiplication and confirm no caching occurred. + ruy::Mul(lhs, rhs, spec, &context, &dst); + EXPECT_EQ(cache->TotalSize(), 0); + + // Set cacheable for the LHS, repeat the multiplication, and see + // that caching did occur. + lhs.cacheable = true; + ruy::Mul(lhs, rhs, spec, &context, &dst); + EXPECT_NE(cache->TotalSize(), 0); +} + +TEST(PrepackedCacheTest, TestClearCache) { + // Create context and set the cache policy + ruy::Context context; + context.cache_policy = ruy::kCacheLHSOnNarrowMul; + PrepackedCache* cache = context.GetPrepackedCache(); + EXPECT_EQ(cache->TotalSize(), 0); + + const float lhs_data[] = {1, 2, 3, 4}; + const float rhs_data[] = {1, 2}; + float dst_data[4]; + + ruy::Matrix lhs; + ruy::MakeSimpleLayout(2, 2, ruy::Order::kRowMajor, &lhs.layout); + lhs.data = lhs_data; + ruy::Matrix rhs; + ruy::MakeSimpleLayout(2, 1, ruy::Order::kColMajor, &rhs.layout); + rhs.data = rhs_data; + ruy::Matrix dst; + ruy::MakeSimpleLayout(2, 1, ruy::Order::kColMajor, &dst.layout); + dst.data = dst_data; + + ruy::BasicSpec spec; + // Set cacheable for the LHS and see that caching occurs. + lhs.cacheable = true; + ruy::Mul(lhs, rhs, spec, &context, &dst); + EXPECT_NE(cache->TotalSize(), 0); + + // Clear the cache via the Context. + context.ClearPrepackedCache(); + // Verify that the cache is now empty. + cache = context.GetPrepackedCache(); + EXPECT_EQ(cache->TotalSize(), 0); +} + +} // namespace +} // namespace ruy + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/ruy/profiler/BUILD b/ruy/profiler/BUILD new file mode 100644 index 0000000..b0af802 --- /dev/null +++ b/ruy/profiler/BUILD @@ -0,0 +1,52 @@ +# A minimalistic profiler sampling pseudo-stacks + +package( + default_visibility = ["//visibility:public"], + licenses = ["notice"], # Apache 2.0 +) + +config_setting( + name = "ruy_profiler", + define_values = {"ruy_profiler": "true"}, +) + +cc_library( + name = "instrumentation", + srcs = ["instrumentation.cc"], + hdrs = ["instrumentation.h"], + defines = select({ + ":ruy_profiler": ["RUY_PROFILER"], + "//conditions:default": [], + }), +) + +cc_library( + name = "profiler", + srcs = [ + "profiler.cc", + "treeview.cc", + ], + hdrs = [ + "profiler.h", + "treeview.h", + ], + deps = [":instrumentation"], +) + +cc_library( + name = "test_instrumented_library", + testonly = True, + srcs = ["test_instrumented_library.cc"], + hdrs = ["test_instrumented_library.h"], + deps = [":instrumentation"], +) + +cc_test( + name = "test", + srcs = ["test.cc"], + deps = [ + ":profiler", + ":test_instrumented_library", + "@com_google_googletest//:gtest", + ], +) diff --git a/ruy/profiler/README.md b/ruy/profiler/README.md new file mode 100644 index 0000000..8d79025 --- /dev/null +++ b/ruy/profiler/README.md @@ -0,0 +1,149 @@ +# A minimalistic profiler sampling pseudo-stacks + +## Overview + +The present directory is the "ruy profiler". As a time profiler, it allows to +measure where code is spending time. + +Contrary to most typical profilers, what it samples is not real call stacks, but +"pseudo-stacks" which are just simple data structures constructed from within +the program being profiled. Using this profiler requires manually instrumenting +code to construct such pseudo-stack information. + +Another unusual characteristic of this profiler is that it uses only the C++11 +standard library. It does not use any non-portable feature, in particular it +does not rely on signal handlers. The sampling is performed by a thread, the +"profiler thread". + +A discussion of pros/cons of this approach is appended below. + +## How to use this profiler + +### How to instrument code + +An example of instrumented code is given in `test_instrumented_library.cc`. + +Code is instrumented by constructing `ScopeLabel` objects. These are RAII +helpers, ensuring that the thread pseudo-stack contains the label during their +lifetime. In the most common use case, one would construct such an object at the +start of a function, so that its scope is the function scope and it allows to +measure how much time is spent in this function. + +```c++ +#include "ruy/profiler/instrumentation.h" + +... + +void SomeFunction() { + ruy::profiling::ScopeLabel function_label("SomeFunction"); + ... do something ... +} +``` + +A `ScopeLabel` may however have any scope, for instance: + +```c++ +if (some_case) { + ruy::profiling::ScopeLabel extra_work_label("Some more work"); + ... do some more work ... +} +``` + +The string passed to the `ScopeLabel` constructor must be just a pointer to a +literal string (a `char*` pointer). The profiler will assume that these pointers +stay valid until the profile is finalized. + +However, that literal string may be a `printf` format string, and labels may +have up to 4 parameters, of type `int`. For example: + +```c++ +void SomeFunction(int size) { + ruy::profiling::ScopeLabel function_label("SomeFunction (size=%d)", size); + +``` + +### How to run the profiler + +Profiling instrumentation is a no-op unless the preprocessor token +`RUY_PROFILER` is defined, so defining it is the first step when actually +profiling. When building with Bazel, the preferred way to enable that is to pass +this flag on the Bazel command line: + +``` +--define=ruy_profiler=true +``` + +To actually profile a code scope, it is enough to construct a `ScopeProfile` +object, also a RAII helper. It will start the profiler on construction, and on +destruction it will terminate the profiler and report the profile treeview on +standard output by default. Example: + +```c++ +void SomeProfiledBenchmark() { + ruy::profiling::ScopeProfile profile; + + CallSomeInstrumentedCode(); +} +``` + +An example is provided by the `:test` target in the present directory. Run it +with `--define=ruy_profiler=true` as explained above: + +``` +bazel run -c opt \ + --define=ruy_profiler=true \ + //tensorflow/lite/experimental/ruy/profiler:test +``` + +The default behavior dumping the treeview on standard output may be overridden +by passing a pointer to a `TreeView` object to the `ScopeProfile` constructor. +This causes the tree-view to be stored in that `TreeView` object, where it may +be accessed an manipulated using the functions declared in `treeview.h`. The +aforementioned `:test` provides examples for doing so. + +## Advantages and inconvenients + +Compared to a traditional profiler, e.g. Linux's "perf", the present kind of +profiler has the following inconvenients: + +* Requires manual instrumentation of code being profiled. +* Substantial overhead, modifying the performance characteristics of the code + being measured. +* Questionable accuracy. + +But also the following advantages: + +* Profiling can be driven from within a benchmark program, allowing the entire + profiling procedure to be a single command line. +* Not relying on symbol information removes removes exposure to toolchain + details and means less hassle in some build environments, especially + embedded/mobile (single command line to run and profile, no symbols files + required). +* Fully portable (all of this is standard C++11). +* Fully testable (see `:test`). Profiling becomes just another feature of the + code like any other. +* Customized instrumentation can result in easier to read treeviews (only + relevant functions, and custom labels may be more readable than function + names). +* Parametrized/formatted labels allow to do things that aren't possible with + call-stack-sampling profilers. For example, break down a profile where much + time is being spent in matrix multiplications, by the various matrix + multiplication shapes involved. + +The philosophy underlying this profiler is that software performance depends on +software engineers profiling often, and a key factor limiting that in practice +is the difficulty or cumbersome aspects of profiling with more serious profilers +such as Linux's "perf", especially in embedded/mobile development: multiple +command lines are involved to copy symbol files to devices, retrieve profile +data from the device, etc. In that context, it is useful to make profiling as +easy as benchmarking, even on embedded targets, even if the price to pay for +that is lower accuracy, higher overhead, and some intrusive instrumentation +requirement. + +Another key aspect determining what profiling approach is suitable for a given +context, is whether one already has a-priori knowledge of where much of the time +is likely being spent. When one has such a-priori knowledge, it is feasible to +instrument the known possibly-critical code as per the present approach. On the +other hand, in situations where one doesn't have such a-priori knowledge, a real +profiler such as Linux's "perf" allows to right away get a profile of real +stacks, from just symbol information generated by the toolchain. diff --git a/ruy/profiler/instrumentation.cc b/ruy/profiler/instrumentation.cc new file mode 100644 index 0000000..f03f667 --- /dev/null +++ b/ruy/profiler/instrumentation.cc @@ -0,0 +1,130 @@ +/* Copyright 2020 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/profiler/instrumentation.h" + +#ifdef RUY_PROFILER + +namespace ruy { +namespace profiler { + +void Label::operator=(const Label& other) { + format_ = other.format_; + args_count_ = other.args_count_; + for (int i = 0; i < args_count_; i++) { + args_[i] = other.args_[i]; + } +} + +bool Label::operator==(const Label& other) const { + if (std::string(format_) != std::string(other.format_)) { + return false; + } + if (args_count_ != other.args_count_) { + return false; + } + for (int i = 0; i < args_count_; i++) { + if (args_[i] != other.args_[i]) { + return false; + } + } + return true; +} + +std::string Label::Formatted() const { + static constexpr int kBufSize = 256; + char buf[kBufSize]; + if (args_count_ == 0) { + return format_; + } + if (args_count_ == 1) { + snprintf(buf, kBufSize, format_, args_[0]); + } else if (args_count_ == 2) { + snprintf(buf, kBufSize, format_, args_[0], args_[1]); + } else if (args_count_ == 3) { + snprintf(buf, kBufSize, format_, args_[0], args_[1], args_[2]); + } else if (args_count_ == 4) { + snprintf(buf, kBufSize, format_, args_[0], args_[1], args_[2], args_[3]); + } else { + abort(); + } + return buf; +} + +namespace detail { + +std::mutex* GlobalsMutex() { + static std::mutex mutex; + return &mutex; +} + +bool& GlobalIsProfilerRunning() { + static bool b; + return b; +} + +std::vector* GlobalAllThreadStacks() { + static std::vector all_stacks; + return &all_stacks; +} + +ThreadStack* ThreadLocalThreadStack() { + thread_local static ThreadStack thread_stack; + return &thread_stack; +} + +ThreadStack::ThreadStack() { + std::lock_guard lock(*GlobalsMutex()); + static std::uint32_t global_next_thread_stack_id = 0; + stack_.id = global_next_thread_stack_id++; + GlobalAllThreadStacks()->push_back(this); +} + +ThreadStack::~ThreadStack() { + std::lock_guard lock(*GlobalsMutex()); + std::vector* all_stacks = GlobalAllThreadStacks(); + for (auto it = all_stacks->begin(); it != all_stacks->end(); ++it) { + if (*it == this) { + all_stacks->erase(it); + return; + } + } +} +int GetBufferSize(const Stack& stack) { + return sizeof(stack.id) + sizeof(stack.size) + + stack.size * sizeof(stack.labels[0]); +} + +void CopyToBuffer(const Stack& stack, char* dst) { + memcpy(dst, &stack.id, sizeof(stack.id)); + dst += sizeof(stack.id); + memcpy(dst, &stack.size, sizeof(stack.size)); + dst += sizeof(stack.size); + memcpy(dst, stack.labels, stack.size * sizeof(stack.labels[0])); +} + +void ReadFromBuffer(const char* src, Stack* stack) { + memcpy(&stack->id, src, sizeof(stack->id)); + src += sizeof(stack->id); + memcpy(&stack->size, src, sizeof(stack->size)); + src += sizeof(stack->size); + memcpy(stack->labels, src, stack->size * sizeof(stack->labels[0])); +} + +} // namespace detail +} // namespace profiler +} // namespace ruy + +#endif diff --git a/ruy/profiler/instrumentation.h b/ruy/profiler/instrumentation.h new file mode 100644 index 0000000..a9046d4 --- /dev/null +++ b/ruy/profiler/instrumentation.h @@ -0,0 +1,203 @@ +/* Copyright 2020 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PROFILER_INSTRUMENTATION_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PROFILER_INSTRUMENTATION_H_ + +#ifdef RUY_PROFILER +#include +#include +#include +#endif + +namespace ruy { +namespace profiler { + +#ifdef RUY_PROFILER + +// A label is how a code scope is annotated to appear in profiles. +// The stacks that are sampled by the profiler are stacks of such labels. +// A label consists of a literal string, plus optional integer arguments. +class Label { + public: + Label() {} + template + explicit Label(Args... args) { + Set(args...); + } + void Set(const char* format) { + format_ = format; + args_count_ = 0; + } + template + void Set(const char* format, Args... args) { + format_ = format; + args_count_ = sizeof...(args); + SetArgs(0, args...); + } + + void operator=(const Label& other); + + bool operator==(const Label& other) const; + + std::string Formatted() const; + const char* format() const { return format_; } + + private: + void SetArgs(int position, int arg0) { args_[position] = arg0; } + + template + void SetArgs(int position, int arg0, Args... args) { + SetArgs(position, arg0); + SetArgs(position + 1, args...); + } + + static constexpr int kMaxArgs = 4; + const char* format_ = nullptr; + int args_count_ = 0; + int args_[kMaxArgs]; +}; + +namespace detail { + +// Forward-declaration, see class ThreadStack below. +class ThreadStack; + +bool& GlobalIsProfilerRunning(); + +// Returns the global vector of pointers to all stacks, there being one stack +// per thread executing instrumented code. +std::vector* GlobalAllThreadStacks(); + +// Returns the mutex to be locked around any access to GlobalAllThreadStacks(). +std::mutex* GlobalsMutex(); + +// Returns the thread-local stack, specific to the current thread. +ThreadStack* ThreadLocalThreadStack(); + +// This 'stack' is what may be more appropriately called a 'pseudostack': +// It contains Label entries that are 'manually' entered by instrumentation +// code. It's unrelated to real call stacks. +struct Stack { + std::uint32_t id = 0; + static constexpr int kMaxSize = 64; + int size = 0; + Label labels[kMaxSize]; +}; + +// Returns the buffer byte size required by CopyToSample. +int GetBufferSize(const Stack& stack); + +// Copies this Stack into a byte buffer, called a 'sample'. +void CopyToBuffer(const Stack& stack, char* dst); + +// Populates this Stack from an existing sample buffer, typically +// produced by CopyToSample. +void ReadFromBuffer(const char* src, Stack* stack); + +// ThreadStack is meant to be used as a thread-local singleton, assigning to +// each thread a Stack object holding its pseudo-stack of profile labels, +// plus a mutex allowing to synchronize accesses to this pseudo-stack between +// this thread and a possible profiler thread sampling it. +class ThreadStack { + public: + ThreadStack(); + ~ThreadStack(); + + const Stack& stack() const { return stack_; } + + // Returns the mutex to lock around any access to this stack. Each stack is + // accessed by potentially two threads: the thread that it belongs to + // (which calls Push and Pop) and the profiler thread during profiling + // (which calls CopyToSample). + std::mutex& Mutex() const { return mutex_; } + + // Pushes a new label on the top of this Stack. + template + void Push(Args... args) { + // This mutex locking is needed to guard against race conditions as both + // the current thread and the profiler thread may be concurrently accessing + // this stack. In addition to that, this mutex locking also serves the other + // purpose of acting as a barrier (of compiler code reordering, of runtime + // CPU instruction reordering, and of memory access reordering), which + // gives a measure of correctness to this profiler. The downside is some + // latency. As this lock will be uncontended most of the times, the cost + // should be roughly that of an sequentially-consistent atomic access, + // comparable to an access to the level of CPU data cache that is shared + // among all cores, typically 60 cycles on current ARM CPUs, plus side + // effects from barrier instructions. + std::lock_guard lock(mutex_); + // Avoid overrunning the stack, even in 'release' builds. This profiling + // instrumentation code should not ship in release builds anyway, the + // overhead of this check is negligible, and overrunning a stack array would + // be bad. + if (stack_.size >= Stack::kMaxSize) { + abort(); + } + stack_.labels[stack_.size++].Set(args...); + } + + // Pops the top-most label from this Stack. + void Pop() { + // See the comment in Push about this lock. While it would be tempting to + // try to remove this lock and just atomically decrement size_ with a + // store-release, that would not necessarily be a substitute for all of the + // purposes that this lock serves, or if it was done carefully to serve all + // of the same purposes, then that wouldn't be faster than this (mostly + // uncontended) lock. + std::lock_guard lock(mutex_); + stack_.size--; + } + + private: + mutable std::mutex mutex_; + Stack stack_; +}; + +} // namespace detail + +// RAII user-facing way to construct Labels associated with their life scope +// and get them pushed to / popped from the current thread stack. +class ScopeLabel { + public: + template + ScopeLabel(Args... args) : thread_stack_(detail::ThreadLocalThreadStack()) { + thread_stack_->Push(args...); + } + + ~ScopeLabel() { thread_stack_->Pop(); } + + private: + detail::ThreadStack* thread_stack_; +}; + +#else // no RUY_PROFILER + +class ScopeLabel { + public: + template + explicit ScopeLabel(Args...) {} + + // This destructor is needed to consistently silence clang's -Wunused-variable + // which seems to trigger semi-randomly. + ~ScopeLabel() {} +}; + +#endif + +} // namespace profiler +} // namespace ruy + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PROFILER_INSTRUMENTATION_H_ diff --git a/ruy/profiler/profiler.cc b/ruy/profiler/profiler.cc new file mode 100644 index 0000000..ae3a2e2 --- /dev/null +++ b/ruy/profiler/profiler.cc @@ -0,0 +1,109 @@ +/* Copyright 2020 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/profiler/profiler.h" + +#ifdef RUY_PROFILER +#include +#include // NOLINT +#include +#include +#include // NOLINT +#include +#endif + +#include "ruy/profiler/instrumentation.h" +#include "ruy/profiler/treeview.h" + +namespace ruy { +namespace profiler { + +#ifdef RUY_PROFILER + +ScopeProfile::ScopeProfile() { Start(); } +ScopeProfile::ScopeProfile(bool enable) { + if (enable) { + Start(); + } +} +ScopeProfile::~ScopeProfile() { + if (!thread_) { + return; + } + finishing_.store(true); + thread_->join(); + Finish(); +} + +void ScopeProfile::Start() { + { + std::lock_guard lock(*detail::GlobalsMutex()); + if (detail::GlobalIsProfilerRunning()) { + fprintf(stderr, "FATAL: profiler already running!\n"); + abort(); + } + detail::GlobalIsProfilerRunning() = true; + } + finishing_ = false; + thread_.reset(new std::thread(&ScopeProfile::ThreadFunc, this)); +} + +void ScopeProfile::ThreadFunc() { + while (!finishing_.load()) { + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + std::lock_guard lock(*detail::GlobalsMutex()); + auto* thread_stacks = detail::GlobalAllThreadStacks(); + for (detail::ThreadStack* thread_stack : *thread_stacks) { + Sample(*thread_stack); + } + } +} + +void ScopeProfile::Sample(const detail::ThreadStack& thread_stack) { + std::lock_guard lock(thread_stack.Mutex()); + // Drop empty stacks. + // This ensures that profiles aren't polluted by uninteresting threads. + if (thread_stack.stack().size == 0) { + return; + } + int sample_size = detail::GetBufferSize(thread_stack.stack()); + int old_buf_size = samples_buf_.size(); + samples_buf_.resize(old_buf_size + sample_size); + detail::CopyToBuffer(thread_stack.stack(), + samples_buf_.data() + old_buf_size); +} + +void ScopeProfile::Finish() { + { + std::lock_guard lock(*detail::GlobalsMutex()); + if (!detail::GlobalIsProfilerRunning()) { + fprintf(stderr, "FATAL: profiler is not running!\n"); + abort(); + } + detail::GlobalIsProfilerRunning() = false; + } + if (user_treeview_) { + user_treeview_->Populate(samples_buf_); + } else { + TreeView treeview; + treeview.Populate(samples_buf_); + Print(treeview); + } +} + +#endif // RUY_PROFILER + +} // namespace profiler +} // namespace ruy diff --git a/ruy/profiler/profiler.h b/ruy/profiler/profiler.h new file mode 100644 index 0000000..b68ca90 --- /dev/null +++ b/ruy/profiler/profiler.h @@ -0,0 +1,106 @@ +/* Copyright 2020 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PROFILER_PROFILER_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PROFILER_PROFILER_H_ + +#include + +#ifdef RUY_PROFILER +#include +#include +#include +#include +#endif + +#include "ruy/profiler/instrumentation.h" +#include "ruy/profiler/treeview.h" + +namespace ruy { +namespace profiler { + +#ifdef RUY_PROFILER + +// RAII user-facing way to create a profiler and let it profile a code scope, +// and print out an ASCII/MarkDown treeview upon leaving the scope. +class ScopeProfile { + public: + // Default constructor, unconditionally profiling. + ScopeProfile(); + + // Constructor allowing to choose at runtime whether to profile. + explicit ScopeProfile(bool enable); + + // Destructor. It's where the profile is reported. + ~ScopeProfile(); + + // See treeview_ member. + void SetUserTreeView(TreeView* treeview) { user_treeview_ = treeview; } + + private: + void Start(); + + // Thread entry point function for the profiler thread. This thread is + // created on construction. + void ThreadFunc(); + + // Record a stack as a sample. + void Sample(const detail::ThreadStack& stack); + + // Finalize the profile. Called on destruction. + // If user_treeview_ is non-null, it will receive the treeview. + // Otherwise the treeview will just be printed. + void Finish(); + + // Buffer where samples are recorded during profiling. + std::vector samples_buf_; + + // Used to synchronize thread termination. + std::atomic finishing_; + + // Underlying profiler thread, which will perform the sampling. + // This profiler approach relies on a thread rather than on signals. + std::unique_ptr thread_; + + // TreeView to populate upon destruction. If left null (the default), + // a temporary treeview will be used and dumped on stdout. The user + // may override that by passing their own TreeView object for other + // output options or to directly inspect the TreeView. + TreeView* user_treeview_ = nullptr; +}; + +#else // no RUY_PROFILER + +struct ScopeProfile { + ScopeProfile() { +#ifdef GEMMLOWP_PROFILING + fprintf( + stderr, + "\n\n\n**********\n\nWARNING:\n\nLooks like you defined " + "GEMMLOWP_PROFILING, but this code has been ported to the new ruy " + "profiler replacing the old gemmlowp profiler. You should now be " + "defining RUY_PROFILER and not GEMMLOWP_PROFILING. When building using " + "Bazel, just pass --define=ruy_profiler=true.\n\n**********\n\n\n"); +#endif + } + explicit ScopeProfile(bool) {} +}; + +#endif + +} // namespace profiler +} // namespace ruy + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PROFILER_PROFILER_H_ diff --git a/ruy/profiler/test.cc b/ruy/profiler/test.cc new file mode 100644 index 0000000..e94840b --- /dev/null +++ b/ruy/profiler/test.cc @@ -0,0 +1,167 @@ +/* Copyright 2020 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "testing/base/public/gunit.h" +#include "ruy/profiler/profiler.h" +#include "ruy/profiler/test_instrumented_library.h" +#include "ruy/profiler/treeview.h" + +namespace ruy { +namespace profiler { +namespace { + +void DoSomeMergeSort(int size) { + std::vector data(size); + + std::default_random_engine engine; + for (auto& val : data) { + val = engine(); + } + + MergeSort(size, data.data()); +} + +// The purpose of this basic test is to cover the basic path that will be taken +// by a majority of users, not inspecting treeviews but just implicitly printing +// them on stdout, and to have this test enabled even when RUY_PROFILER is not +// defined, so that we have coverage for the non-RUY_PROFILER case. +TEST(ProfilerTest, MergeSortSingleThreadBasicTestEvenWithoutProfiler) { + { + ScopeProfile profile; + DoSomeMergeSort(1 << 20); + } +} + +#ifdef RUY_PROFILER + +TEST(ProfilerTest, MergeSortSingleThread) { + TreeView treeview; + { + ScopeProfile profile; + profile.SetUserTreeView(&treeview); + DoSomeMergeSort(1 << 20); + } + Print(treeview); + EXPECT_EQ(treeview.thread_roots().size(), 1); + const auto& thread_root = *treeview.thread_roots().begin()->second; + EXPECT_EQ(DepthOfTreeBelow(thread_root), 22); + EXPECT_GE( + WeightBelowNodeMatchingUnformatted(thread_root, "Merging sorted halves"), + 0.1 * thread_root.weight); + EXPECT_GE(WeightBelowNodeMatchingFormatted( + thread_root, "MergeSortRecurse (level=20, size=1)"), + 0.01 * thread_root.weight); + + TreeView treeview_collapsed; + CollapseNodesMatchingUnformatted(treeview, 5, "MergeSort (size=%d)", + &treeview_collapsed); + Print(treeview_collapsed); + const auto& collapsed_thread_root = + *treeview_collapsed.thread_roots().begin()->second; + EXPECT_EQ(DepthOfTreeBelow(collapsed_thread_root), 6); + EXPECT_EQ( + WeightBelowNodeMatchingUnformatted(thread_root, "MergeSort (size=%d)"), + WeightBelowNodeMatchingUnformatted(collapsed_thread_root, + "MergeSort (size=%d)")); +} + +TEST(ProfilerTest, MemcpyFourThreads) { + TreeView treeview; + { + ScopeProfile profile; + profile.SetUserTreeView(&treeview); + std::vector> threads; + for (int i = 0; i < 4; i++) { + threads.emplace_back(new std::thread([i]() { + ScopeLabel thread_label("worker thread #%d", i); + std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + ScopeLabel some_more_work_label("some more work"); + std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + })); + } + for (int i = 0; i < 4; i++) { + threads[i]->join(); + } + } + Print(treeview); + // Since we cleared GlobalAllThreadStacks and the current thread hasn't + // created any ScopeLabel, only the 4 worker threads should be recorded. + EXPECT_EQ(treeview.thread_roots().size(), 4); + for (const auto& thread_root : treeview.thread_roots()) { + const TreeView::Node& root_node = *thread_root.second; + // The root node may have 1 or 2 children depending on whether there is + // an "[other]" child. + EXPECT_GE(root_node.children.size(), 1); + EXPECT_LE(root_node.children.size(), 2); + const TreeView::Node& child_node = *root_node.children[0]; + EXPECT_EQ(child_node.label.format(), "worker thread #%d"); + // There must be 2 children, since roughly half the time will be in + // "some more work" leaving the other half in "[other]". + EXPECT_EQ(child_node.children.size(), 2); + const TreeView::Node& child_child_node = *child_node.children[0]; + // Since we sample every millisecond and the threads run for >= 2000 + // milliseconds, the "thread func" label should get roughly 2000 samples. + // Not very rigorous, as we're depending on the profiler thread getting + // scheduled, so to avoid this test being flaky, we use a much more + // conservative value of 500, one quarter of that normal value 2000. + EXPECT_GE(child_node.weight, 500); + // Likewise, allow up to four times more than the normal value 2000. + EXPECT_LE(child_node.weight, 8000); + // Roughly half of time should be spent under the "some more work" label. + float some_more_work_percentage = + 100.f * child_child_node.weight / child_node.weight; + EXPECT_GE(some_more_work_percentage, 40.0f); + EXPECT_LE(some_more_work_percentage, 60.0f); + } +} + +TEST(ProfilerTest, OneThreadAfterAnother) { + TreeView treeview; + { + ScopeProfile profile; + profile.SetUserTreeView(&treeview); + { + std::thread thread([]() { + ScopeLabel thread_label("thread 0"); + std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + }); + thread.join(); + } + { + std::thread thread([]() { + ScopeLabel thread_label("thread 1"); + std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + }); + thread.join(); + } + } + Print(treeview); + EXPECT_EQ(treeview.thread_roots().size(), 2); +} + +#endif // RUY_PROFILER + +} // namespace +} // namespace profiler +} // namespace ruy + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/ruy/profiler/test_instrumented_library.cc b/ruy/profiler/test_instrumented_library.cc new file mode 100644 index 0000000..b017ea9 --- /dev/null +++ b/ruy/profiler/test_instrumented_library.cc @@ -0,0 +1,59 @@ +/* Copyright 2020 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "ruy/profiler/instrumentation.h" + +namespace { + +void MergeSortRecurse(int level, int size, int* data, int* workspace) { + ruy::profiler::ScopeLabel function_label( + "MergeSortRecurse (level=%d, size=%d)", level, size); + if (size <= 1) { + return; + } + int half_size = size / 2; + MergeSortRecurse(level + 1, half_size, data, workspace); + MergeSortRecurse(level + 1, size - half_size, data + half_size, + workspace + half_size); + + ruy::profiler::ScopeLabel merging_sorted_halves_label( + "Merging sorted halves"); + int dst_index = 0; + int left_index = 0; + int right_index = half_size; + while (dst_index < size) { + int val; + if (left_index < half_size && + ((right_index >= size) || data[left_index] < data[right_index])) { + val = data[left_index++]; + } else { + val = data[right_index++]; + } + workspace[dst_index++] = val; + } + for (int i = 0; i < size; i++) { + data[i] = workspace[i]; + } +} + +} // namespace + +void MergeSort(int size, int* data) { + ruy::profiler::ScopeLabel function_label("MergeSort (size=%d)", size); + std::vector workspace(size); + MergeSortRecurse(0, size, data, workspace.data()); +} diff --git a/ruy/profiler/test_instrumented_library.h b/ruy/profiler/test_instrumented_library.h new file mode 100644 index 0000000..53d204e --- /dev/null +++ b/ruy/profiler/test_instrumented_library.h @@ -0,0 +1,23 @@ +/* Copyright 2020 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PROFILER_TEST_INSTRUMENTED_LIBRARY_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PROFILER_TEST_INSTRUMENTED_LIBRARY_H_ + +#include "ruy/profiler/instrumentation.h" + +void MergeSort(int size, int* data); + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PROFILER_TEST_INSTRUMENTED_LIBRARY_H_ diff --git a/ruy/profiler/treeview.cc b/ruy/profiler/treeview.cc new file mode 100644 index 0000000..48d922a --- /dev/null +++ b/ruy/profiler/treeview.cc @@ -0,0 +1,248 @@ +/* Copyright 2020 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifdef RUY_PROFILER + +#include "ruy/profiler/treeview.h" + +#include +#include +#include +#include +#include + +namespace ruy { +namespace profiler { + +namespace { + +void SortNode(TreeView::Node* node) { + using NodePtr = std::unique_ptr; + std::sort(node->children.begin(), node->children.end(), + [](const NodePtr& n1, const NodePtr& n2) { + return n1->weight > n2->weight; + }); + for (const auto& child : node->children) { + SortNode(child.get()); + } +} + +// Records a stack i.e. a sample in a treeview, by incrementing the weights +// of matching existing nodes and/or by creating new nodes as needed, +// recursively, below the given node. +void AddStack(const detail::Stack& stack, TreeView::Node* node, int level) { + node->weight++; + if (stack.size == level) { + return; + } + TreeView::Node* child_to_add_to = nullptr; + for (const auto& child : node->children) { + if (child->label == stack.labels[level]) { + child_to_add_to = child.get(); + break; + } + } + if (!child_to_add_to) { + child_to_add_to = node->children.emplace_back(new TreeView::Node).get(); + child_to_add_to->label = stack.labels[level]; + } + AddStack(stack, child_to_add_to, level + 1); +} + +// Recursively populates the treeview below the given node with 'other' +// entries documenting for each node the difference between its weight and the +// sum of its children's weight. +void AddOther(TreeView::Node* node) { + int top_level_children_weight = 0; + for (const auto& child : node->children) { + AddOther(child.get()); + top_level_children_weight += child->weight; + } + if (top_level_children_weight != 0 && + top_level_children_weight != node->weight) { + const auto& new_child = node->children.emplace_back(new TreeView::Node); + new_child->label = Label("[other]"); + new_child->weight = node->weight - top_level_children_weight; + } +} + +} // namespace + +void TreeView::Populate(const std::vector& samples_buf_) { + thread_roots_.clear(); + // Populate the treeview with regular nodes coming from samples. + const char* buf_ptr = samples_buf_.data(); + const char* const buf_ptr_end = buf_ptr + samples_buf_.size(); + while (buf_ptr < buf_ptr_end) { + detail::Stack stack; + detail::ReadFromBuffer(buf_ptr, &stack); + // Empty stacks should have been dropped during sampling. + assert(stack.size > 0); + buf_ptr += GetBufferSize(stack); + const int id = stack.id; + if (!thread_roots_[id]) { + thread_roots_[id].reset(new Node); + } + AddStack(stack, thread_roots_[id].get(), 0); + } + // Populate the treeview with additional 'other' nodes, sort, and set + // root labels. + for (const auto& thread_root : thread_roots_) { + std::uint32_t id = thread_root.first; + Node* root = thread_root.second.get(); + AddOther(root); + SortNode(root); + root->label.Set("Thread %x (%d samples)", id, root->weight); + } +} + +// Recursively prints the treeview below the given node. The 'root' node +// argument is only needed to compute weights ratios, with the root ratio +// as denominator. +void PrintTreeBelow(const TreeView::Node& node, const TreeView::Node& root, + int level) { + if (&node == &root) { + printf("%s\n\n", node.label.Formatted().c_str()); + } else { + for (int i = 1; i < level; i++) { + printf(" "); + } + printf("* %.2f%% %s\n", 100.0f * node.weight / root.weight, + node.label.Formatted().c_str()); + } + for (const auto& child : node.children) { + PrintTreeBelow(*child, root, level + 1); + } +} + +void Print(const TreeView& treeview) { + printf("\n"); + printf("Profile (%d threads):\n\n", + static_cast(treeview.thread_roots().size())); + for (const auto& thread_root : treeview.thread_roots()) { + const TreeView::Node& root = *thread_root.second; + PrintTreeBelow(root, root, 0); + printf("\n"); + } +} + +int DepthOfTreeBelow(const TreeView::Node& node) { + if (node.children.empty()) { + return 0; + } else { + int max_child_depth = 0; + for (const auto& child : node.children) { + max_child_depth = std::max(max_child_depth, DepthOfTreeBelow(*child)); + } + return 1 + max_child_depth; + } +} + +int WeightBelowNodeMatchingFunction( + const TreeView::Node& node, + const std::function& match) { + int weight = 0; + if (match(node.label)) { + weight += node.weight; + } + for (const auto& child : node.children) { + weight += WeightBelowNodeMatchingFunction(*child, match); + } + return weight; +} + +int WeightBelowNodeMatchingUnformatted(const TreeView::Node& node, + const std::string& format) { + return WeightBelowNodeMatchingFunction( + node, [&format](const Label& label) { return label.format() == format; }); +} + +int WeightBelowNodeMatchingFormatted(const TreeView::Node& node, + const std::string& formatted) { + return WeightBelowNodeMatchingFunction( + node, [&formatted](const Label& label) { + return label.Formatted() == formatted; + }); +} + +void CollapseNode(const TreeView::Node& node_in, int depth, + TreeView::Node* node_out) { + node_out->label = node_in.label; + node_out->weight = node_in.weight; + node_out->children.clear(); + if (depth > 0) { + for (const auto& child_in : node_in.children) { + auto* child_out = new TreeView::Node; + node_out->children.emplace_back(child_out); + CollapseNode(*child_in, depth - 1, child_out); + } + } +} + +void CollapseSubnodesMatchingFunction( + const TreeView::Node& node_in, int depth, + const std::function& match, TreeView::Node* node_out) { + if (match(node_in.label)) { + CollapseNode(node_in, depth, node_out); + } else { + node_out->label = node_in.label; + node_out->weight = node_in.weight; + node_out->children.clear(); + + for (const auto& child_in : node_in.children) { + auto* child_out = new TreeView::Node; + node_out->children.emplace_back(child_out); + CollapseSubnodesMatchingFunction(*child_in, depth, match, child_out); + } + } +} + +void CollapseNodesMatchingFunction( + const TreeView& treeview_in, int depth, + const std::function& match, TreeView* treeview_out) { + treeview_out->mutable_thread_roots()->clear(); + for (const auto& thread_root_in : treeview_in.thread_roots()) { + std::uint32_t id = thread_root_in.first; + const auto& root_in = *thread_root_in.second; + auto* root_out = new TreeView::Node; + treeview_out->mutable_thread_roots()->emplace(id, root_out); + CollapseSubnodesMatchingFunction(root_in, depth, match, root_out); + } +} + +void CollapseNodesMatchingUnformatted(const TreeView& treeview_in, int depth, + const std::string& format, + TreeView* treeview_out) { + CollapseNodesMatchingFunction( + treeview_in, depth, + [&format](const Label& label) { return label.format() == format; }, + treeview_out); +} + +void CollapseNodesMatchingFormatted(const TreeView& treeview_in, int depth, + const std::string& formatted, + TreeView* treeview_out) { + CollapseNodesMatchingFunction( + treeview_in, depth, + [&formatted](const Label& label) { + return label.Formatted() == formatted; + }, + treeview_out); +} + +} // namespace profiler +} // namespace ruy + +#endif // RUY_PROFILER diff --git a/ruy/profiler/treeview.h b/ruy/profiler/treeview.h new file mode 100644 index 0000000..e34b4f9 --- /dev/null +++ b/ruy/profiler/treeview.h @@ -0,0 +1,130 @@ +/* Copyright 2020 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PROFILER_TREEVIEW_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PROFILER_TREEVIEW_H_ + +#ifdef RUY_PROFILER + +#include +#include +#include +#include + +#include "ruy/profiler/instrumentation.h" + +namespace ruy { +namespace profiler { + +// A tree view of a profile. +class TreeView { + public: + struct Node { + std::vector> children; + Label label; + int weight = 0; + }; + + void Populate(const std::vector& samples_buf_); + + // Intentionally an *ordered* map so that threads are enumerated + // in an order that's consistent and typically putting the 'main thread' + // first. + using ThreadRootsMap = std::map>; + + const ThreadRootsMap& thread_roots() const { return thread_roots_; } + ThreadRootsMap* mutable_thread_roots() { return &thread_roots_; } + + private: + ThreadRootsMap thread_roots_; +}; + +/* Below are API functions for manipulating and printing treeviews. */ + +// Prints the treeview to stdout. +void Print(const TreeView& treeview); + +// Prints the treeview below the given node on stdout. +void PrintTreeBelow(const TreeView::Node& node); + +// Returns the tree depth below the given node. +int DepthOfTreeBelow(const TreeView::Node& node); + +// Returns the sum of weights of nodes below the given node and filtered by +// the `match` predicate. +int WeightBelowNodeMatchingFunction( + const TreeView::Node& node, const std::function& match); + +// Returns the sum of weights of nodes below the given node and whose +// unformatted label (i.e. raw format string) matches the given `format` string. +// +// This allows to aggregate nodes whose labels differ only by parameter values. +int WeightBelowNodeMatchingUnformatted(const TreeView::Node& node, + const std::string& format); + +// Returns the sum of weights of nodes below the given node and whose formatted +// label matches the `formatted` string. +// +// In the case of nodes with parametrized labels, this allows to count only +// nodes with specific parameter values. For that purpose, one may also instead +// use WeightBelowNodeMatchingFunction directly, with a `match` predicate +// comparing raw integer parameter values directly, instead of going through +// formatted strings. +int WeightBelowNodeMatchingFormatted(const TreeView::Node& node, + const std::string& formatted); + +// Produces a `node_out` that is a copy of `node_in` but with tree depth below +// it clamped at `depth`, with further subtrees aggregated into single leaf +// nodes. +void CollapseNode(const TreeView::Node& node_in, int depth, + TreeView::Node* node_out); + +// Calls CollapseNode with the given `depth` on every subnode filtered by the +// `match` predicate. Note that this does NOT limit the tree depth below +// `node_out` to `depth`, since each collapsed node below `node_out` may be +// arbitrarily far below it and `depth` is only used as the collapsing depth +// at that point. +void CollapseSubnodesMatchingFunction( + const TreeView::Node& node_in, int depth, + const std::function& match, TreeView::Node* node_out); + +// Calls CollapseNode with the given `depth` on every node filtered by the +// `match` predicate. Note that this does NOT limit the tree depth below +// `node_out` to `depth`, since each collapsed node below `node_out` may be +// arbitrarily far below it and `depth` is only used as the collapsing depth +// at that point. +void CollapseNodesMatchingFunction( + const TreeView& treeview_in, int depth, + const std::function& match, TreeView* treeview_out); + +// Special case of CollapseNodesMatchingFunction matching unformatted labels, +// i.e. raw format strings. +// See the comment on WeightBelowNodeMatchingUnformatted. +void CollapseNodesMatchingUnformatted(const TreeView& treeview_in, int depth, + const std::string& format, + TreeView* treeview_out); + +// Special case of CollapseNodesMatchingFunction matching formatted labels. +// See the comment on WeightBelowNodeMatchingFormatted. +void CollapseNodesMatchingFormatted(const TreeView& treeview_in, int depth, + const std::string& formatted, + TreeView* treeview_out); + +} // namespace profiler +} // namespace ruy + +#endif // RUY_PROFILER + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PROFILER_TREEVIEW_H_ diff --git a/ruy/ruy.h b/ruy/ruy.h new file mode 100644 index 0000000..9cafe14 --- /dev/null +++ b/ruy/ruy.h @@ -0,0 +1,42 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This is the only Ruy header that users should #include. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_RUY_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_RUY_H_ + +#include "ruy/context.h" +#include "ruy/dispatch.h" +#include "ruy/matrix.h" +#include "ruy/path.h" +#include "ruy/spec.h" + +namespace ruy { + +// Performs a multiplication of matrices. This is Ruy's only API entry point. +// Should be self-explanatory given the above documentation for each of Matrix, +// Spec and Context. +template +void Mul(const Matrix& lhs, const Matrix& rhs, + const Spec& spec, Context* context, Matrix* dst) { + DispatchMul( + lhs, rhs, spec, context, dst); +} + +} // namespace ruy + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_RUY_H_ diff --git a/ruy/ruy_advanced.h b/ruy/ruy_advanced.h new file mode 100644 index 0000000..124ddd2 --- /dev/null +++ b/ruy/ruy_advanced.h @@ -0,0 +1,69 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_RUY_ADVANCED_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_RUY_ADVANCED_H_ + +#include +#include + +#include "ruy/context.h" +#include "ruy/matrix.h" +#include "ruy/path.h" +#include "ruy/prepack.h" +#include "ruy/side_pair.h" + +namespace ruy { + +// Low-level, explicit pre-packing API. +// +// The cost of packing an input matrix (either the LHS or RHS) is amortized +// across the non-depth dimension of the opposite input matrix. Thus, when the +// LHS has very few rows or the RHS has very few columns, the cost of packing +// the opposite input matrix can become significant. See pack.h for further +// information on packing. +// +// This file provides an API allowing a user to explicitly pack a matrix and +// reuse the pre-packed matrix, avoiding that cost. +// +// See example_prepack.cc for example usage. + +template +void PrePackForMul(const Matrix& lhs, const Matrix& rhs, + const Spec& spec, Context* context, Matrix* dst, + PrepackedMatrix* prepacked_lhs, + PrepackedMatrix* prepacked_rhs, + std::function alloc_fn) { + SidePair prepacked(prepacked_lhs, prepacked_rhs); + PrePackForMulInternal(lhs, rhs, spec, context, dst, prepacked, + alloc_fn); +} + +template +void MulWithPrepacked(const Matrix& lhs, + const Matrix& rhs, const Spec& spec, + Context* context, Matrix* dst, + PrepackedMatrix* prepacked_lhs, + PrepackedMatrix* prepacked_rhs) { + SidePair prepacked(prepacked_lhs, prepacked_rhs); + MulWithPrepackedInternal(lhs, rhs, spec, context, dst, + prepacked); +} + +} // namespace ruy + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_RUY_ADVANCED_H_ diff --git a/ruy/ruy_test.bzl b/ruy/ruy_test.bzl new file mode 100644 index 0000000..ef7e8b1 --- /dev/null +++ b/ruy/ruy_test.bzl @@ -0,0 +1,34 @@ +# Provides the ruy_test macro for type-parametrized tests. +"""ruy_test is a macro for building a test with multiple paths corresponding to tuples of types for LHS, RHS, accumulator and destination.""" + +def ruy_test(name, srcs, lhs_rhs_accum_dst, copts, tags = [], deps = None): + for (lhs, rhs, accum, dst) in lhs_rhs_accum_dst: + native.cc_test( + name = "%s_%s_%s_%s_%s" % (name, lhs, rhs, accum, dst), + srcs = srcs, + copts = copts + [ + "-DRUY_TEST_LHSSCALAR=%s" % lhs, + "-DRUY_TEST_RHSSCALAR=%s" % rhs, + "-DRUY_TEST_ACCUMSCALAR=%s" % accum, + "-DRUY_TEST_DSTSCALAR=%s" % dst, + ], + deps = deps, + tags = tags, + ) + +def ruy_benchmark(name, srcs, lhs_rhs_accum_dst, copts, deps = None): + tags = ["req_dep=//third_party/gemmlowp:profiler"] + for (lhs, rhs, accum, dst) in lhs_rhs_accum_dst: + native.cc_binary( + name = "%s_%s_%s_%s_%s" % (name, lhs, rhs, accum, dst), + testonly = True, + srcs = srcs, + copts = copts + [ + "-DRUY_TEST_LHSSCALAR=%s" % lhs, + "-DRUY_TEST_RHSSCALAR=%s" % rhs, + "-DRUY_TEST_ACCUMSCALAR=%s" % accum, + "-DRUY_TEST_DSTSCALAR=%s" % dst, + ], + deps = deps, + tags = tags, + ) diff --git a/ruy/ruy_test_ext.bzl b/ruy/ruy_test_ext.bzl new file mode 100644 index 0000000..263121f --- /dev/null +++ b/ruy/ruy_test_ext.bzl @@ -0,0 +1,19 @@ +"""Allows to specialize the ruy BUILD to availability of external libraries""" + +def ruy_test_ext_defines(): + return select({ + "//tools/cc_target_os:windows": [], + "//tools/cc_target_os:wasm": [], + "//tools/cc_target_os:chromiumos": ["RUY_TESTING_ON_CHROMIUMOS"], + "//conditions:default": ["RUY_TEST_EXTERNAL_PATHS"], + }) + +def ruy_test_ext_deps(): + return select({ + "//tools/cc_target_os:windows": [], + "//conditions:default": [ + "//third_party/eigen3", + "//third_party/gemmlowp", + "//third_party/lapack:blas", + ], + }) diff --git a/ruy/ruy_test_ext.bzl.opensource b/ruy/ruy_test_ext.bzl.opensource new file mode 100644 index 0000000..5701fff --- /dev/null +++ b/ruy/ruy_test_ext.bzl.opensource @@ -0,0 +1,7 @@ +"""Allows to specialize the ruy BUILD to availability of external libraries""" + +def ruy_test_ext_defines(): + return [] + +def ruy_test_ext_deps(): + return [] diff --git a/ruy/side_pair.h b/ruy/side_pair.h new file mode 100644 index 0000000..e62968b --- /dev/null +++ b/ruy/side_pair.h @@ -0,0 +1,64 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_SIDE_PAIR_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_SIDE_PAIR_H_ + +#include "ruy/check_macros.h" + +namespace ruy { + +// Enumeration of the sides, i.e. the operands 'slots', in a matrix +// multiplication. The numerical values of these enumeration constants matter +// because these will be used as indices into the array underlying a SidePair. +enum class Side { + // Left-hand side + kLhs = 0, + // Right-hand side + kRhs = 1 +}; + +// SidePair is a pair container where the two elements are indexed by a Side +// enum. +template +class SidePair final { + public: + SidePair() {} + SidePair(const T& a, const T& b) : elem_{a, b} {} + const T& operator[](Side side) const { + const int index = static_cast(side); + // Technically this check is vacuous, since other values would be + // out-of-range for enum Side. + RUY_DCHECK(index == 0 || index == 1); + return elem_[index]; + } + + T& operator[](Side side) { + const int index = static_cast(side); + // Technically this check is vacuous, since other values would be + // out-of-range for enum Side. + RUY_DCHECK(index == 0 || index == 1); + return elem_[index]; + } + + private: + static_assert(static_cast(Side::kLhs) == 0, ""); + static_assert(static_cast(Side::kRhs) == 1, ""); + T elem_[2]; +}; + +} // namespace ruy + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_SIDE_PAIR_H_ diff --git a/ruy/size_util.h b/ruy/size_util.h new file mode 100644 index 0000000..2a4bdb9 --- /dev/null +++ b/ruy/size_util.h @@ -0,0 +1,93 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_SIZE_UTIL_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_SIZE_UTIL_H_ + +#include + +#include "ruy/check_macros.h" + +#ifdef _WIN32 +#include +#endif + +namespace ruy { + +template +inline Integer floor_log2(Integer n) { + static_assert(std::is_integral::value, ""); + static_assert(std::is_signed::value, ""); + static_assert(sizeof(Integer) == 4 || sizeof(Integer) == 8, ""); + + RUY_DCHECK_GE(n, 1); +#ifdef _WIN32 + unsigned long result; // NOLINT[runtime/int] + if (sizeof(Integer) == 4) { + _BitScanReverse(&result, n); + } else { + _BitScanReverse64(&result, n); + } + return result; +#else + if (sizeof(Integer) == 4) { + return 31 - __builtin_clz(n); + } else { + return 63 - __builtin_clzll(n); + } +#endif +} + +template +Integer ceil_log2(Integer n) { + RUY_DCHECK_GE(n, 1); + return n == 1 ? 0 : floor_log2(n - 1) + 1; +} + +template +bool is_pot(Integer value) { + return (value > 0) && ((value & (value - 1)) == 0); +} + +template +Integer pot_log2(Integer n) { + RUY_DCHECK(is_pot(n)); + return floor_log2(n); +} + +template +Integer round_down_pot(Integer value) { + return static_cast(1) << floor_log2(value); +} + +template +Integer round_up_pot(Integer value) { + return static_cast(1) << ceil_log2(value); +} + +template +Integer round_down_pot(Integer value, Modulo modulo) { + RUY_DCHECK_EQ(modulo & (modulo - 1), 0); + return value & ~(modulo - 1); +} + +template +Integer round_up_pot(Integer value, Modulo modulo) { + return round_down_pot(value + modulo - 1, modulo); +} + +} // namespace ruy + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_SIZE_UTIL_H_ diff --git a/ruy/size_util_test.cc b/ruy/size_util_test.cc new file mode 100644 index 0000000..54f0c11 --- /dev/null +++ b/ruy/size_util_test.cc @@ -0,0 +1,101 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/size_util.h" + +#include +#include +#include + +#include "testing/base/public/gunit.h" + +namespace ruy { +namespace { + +template +void SizeUtilTestValue(Integer value) { + if (value == 0) { + return; + } + + EXPECT_LE(0, floor_log2(value)); + EXPECT_LE(floor_log2(value), ceil_log2(value)); + EXPECT_LE(ceil_log2(value), 8 * sizeof(Integer)); + + if (is_pot(value)) { + EXPECT_EQ(floor_log2(value), ceil_log2(value)); + EXPECT_EQ(floor_log2(value), pot_log2(value)); + } else { + EXPECT_EQ(floor_log2(value) + 1, ceil_log2(value)); + } + EXPECT_EQ(value >> floor_log2(value), 1); + EXPECT_EQ(round_down_pot(value), static_cast(1) + << floor_log2(value)); + EXPECT_LE(round_down_pot(value), value); + EXPECT_GE(round_down_pot(value), value >> 1); + EXPECT_TRUE(is_pot(round_down_pot(value))); + + if (ceil_log2(value) < 8 * sizeof(Integer) - 1) { + EXPECT_EQ(value >> ceil_log2(value), is_pot(value) ? 1 : 0); + EXPECT_EQ(round_up_pot(value), static_cast(1) << ceil_log2(value)); + EXPECT_GE(round_up_pot(value), value); + EXPECT_LE(round_up_pot(value) >> 1, value); + EXPECT_TRUE(is_pot(round_up_pot(value))); + } + + for (std::uint8_t modulo : {1, 2, 8, 32, 128}) { + EXPECT_GE(value, round_down_pot(value, modulo)); + EXPECT_EQ(round_down_pot(value, modulo) % modulo, 0); + + if (value <= std::numeric_limits::max() - modulo) { + EXPECT_LE(value, round_up_pot(value, modulo)); + EXPECT_EQ(round_up_pot(value, modulo) % modulo, 0); + } + } +} + +template +void SizeUtilTest() { + for (int exponent = 0; exponent < 8 * sizeof(Integer) - 1; exponent++) { + const Integer pot = static_cast(1) << exponent; + SizeUtilTestValue(pot - 1); + SizeUtilTestValue(pot); + SizeUtilTestValue(pot + 1); + SizeUtilTestValue(pot + 12); + SizeUtilTestValue(pot + 123); + } + SizeUtilTestValue(std::numeric_limits::max() - 1); + SizeUtilTestValue(std::numeric_limits::max()); +} + +TEST(SizeUtilTest, Int) { SizeUtilTest(); } + +TEST(SizeUtilTest, Long) { SizeUtilTest(); } // NOLINT + +TEST(SizeUtilTest, LongLong) { SizeUtilTest(); } // NOLINT + +TEST(SizeUtilTest, Int32) { SizeUtilTest(); } + +TEST(SizeUtilTest, Int64) { SizeUtilTest(); } + +TEST(SizeUtilTest, Ptrdiff) { SizeUtilTest(); } + +} // namespace +} // namespace ruy + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/ruy/spec.h b/ruy/spec.h new file mode 100644 index 0000000..d96b6a9 --- /dev/null +++ b/ruy/spec.h @@ -0,0 +1,118 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_SPEC_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_SPEC_H_ + +#include +#include + +#include "ruy/cpu_cache_size.h" +#include "ruy/matrix.h" + +namespace ruy { + +// Our 'general' loop structure (the default) involves multi-threading and +// complicated loops aiming to optimize cache-friendliness. One may opt out of +// this and pick the 'simple' loop structure instead, which only performs well +// for small matrix sizes and only allows using one thread, in exchange for +// smaller code size. +enum class LoopStructure { kGeneral, kSimple, kAuto }; + +// In general we allow zero_point's to have any Scalar value. This is called +// 'asymmetric' quantization. We do take advantage of the optimization +// opportunities when zero_points happen at runtime to be 'symmetric' (e.g. the +// int8 value 0 or the uint8 value 128), but we still generate code to handle +// the general asymmetric case. By choosing kSymmetric here, one opts out of +// this and supports only the symmetric case, in exchange for smaller code size. +enum class ZeroPointSupport { kGeneral, kSymmetric }; + +// In general we allow all Layout's, even if we may use slow paths for some +// kinds of layouts. By choosing kRCC, one may opt out of this and +// only keep support for the simplest and most efficient combination of +// Layout's, in exchange for smaller code size. The case covered by +// kRCC is where the storage orders are exactly the following: +// - LHS is RowMajor +// - RHS is ColMajor +// - Destination is ColMajor +enum class LayoutSupport { kGeneral, kRCC }; + +// A Spec describes all about a matrix multiplication operation that isn't +// encoded in the LHS, RHS and destination matrices. Some of that information +// is encoded as compile-time constants and types (for instance, the choice +// of accumulator type, AccumScalar). Some of that information is encoded as +// runtime values (for instance, the optional bias vector). +template +struct BasicSpec { + // Accumulator type. The type of accumulators used to compute the dot-products + // before being ultimately casted to the destination type. + using AccumScalar = tAccumScalar; + // The destination scalar type. + using DstScalar = tDstScalar; + // The bias vector data, if not null. + const AccumScalar* bias = nullptr; + // Only for non-floating-point cases. The fixed-point part (i.e. the mantissa) + // of the multiplier by which accumulators are multiplied before being casted + // to the destination type. + AccumScalar multiplier_fixedpoint = 0; + // Only for non-floating-point cases. The exponent part of the aforementioned + // multiplier. + int multiplier_exponent = 0; + // Per-channel variant of multiplier_fixedpoint. If not nullptr, this must + // point to a buffer of as many values as there are rows in the destination + // matrix. Each row of the destination matrix will use the corresponding + // buffer element instead of multiplier_fixedpoint. + const AccumScalar* multiplier_fixedpoint_perchannel = nullptr; + // Per-channel variant of multiplier_exponent. If not nullptr, this must + // point to a buffer of as many values as there are rows in the destination + // matrix. Each row of the destination matrix will use the corresponding + // buffer element instead of multiplier_exponent. + // + // Either none or both of multiplier_exponent_perchannel and + // multiplier_fixedpoint_perchannel must be nullptr. + const int* multiplier_exponent_perchannel = nullptr; + // min clamp bound of destination values. + DstScalar clamp_min = std::is_floating_point::value + ? -std::numeric_limits::infinity() + : std::numeric_limits::lowest(); + // max clamp bound of destination values. + DstScalar clamp_max = std::is_floating_point::value + ? std::numeric_limits::infinity() + : std::numeric_limits::max(); + // See above enum LoopStructure + static constexpr LoopStructure kLoopStructure = LoopStructure::kAuto; + // See above enum LayoutSupport + static constexpr LayoutSupport kLayoutSupport = LayoutSupport::kGeneral; + // See above enum ZeroPointSupport + static constexpr ZeroPointSupport kZeroPointSupport = + ZeroPointSupport::kGeneral; + // Testing-only, not meant to be used by actual users: + // Used for testing of various kernel layouts. + using StandardCppKernelLhsLayout = FixedKernelLayout; + using StandardCppKernelRhsLayout = FixedKernelLayout; + // Returns (a reasonable estimate of) the local CPU cache size. + // See ruy::LocalDataCacheSize() which returns some coarse, sane default for + // each CPU architecture. + // This may be overridden, either to provide more accurate/runtime values, + // or to test with other values to let testcases have more coverage. + static int local_data_cache_size() { return LocalDataCacheSize(); } + // Same as local_data_cache_size but for the total data cache size accessible + // to each CPU core. See ruy::SharedDataCacheSize(). + static int shared_data_cache_size() { return SharedDataCacheSize(); } +}; + +} // namespace ruy + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_SPEC_H_ diff --git a/ruy/test.h b/ruy/test.h new file mode 100644 index 0000000..649a0d9 --- /dev/null +++ b/ruy/test.h @@ -0,0 +1,2125 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TEST_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TEST_H_ + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "testing/base/public/gunit.h" // IWYU pragma: export +#include "ruy/matrix.h" // IWYU pragma: export +#include "ruy/platform.h" +#include "ruy/pmu.h" +#include "ruy/ruy.h" +#include "ruy/ruy_advanced.h" +#include "ruy/spec.h" // IWYU pragma: export +#include "ruy/time.h" + +#ifdef RUY_TEST_EXTERNAL_PATHS +#define EIGEN_USE_THREADS +#define EIGEN_USE_CUSTOM_THREAD_POOL +#include "third_party/eigen3/Eigen/Core" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "third_party/gemmlowp/public/gemmlowp.h" +#include "third_party/lapack/blas.h" +#endif + +#ifdef RUY_PROFILER +#include "ruy/profiler/profiler.h" +#endif + +namespace ruy { + +const float kClampRatio = 0.1f; + +enum class ExternalPath { kNone, kGemmlowp, kEigen, kEigenTensor, kOpenBlas }; + +inline std::vector* CoveredPaths() { + static std::vector covered_paths; + return &covered_paths; +} + +inline const char* PathName(Path path) { +#define RUY_PATHNAME_CASE(NAME) \ + case Path::NAME: \ + return #NAME; + switch (path) { + RUY_PATHNAME_CASE(kReference) + RUY_PATHNAME_CASE(kStandardCpp) +#if RUY_PLATFORM(NEON) + RUY_PATHNAME_CASE(kNeon) + RUY_PATHNAME_CASE(kNeonDotprod) +#elif RUY_PLATFORM(X86) + RUY_PATHNAME_CASE(kSse42) + RUY_PATHNAME_CASE(kAvx2) + RUY_PATHNAME_CASE(kAvx512) + RUY_PATHNAME_CASE(kAvxVnni) +#endif + default: + RUY_CHECK(false); + return nullptr; + } +#undef RUY_PATHNAME_CASE +} + +inline const char* TuningName(Tuning tuning) { +#define RUY_SUBPATHNAME_CASE(NAME) \ + case Tuning::NAME: \ + return #NAME; + switch (tuning) { + RUY_SUBPATHNAME_CASE(kInOrder) + RUY_SUBPATHNAME_CASE(kOutOfOrder) + default: + RUY_CHECK(false); + return nullptr; + } +#undef RUY_SUBPATHNAME_CASE +} + +inline const char* PathName(ExternalPath path) { +#define RUY_PATHNAME_CASE(NAME) \ + case ExternalPath::NAME: \ + return #NAME; + switch (path) { + RUY_PATHNAME_CASE(kGemmlowp) + RUY_PATHNAME_CASE(kEigen) + RUY_PATHNAME_CASE(kEigenTensor) + RUY_PATHNAME_CASE(kOpenBlas) + default: + RUY_CHECK(false); + return nullptr; + } +#undef RUY_PATHNAME_CASE +} + +inline std::ostream& operator<<(std::ostream& stream, Path path) { + return stream << PathName(path); +} + +inline std::ostream& operator<<(std::ostream& stream, + ExternalPath external_path) { + return stream << PathName(external_path); +} + +template +std::string Join(const ContainerType& container) { + if (container.empty()) { + return ""; + } + std::ostringstream stream; + auto it = container.begin(); + stream << *it++; + for (; it != container.end(); ++it) { + stream << ", "; + stream << *it; + } + return stream.str(); +} + +struct LogCoveredPathsOnDestruction final { + ~LogCoveredPathsOnDestruction() { + std::cerr << "Covered paths: " << Join(*CoveredPaths()) << std::endl; + + // When testing on ARM64 ChromiumOS emulator, make sure that we covered + // the dotprod path. We're getting such coverage at the moment thanks to + // using a sufficiently recent emulator, and we don't want to regress that. +#if RUY_PLATFORM(ARM_64) && defined RUY_TESTING_ON_CHROMIUMOS + bool found_dotprod = false; + for (const std::string& covered_path : *CoveredPaths()) { + if (covered_path == "kNeonDotprod") { + found_dotprod = true; + } + } + if (!found_dotprod) { + std::cerr + << "Error: we haven't tested the kNeonDotprod path as we should " + "have. At the moment, this is required on ChromiumOS as this is " + "what we run emulator tests in, that currently supports " + "dot-product " + "instructions, and we care very much about not regressing that. " + "If this test was run in an emulator, please upgrade to a newer " + "emulator version. If this test was run on an actual device, and " + "you need to be able to run ruy tests on devices not supporting " + "dot-product instructions, get in touch with us.\n" + << std::endl; + abort(); + } +#endif + } + static void Singleton() { static LogCoveredPathsOnDestruction singleton; } +}; + +enum class RandomRange { + kGeneral, + kAvoidMinValue, + kOffCenterAvoidMinValue, + kReasonableSrcZeroPoint, + kReasonableDstZeroPoint, + kBias +}; + +template ::value> +struct RandomRangeBounds {}; + +template +struct RandomRangeBounds { + static Scalar GetMinBound(RandomRange range) { + switch (range) { + case RandomRange::kGeneral: + return -1; + case RandomRange::kAvoidMinValue: + return -1; + case RandomRange::kOffCenterAvoidMinValue: + return -1; + case RandomRange::kReasonableSrcZeroPoint: + return 0; + case RandomRange::kReasonableDstZeroPoint: + return 0; + case RandomRange::kBias: + return -1; + default: + RUY_CHECK(false); + return 0; + } + } + static Scalar GetMaxBound(RandomRange range) { + switch (range) { + case RandomRange::kGeneral: + return 1; + case RandomRange::kAvoidMinValue: + return 1; + case RandomRange::kOffCenterAvoidMinValue: + return 1; + case RandomRange::kReasonableSrcZeroPoint: + return 0; + case RandomRange::kReasonableDstZeroPoint: + return 0; + case RandomRange::kBias: + return 1; + default: + RUY_CHECK(false); + return 0; + } + } +}; + +template +Scalar WeightedSum(Scalar s1, float weight1, Scalar s2, float weight2) { + float sum = s1 * weight1 + s2 * weight2; + float clamped = std::min( + std::numeric_limits::max(), + std::max(std::numeric_limits::lowest(), sum)); + return static_cast(clamped); +} + +template +Scalar Parametrized(float param) { + return WeightedSum(std::numeric_limits::max(), param, + std::numeric_limits::lowest(), 1 - param); +} + +template +struct RandomRangeBounds { + static Scalar GetMinBound(RandomRange range) { + static constexpr double offcenteredness = + 0.02; // Shift lower limit by about 5 for range of 255. + switch (range) { + case RandomRange::kGeneral: + return std::numeric_limits::lowest(); + case RandomRange::kAvoidMinValue: + return 1 + std::numeric_limits::lowest(); + case RandomRange::kOffCenterAvoidMinValue: + return 1 + std::numeric_limits::lowest() + + static_cast( + offcenteredness * std::numeric_limits::max() - + offcenteredness * + (std::numeric_limits::lowest() + 1)); + case RandomRange::kReasonableSrcZeroPoint: + return std::numeric_limits::lowest(); + case RandomRange::kReasonableDstZeroPoint: + return Parametrized(0.4); + case RandomRange::kBias: + return std::is_same::value + ? static_cast(-10000) + : 0; + default: + RUY_CHECK(false); + return 0; + } + } + static Scalar GetMaxBound(RandomRange range) { + switch (range) { + case RandomRange::kGeneral: + return std::numeric_limits::max(); + case RandomRange::kAvoidMinValue: + return std::numeric_limits::max(); + case RandomRange::kOffCenterAvoidMinValue: + return std::numeric_limits::max(); + case RandomRange::kReasonableSrcZeroPoint: + return std::numeric_limits::max(); + case RandomRange::kReasonableDstZeroPoint: + return Parametrized(0.6); + case RandomRange::kBias: + return std::is_same::value + ? static_cast(10000) + : 0; + default: + RUY_CHECK(false); + return 0; + } + } +}; + +inline std::default_random_engine& global_random_engine() { + static std::default_random_engine engine; + return engine; +} + +template +struct UniformRandomDistribution { + UniformRandomDistribution(RandomRange range) + : dist(RandomRangeBounds::GetMinBound(range), + RandomRangeBounds::GetMaxBound(range)) {} + Scalar Get() { return dist(global_random_engine()); } + // std::uniform_int_distribution is specified not to support char types, + // only short and wider types. MSVC actually generates an error on + // std::uniform_int_distribution. + using StdDistType = typename std::conditional< + std::is_floating_point::value, + std::uniform_real_distribution, + std::uniform_int_distribution>::type; + StdDistType dist; +}; + +template +void MakeRandomScalar(UniformRandomDistribution* uniform_dist, + Scalar* dst) { + *dst = uniform_dist->Get(); +} + +template +void MakeRandomVector(UniformRandomDistribution* uniform_dist, int size, + std::vector* dst) { + dst->resize(size); + for (auto& x : *dst) { + MakeRandomScalar(uniform_dist, &x); + } +} + +template +void MakeRandomScalar(RandomRange range, Scalar* dst) { + UniformRandomDistribution dist(range); + *dst = dist.Get(); + if (range == RandomRange::kReasonableDstZeroPoint || + range == RandomRange::kReasonableSrcZeroPoint) { + if (global_random_engine()() & 1) { + *dst = SymmetricZeroPoint(); + } + } +} + +template +void MakeRandomVector(RandomRange range, int size, std::vector* dst) { + UniformRandomDistribution dist(range); + dst->resize(size); + for (auto& x : *dst) { + MakeRandomScalar(&dist, &x); + } +} + +enum class LayoutStyle { kPackedLinear, kLinear }; + +inline void MakeLayout(int rows, int cols, Order order, + LayoutStyle layout_style, Layout* layout) { + layout->rows = rows; + layout->cols = cols; + layout->order = order; + + const int packed_stride = order == Order::kColMajor ? rows : cols; + + RUY_CHECK(layout_style == LayoutStyle::kPackedLinear || + layout_style == LayoutStyle::kLinear); + if (layout_style == LayoutStyle::kPackedLinear) { + layout->stride = packed_stride; + } else { + layout->stride = packed_stride + 1; + } +} + +template +struct StorageMatrix { + StorageMatrix() = default; + StorageMatrix(const StorageMatrix&) = delete; + void operator=(const StorageMatrix&) = delete; + std::vector data; + Matrix matrix; +}; + +template +void VerifyConsistentFields(const StorageMatrix& storage_matrix) { + if (storage_matrix.data.empty()) { + RUY_CHECK_EQ(storage_matrix.matrix.data.get(), nullptr); + RUY_CHECK_EQ(storage_matrix.matrix.layout.rows, 0); + RUY_CHECK_EQ(storage_matrix.matrix.layout.cols, 0); + } else { + RUY_CHECK_EQ(storage_matrix.matrix.data.get(), storage_matrix.data.data()); + RUY_CHECK_EQ(FlatSize(storage_matrix.matrix.layout), + storage_matrix.data.size()); + } +} + +template +void MakeRandom(int rows, int cols, Order order, Scalar zero_point, + LayoutStyle layout_style, RandomRange range, + StorageMatrix* storage_matrix) { + MakeLayout(rows, cols, order, layout_style, &storage_matrix->matrix.layout); + storage_matrix->matrix.zero_point = zero_point; + UniformRandomDistribution data_dist(range); + MakeRandomVector(&data_dist, FlatSize(storage_matrix->matrix.layout), + &storage_matrix->data); + storage_matrix->matrix.data = storage_matrix->data.data(); + VerifyConsistentFields(*storage_matrix); +} + +template +struct TestResult { + void operator=(const TestResult&) = delete; + void operator=(const TestResult&&) = delete; + StorageMatrix storage_matrix; + Path path = Path::kNone; + Tuning tuning = Tuning::kAuto; + ExternalPath external_path = ExternalPath::kNone; + float latency; + float l1_refill_rate; + float l2_refill_rate; + float l3_refill_rate; + float l1tlb_refill_rate; + float l2tlb_refill_rate; + float mispred_rate; + float frontend_stall_rate; + float backend_stall_rate; + + // Per-path data for pre-packing. + // This is not used by external paths or by Path::kReference. + Allocator allocator; + PrepackedMatrix prepacked_lhs; + PrepackedMatrix prepacked_rhs; + bool use_prepacked_lhs = false; + bool use_prepacked_rhs = false; +}; + +template +std::string PathName(const TestResult& result) { + std::string pathname; + if (result.path != Path::kNone) { + pathname.assign(PathName(result.path)); + } else if (result.external_path != ExternalPath::kNone) { + pathname.assign(PathName(result.external_path)); + } else { + RUY_CHECK(false); + } + if (result.tuning != Tuning::kAuto) { + pathname.append("/"); + pathname.append(TuningName(result.tuning)); + } + return pathname; +} + +enum class ExpectedOutcome { kSuccess, kDeath }; + +template +struct TestSet final { + using LhsScalar = tLhsScalar; + using RhsScalar = tRhsScalar; + using AccumScalar = typename SpecType::AccumScalar; + using DstScalar = typename SpecType::DstScalar; + using Spec = SpecType; + using TestResultType = TestResult; + + void Run() { + MakeZeroPoints(); + MakeLhsRhs(); + MakeSpec(); + MakeOtherParams(); + MakeResultPaths(); + MakePrepackedMatrices(); + Eval(); + Verify(); + } + + private: + void MakeZeroPoints(); + void MakeLhsRhs(); + void MakeSpec(); + void MakeResultPaths(); + void MakePrepackedMatrices(); + void MakeOtherParams(); + void EvalAndVerify(); + void Eval(); + void Verify(); + + void EvalResult(TestResultType* result); + void EvalRuy(TestResultType* result); + void DoMul(TestResultType* result); + void Benchmark(TestResultType* result); + void VerifyTestResults() const; + + public: + enum class LifeStage { + kInitial, + kHasZeroPoints, + kHasLhsRhs, + kHasSpec, + kHasOtherParams, + kHasResultPaths, + kHasPrepackedMatrices, + kEvaluated, + kFinal + }; + + ~TestSet() { + RUY_CHECK_EQ(life_stage, LifeStage::kFinal); + LogCoveredPathsOnDestruction::Singleton(); + } + + LifeStage life_stage = LifeStage::kInitial; + + int rows = 0; + int cols = 0; + int depth = 0; + Order lhs_order = Order::kRowMajor; + Order rhs_order = Order::kColMajor; + Order dst_order = Order::kColMajor; + LayoutStyle layout_style = LayoutStyle::kPackedLinear; + ExpectedOutcome expected_outcome = ExpectedOutcome::kSuccess; + + bool use_specified_zero_points = false; + LhsScalar lhs_zero_point = 0; + RhsScalar rhs_zero_point = 0; + DstScalar dst_zero_point = 0; + + std::vector per_channel_multiplier_fixedpoint; + std::vector per_channel_multiplier_exponent; + + StorageMatrix lhs; + StorageMatrix rhs; + Spec spec; + std::vector bias_data; + std::vector> results; + + std::vector paths; + std::vector external_paths; + + bool benchmark = false; + bool perchannel = false; + int max_num_threads = 0; + bool benchmark_prepack_lhs = false; + bool benchmark_prepack_rhs = false; +}; + +inline PmuEvents& GlobalPmuEvents() { + static PmuEvents pmu; + return pmu; +} + +inline Context& GlobalContext() { + // Ensure that GlobalPmuEvents is constructed before we create any context. + // This ensures that pmu counters are opened before we create any worker + // thread, which is necessary to count events from worker threads. + GlobalPmuEvents(); + + static Context context; + return context; +} + +#if defined(__has_feature) +#if __has_feature(thread_sanitizer) +#define RUY_TSAN +#endif +#if __has_feature(address_sanitizer) +#define RUY_ASAN +#endif +#endif // defined(__has_feature) + +template +void TestSet::DoMul(TestResultType* result) { + Context* context = &GlobalContext(); + + if (!result->use_prepacked_lhs && !result->use_prepacked_rhs) { + Mul(lhs.matrix, rhs.matrix, spec, context, + &result->storage_matrix.matrix); + return; + } + + // If we prepacked an input matrix, null out its data pointer to check + // that we don't access any data through it. + Matrix null_data_lhs = lhs.matrix; + Matrix null_data_rhs = rhs.matrix; + if (result->use_prepacked_lhs) { + null_data_lhs.data = nullptr; + } + if (result->use_prepacked_rhs) { + null_data_rhs.data = nullptr; + } + + // Do the multiplication with pre-packed matrices. + PrepackedMatrix* prepacked_lhs_ptr = + result->use_prepacked_lhs ? &result->prepacked_lhs : nullptr; + PrepackedMatrix* prepacked_rhs_ptr = + result->use_prepacked_rhs ? &result->prepacked_rhs : nullptr; + MulWithPrepacked(null_data_lhs, null_data_rhs, spec, context, + &result->storage_matrix.matrix, prepacked_lhs_ptr, + prepacked_rhs_ptr); +} + +// When building for WAsm, ASSERT_DEATH is not defined. +#ifdef ASSERT_DEATH +#define RUY_ASSERT_DEATH(CONDITION, MESSAGE) ASSERT_DEATH(CONDITION, MESSAGE) +#else +#define RUY_ASSERT_DEATH(CONDITION, MESSAGE) +#endif + +template +void TestSet::EvalRuy(TestResultType* result) { + GlobalContext().explicit_tuning = result->tuning; + if (max_num_threads) { + GlobalContext().max_num_threads = max_num_threads; + } else if (benchmark) { + GlobalContext().max_num_threads = 1; + } else { + GlobalContext().max_num_threads = 1 + global_random_engine()() % 8; + } + GlobalContext().SetRuntimeEnabledPaths(result->path); + if (expected_outcome == ExpectedOutcome::kSuccess) { + DoMul(result); + RUY_CHECK_EQ(GlobalContext().last_taken_path, result->path); + } else if (expected_outcome == ExpectedOutcome::kDeath) { + // TODO(benoitjacob) TSan and ASan seem to be breaking ASSERT_DEATH. + // Report a bug? +#if (!defined NDEBUG) && (!defined RUY_ASAN) && (!defined RUY_TSAN) + RUY_ASSERT_DEATH(DoMul(result), ""); +#endif + } else { + RUY_CHECK(false); + } + GlobalContext().explicit_tuning = Tuning::kAuto; + GlobalContext().max_num_threads = 1; +} + +#ifdef RUY_TEST_EXTERNAL_PATHS + +template +void WrapGemmlowp(const Matrix& src, + gemmlowp::MatrixMap* dst) { + RUY_CHECK(src.layout.order == (tOrder == gemmlowp::MapOrder::ColMajor + ? Order::kColMajor + : Order::kRowMajor)); + *dst = gemmlowp::MatrixMap( + src.data.get(), src.layout.rows, src.layout.cols, src.layout.stride); +} + +template +void WrapGemmlowpMutable(Matrix* src, + gemmlowp::MatrixMap* dst) { + RUY_CHECK(src->layout.order == (tOrder == gemmlowp::MapOrder::ColMajor + ? Order::kColMajor + : Order::kRowMajor)); + *dst = gemmlowp::MatrixMap( + src->data.get(), src->layout.rows, src->layout.cols, src->layout.stride); +} + +template +struct GemmlowpOrder {}; + +template <> +struct GemmlowpOrder { + static constexpr gemmlowp::MapOrder kValue = gemmlowp::MapOrder::ColMajor; +}; + +template <> +struct GemmlowpOrder { + static constexpr gemmlowp::MapOrder kValue = gemmlowp::MapOrder::RowMajor; +}; + +inline gemmlowp::GemmContext& GlobalGemmlowpContext() { + static gemmlowp::GemmContext context; + return context; +} + +template +void EvalGemmlowp(const Matrix& lhs, const Matrix& rhs, + const Spec& spec, int max_num_threads, + Matrix* dst) { + static constexpr gemmlowp::MapOrder kGemmlowpLhsOrder = + GemmlowpOrder::kValue; + static constexpr gemmlowp::MapOrder kGemmlowpRhsOrder = + GemmlowpOrder::kValue; + static constexpr gemmlowp::MapOrder kGemmlowpDstOrder = + GemmlowpOrder::kValue; + gemmlowp::MatrixMap gemmlowp_lhs; + gemmlowp::MatrixMap gemmlowp_rhs; + gemmlowp::MatrixMap gemmlowp_dst; + WrapGemmlowp(lhs, &gemmlowp_lhs); + WrapGemmlowp(rhs, &gemmlowp_rhs); + WrapGemmlowpMutable(dst, &gemmlowp_dst); + + gemmlowp::OutputStageScaleInt32ByFixedPointAndExponent quantize_down_stage; + quantize_down_stage.result_offset_after_shift = dst->zero_point; + quantize_down_stage.result_fixedpoint_multiplier = spec.multiplier_fixedpoint; + quantize_down_stage.result_exponent = spec.multiplier_exponent; + gemmlowp::OutputStageScaleInt32ByFixedPointAndExponentPC< + gemmlowp::VectorShape::Col> + quantize_down_stage_pc; + quantize_down_stage_pc.result_offset_after_shift = dst->zero_point; + using ColVectorMap = + gemmlowp::VectorMap; + quantize_down_stage_pc.result_fixedpoint_multiplier = + ColVectorMap(spec.multiplier_fixedpoint_perchannel, lhs.layout.rows); + quantize_down_stage_pc.result_exponent = + ColVectorMap(spec.multiplier_exponent_perchannel, lhs.layout.rows); + + gemmlowp::OutputStageClamp clamp_stage; + clamp_stage.min = spec.clamp_min; + clamp_stage.max = spec.clamp_max; + using OutputStageSaturatingCast = typename std::conditional< + std::is_same::value, + gemmlowp::OutputStageSaturatingCastToUint8, + gemmlowp::OutputStageSaturatingCastToInt16>::type; + OutputStageSaturatingCast saturating_cast_stage; + + GlobalGemmlowpContext().set_max_num_threads(max_num_threads ? max_num_threads + : 1); + if (spec.bias) { + using ColVectorMap = + gemmlowp::VectorMap; + gemmlowp::OutputStageBiasAddition bias_add_stage; + bias_add_stage.bias_vector = ColVectorMap(spec.bias, dst->layout.rows); +#ifndef GEMMLOWP_SSE4 // gemmlowp perchannel stuff does not build on SSE + if (spec.multiplier_exponent_perchannel) { + const auto& output_pipeline = + std::make_tuple(bias_add_stage, quantize_down_stage_pc, clamp_stage, + saturating_cast_stage); + gemmlowp::GemmWithOutputPipeline< + LhsScalar, DstScalar, gemmlowp::L8R8WithLhsNonzeroBitDepthParams>( + &GlobalGemmlowpContext(), gemmlowp_lhs, gemmlowp_rhs, &gemmlowp_dst, + -lhs.zero_point, -rhs.zero_point, output_pipeline); + } else // NOLINT[readability/braces] +#endif + { + const auto& output_pipeline = + std::make_tuple(bias_add_stage, quantize_down_stage, clamp_stage, + saturating_cast_stage); + gemmlowp::GemmWithOutputPipeline< + LhsScalar, DstScalar, gemmlowp::L8R8WithLhsNonzeroBitDepthParams>( + &GlobalGemmlowpContext(), gemmlowp_lhs, gemmlowp_rhs, &gemmlowp_dst, + -lhs.zero_point, -rhs.zero_point, output_pipeline); + } + } else { +#ifndef GEMMLOWP_SSE4 // gemmlowp perchannel stuff does not build on SSE + if (spec.multiplier_exponent_perchannel) { + const auto& output_pipeline = std::make_tuple( + quantize_down_stage_pc, clamp_stage, saturating_cast_stage); + gemmlowp::GemmWithOutputPipeline< + LhsScalar, DstScalar, gemmlowp::L8R8WithLhsNonzeroBitDepthParams>( + &GlobalGemmlowpContext(), gemmlowp_lhs, gemmlowp_rhs, &gemmlowp_dst, + -lhs.zero_point, -rhs.zero_point, output_pipeline); + } else // NOLINT[readability/braces] +#endif + { + const auto& output_pipeline = std::make_tuple( + quantize_down_stage, clamp_stage, saturating_cast_stage); + gemmlowp::GemmWithOutputPipeline< + LhsScalar, DstScalar, gemmlowp::L8R8WithLhsNonzeroBitDepthParams>( + &GlobalGemmlowpContext(), gemmlowp_lhs, gemmlowp_rhs, &gemmlowp_dst, + -lhs.zero_point, -rhs.zero_point, output_pipeline); + } + } +} + +inline constexpr int Mash(Order LhsOrder, Order RhsOrder, Order DstOrder) { + return (LhsOrder == Order::kRowMajor ? 4 : 0) + + (RhsOrder == Order::kRowMajor ? 2 : 0) + + (DstOrder == Order::kRowMajor ? 1 : 0); +} + +template +void EvalGemmlowp(const Matrix& lhs, const Matrix& rhs, + const Spec& spec, int max_num_threads, + Matrix* dst) { + int index = Mash(lhs.layout.order, rhs.layout.order, dst->layout.order); + switch (index) { +#define EVALGEMMLOWP_CASE3(LHS, RHS, DST) \ + case Mash(LHS, RHS, DST): \ + return EvalGemmlowp(lhs, rhs, spec, max_num_threads, dst); +#define EVALGEMMLOWP_CASE2(LHS, RHS) \ + EVALGEMMLOWP_CASE3(LHS, RHS, Order::kColMajor) \ + EVALGEMMLOWP_CASE3(LHS, RHS, Order::kRowMajor) +#define EVALGEMMLOWP_CASE1(LHS) \ + EVALGEMMLOWP_CASE2(LHS, Order::kColMajor) \ + EVALGEMMLOWP_CASE2(LHS, Order::kRowMajor) + + EVALGEMMLOWP_CASE1(Order::kColMajor) + EVALGEMMLOWP_CASE1(Order::kRowMajor) + +#undef EVALGEMMLOWP_CASE1 +#undef EVALGEMMLOWP_CASE2 +#undef EVALGEMMLOWP_CASE3 + + default: + RUY_CHECK(false); + } +} + +template +struct EigenOrder {}; + +template <> +struct EigenOrder { + static constexpr int kValue = Eigen::ColMajor; +}; + +template <> +struct EigenOrder { + static constexpr int kValue = Eigen::RowMajor; +}; + +template +void EvalEigen(const Matrix& lhs, const Matrix& rhs, + const Spec& spec, int max_num_threads, Matrix* dst) { + RUY_CHECK_EQ(lhs.zero_point, 0); + RUY_CHECK_EQ(rhs.zero_point, 0); + RUY_CHECK_EQ(dst->zero_point, 0); + RUY_CHECK_EQ(spec.multiplier_fixedpoint, 0); + RUY_CHECK_EQ(spec.multiplier_exponent, 0); + + static constexpr int kEigenLhsOrder = EigenOrder::kValue; + static constexpr int kEigenRhsOrder = EigenOrder::kValue; + static constexpr int kEigenDstOrder = EigenOrder::kValue; + + using EigenLhsType = typename Eigen::Matrix:: + template StridedConstMapType>::type; + using EigenRhsType = typename Eigen::Matrix:: + template StridedConstMapType>::type; + using EigenDstType = typename Eigen::Matrix:: + template StridedMapType>::type; + using EigenBiasType = + typename Eigen::Matrix::ConstMapType; + + EigenLhsType eigen_lhs(lhs.data.get(), lhs.layout.rows, lhs.layout.cols, + Eigen::OuterStride(lhs.layout.stride)); + EigenRhsType eigen_rhs(rhs.data.get(), rhs.layout.rows, rhs.layout.cols, + Eigen::OuterStride(rhs.layout.stride)); + EigenDstType eigen_dst( + dst->data.get(), dst->layout.rows, dst->layout.cols, + Eigen::OuterStride(dst->layout.stride)); + Eigen::setNbThreads(max_num_threads ? max_num_threads : 1); + + if (spec.bias) { + EigenBiasType eigen_bias(spec.bias, dst->layout.rows); + if (spec.clamp_max == std::numeric_limits::infinity() && + spec.clamp_min == -std::numeric_limits::infinity()) { + eigen_dst.noalias() = (eigen_lhs * eigen_rhs).colwise() + eigen_bias; + } else { + eigen_dst.noalias() = ((eigen_lhs * eigen_rhs).colwise() + eigen_bias) + .cwiseMin(spec.clamp_max) + .cwiseMax(spec.clamp_min); + } + } else { + if (spec.clamp_max == std::numeric_limits::infinity() && + spec.clamp_min == -std::numeric_limits::infinity()) { + eigen_dst.noalias() = eigen_lhs * eigen_rhs; + } else { + eigen_dst.noalias() = (eigen_lhs * eigen_rhs) + .cwiseMin(spec.clamp_max) + .cwiseMax(spec.clamp_min); + } + } +} + +template +void EvalEigen(const Matrix& lhs, const Matrix& rhs, + const Spec& spec, int max_num_threads, Matrix* dst) { + int index = Mash(lhs.layout.order, rhs.layout.order, dst->layout.order); + switch (index) { +#define EVALEIGEN_CASE3(LHS, RHS, DST) \ + case Mash(LHS, RHS, DST): \ + return EvalEigen(lhs, rhs, spec, max_num_threads, dst); +#define EVALEIGEN_CASE2(LHS, RHS) \ + EVALEIGEN_CASE3(LHS, RHS, Order::kColMajor) \ + EVALEIGEN_CASE3(LHS, RHS, Order::kRowMajor) +#define EVALEIGEN_CASE1(LHS) \ + EVALEIGEN_CASE2(LHS, Order::kColMajor) \ + EVALEIGEN_CASE2(LHS, Order::kRowMajor) + + EVALEIGEN_CASE1(Order::kColMajor) + EVALEIGEN_CASE1(Order::kRowMajor) + +#undef EVALEIGEN_CASE1 +#undef EVALEIGEN_CASE2 +#undef EVALEIGEN_CASE3 + + default: + RUY_CHECK(false); + } +} + +template +void EvalEigenTensor(const Matrix& lhs, const Matrix& rhs, + const Spec& spec, int max_num_threads, + Matrix* dst) { + RUY_CHECK_EQ(lhs.zero_point, 0); + RUY_CHECK_EQ(rhs.zero_point, 0); + RUY_CHECK_EQ(dst->zero_point, 0); + RUY_CHECK_EQ(spec.multiplier_fixedpoint, 0); + RUY_CHECK_EQ(spec.multiplier_exponent, 0); + + // Eigen::TensorMap only supports packed layouts + RUY_CHECK(IsPacked(lhs.layout)); + RUY_CHECK(IsPacked(rhs.layout)); + RUY_CHECK(IsPacked(dst->layout)); + + using TensorLhsType = + Eigen::TensorMap>; + using TensorRhsType = + Eigen::TensorMap>; + using TensorDstType = + Eigen::TensorMap>; + using TensorBiasType = + Eigen::TensorMap>; + + const bool tr = DstOrder == Order::kRowMajor; + const auto& contract_lhs = tr ? rhs : lhs; + const auto& contract_rhs = tr ? lhs : rhs; + + TensorLhsType tensor_lhs( + contract_lhs.data.get(), + LhsOrder == Order::kColMajor ? contract_lhs.layout.rows + : contract_lhs.layout.cols, + LhsOrder == Order::kColMajor ? contract_lhs.layout.cols + : contract_lhs.layout.rows); + TensorRhsType tensor_rhs( + contract_rhs.data.get(), + RhsOrder == Order::kColMajor ? contract_rhs.layout.rows + : contract_rhs.layout.cols, + RhsOrder == Order::kColMajor ? contract_rhs.layout.cols + : contract_rhs.layout.rows); + TensorDstType tensor_dst( + dst->data.get(), + DstOrder == Order::kColMajor ? dst->layout.rows : dst->layout.cols, + DstOrder == Order::kColMajor ? dst->layout.cols : dst->layout.rows); + using DimPair = + typename Eigen::Tensor::DimensionPair; + Eigen::array contract_dims( + {DimPair((LhsOrder == Order::kColMajor) ? 1 : 0, + (RhsOrder == Order::kColMajor) ? 0 : 1)}); + Eigen::array shuffle(DstOrder == Order::kColMajor ? 0 : 1, + DstOrder == Order::kColMajor ? 1 : 0); + static Eigen::ThreadPool pool(max_num_threads ? max_num_threads : 1); + static Eigen::ThreadPoolDevice device(&pool, pool.NumThreads()); + if (spec.bias) { + TensorBiasType tensor_bias(spec.bias, dst->layout.rows); + Eigen::array bias_2d_shape(tr ? 1 : dst->layout.rows, + tr ? dst->layout.rows : 1); + Eigen::array bcast(tr ? dst->layout.cols : 1, + tr ? 1 : dst->layout.cols); + if (spec.clamp_max == std::numeric_limits::infinity() && + spec.clamp_min == -std::numeric_limits::infinity()) { + tensor_dst.device(device) = + tensor_lhs.contract(tensor_rhs, contract_dims); + } else { + tensor_dst.device(device) = + (tensor_lhs.contract(tensor_rhs, contract_dims) + + tensor_bias.reshape(bias_2d_shape).broadcast(bcast)) + .cwiseMin(spec.clamp_max) + .cwiseMax(spec.clamp_min); + } + } else { + if (spec.clamp_max == std::numeric_limits::infinity() && + spec.clamp_min == -std::numeric_limits::infinity()) { + tensor_dst.device(device) = + tensor_lhs.contract(tensor_rhs, contract_dims); + } else { + tensor_dst.device(device) = tensor_lhs.contract(tensor_rhs, contract_dims) + .cwiseMin(spec.clamp_max) + .cwiseMax(spec.clamp_min); + } + } +} + +template +void EvalEigenTensor(const Matrix& lhs, const Matrix& rhs, + const Spec& spec, int max_num_threads, + Matrix* dst) { + int index = Mash(lhs.layout.order, rhs.layout.order, dst->layout.order); + switch (index) { +#define EVALEIGENTENSOR_CASE3(LHS, RHS, DST) \ + case Mash(LHS, RHS, DST): \ + return EvalEigenTensor(lhs, rhs, spec, max_num_threads, dst); +#define EVALEIGENTENSOR_CASE2(LHS, RHS) \ + EVALEIGENTENSOR_CASE3(LHS, RHS, Order::kColMajor) \ + EVALEIGENTENSOR_CASE3(LHS, RHS, Order::kRowMajor) +#define EVALEIGENTENSOR_CASE1(LHS) \ + EVALEIGENTENSOR_CASE2(LHS, Order::kColMajor) \ + EVALEIGENTENSOR_CASE2(LHS, Order::kRowMajor) + + EVALEIGENTENSOR_CASE1(Order::kColMajor) + EVALEIGENTENSOR_CASE1(Order::kRowMajor) + +#undef EVALEIGENTENSOR_CASE1 +#undef EVALEIGENTENSOR_CASE2 +#undef EVALEIGENTENSOR_CASE3 + + default: + RUY_CHECK(false); + } +} + +template +struct GenericBlasGemm {}; + +template <> +struct GenericBlasGemm { + static void Run(char* transa, char* transb, lapack::integer* m, + lapack::integer* n, lapack::integer* k, + lapack::doublereal* alpha, lapack::doublereal* a, + lapack::integer* lda, lapack::doublereal* b, + lapack::integer* ldb, lapack::doublereal* beta, + lapack::doublereal* c, lapack::integer* ldc) { + dgemm_(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + } +}; + +template <> +struct GenericBlasGemm { + static void Run(char* transa, char* transb, lapack::integer* m, + lapack::integer* n, lapack::integer* k, lapack::real* alpha, + lapack::real* a, lapack::integer* lda, lapack::real* b, + lapack::integer* ldb, lapack::real* beta, lapack::real* c, + lapack::integer* ldc) { + sgemm_(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + } +}; + +template +void EvalOpenBlas(const Matrix& lhs, const Matrix& rhs, + const Spec& spec, int max_num_threads, Matrix* dst) { + RUY_CHECK_EQ(lhs.zero_point, 0); + RUY_CHECK_EQ(rhs.zero_point, 0); + RUY_CHECK_EQ(dst->zero_point, 0); + RUY_CHECK_EQ(spec.multiplier_fixedpoint, 0); + RUY_CHECK_EQ(spec.multiplier_exponent, 0); + + Matrix gemm_lhs; + Matrix gemm_rhs; + Matrix gemm_dst; + gemm_dst = *dst; + + // Use Transpose to reduce to the all-column-major case. + // Notice that ruy::Matrix merely holds a pointer, does not own data, + // so Transpose is cheap -- no actual matrix data is being transposed here. + if (dst->layout.order == Order::kColMajor) { + gemm_lhs = lhs; + gemm_rhs = rhs; + } else { + gemm_lhs = rhs; + gemm_rhs = lhs; + Transpose(&gemm_lhs); + Transpose(&gemm_rhs); + Transpose(&gemm_dst); + } + bool transposed_lhs = false; + bool transposed_rhs = false; + + if (gemm_lhs.layout.order == Order::kRowMajor) { + Transpose(&gemm_lhs); + transposed_lhs = true; + } + if (gemm_rhs.layout.order == Order::kRowMajor) { + Transpose(&gemm_rhs); + transposed_rhs = true; + } + + RUY_CHECK_EQ(gemm_lhs.layout.order, Order::kColMajor); + RUY_CHECK_EQ(gemm_rhs.layout.order, Order::kColMajor); + RUY_CHECK_EQ(gemm_dst.layout.order, Order::kColMajor); + + char transa = transposed_lhs ? 'T' : 'N'; + char transb = transposed_rhs ? 'T' : 'N'; + int m = gemm_lhs.layout.rows; + int n = gemm_rhs.layout.cols; + int k = gemm_lhs.layout.cols; + float alpha = 1; + Scalar* a = gemm_lhs.data.get(); + int lda = gemm_lhs.layout.stride; + Scalar* b = gemm_rhs.data.get(); + int ldb = gemm_rhs.layout.stride; + float beta = 0; + Scalar* c = gemm_dst.data.get(); + int ldc = gemm_dst.layout.stride; + GenericBlasGemm::Run(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, + &ldb, &beta, c, &ldc); + + // BLAS does not allow us to express the bias-addition and clamping, so + // we use Eigen for that. + + using EigenDstType = + typename Eigen::Matrix:: + template StridedMapType>::type; + using EigenBiasType = + typename Eigen::Matrix::ConstMapType; + + EigenDstType eigen_dst( + gemm_dst.data.get(), gemm_dst.layout.rows, gemm_dst.layout.cols, + Eigen::OuterStride(gemm_dst.layout.stride)); + Eigen::setNbThreads(max_num_threads ? max_num_threads : 1); + + if (spec.bias) { + EigenBiasType eigen_bias(spec.bias, dst->layout.rows); + if (spec.clamp_max == std::numeric_limits::infinity() && + spec.clamp_min == -std::numeric_limits::infinity()) { + eigen_dst.noalias() = eigen_dst.colwise() + eigen_bias; + } else { + eigen_dst.noalias() = (eigen_dst.colwise() + eigen_bias) + .cwiseMin(spec.clamp_max) + .cwiseMax(spec.clamp_min); + } + } else { + if (spec.clamp_max == std::numeric_limits::infinity() && + spec.clamp_min == -std::numeric_limits::infinity()) { + } else { + eigen_dst.noalias() = + eigen_dst.cwiseMin(spec.clamp_max).cwiseMax(spec.clamp_min); + } + } +} + +template +struct SupportsGemmlowp { + static constexpr bool kValue = + std::is_same::value && + std::is_same::value; +}; + +template +struct UsesSingleScalarType { + static constexpr bool kValue = + std::is_same::value && + std::is_same::value && + std::is_same::value; +}; + +template ::value, + bool EnableGemmlowp = SupportsGemmlowp::kValue, + bool SingleScalarType = UsesSingleScalarType::kValue> +struct EvalExternalPathImpl { + using DstScalar = typename TestSetType::DstScalar; + static void Run(TestSetType*, TestResult*) { RUY_CHECK(false); } +}; + +template +struct EvalExternalPathImpl { + using DstScalar = typename TestSetType::DstScalar; + static void Run(TestSetType* test_set, TestResult* test_result) { + if (test_result->external_path == ExternalPath::kEigen) { + EvalEigen(test_set->lhs.matrix, test_set->rhs.matrix, test_set->spec, + test_set->max_num_threads, &test_result->storage_matrix.matrix); + } else if (test_result->external_path == ExternalPath::kEigenTensor) { + EvalEigenTensor(test_set->lhs.matrix, test_set->rhs.matrix, + test_set->spec, test_set->max_num_threads, + &test_result->storage_matrix.matrix); + } else if (test_result->external_path == ExternalPath::kOpenBlas) { + EvalOpenBlas(test_set->lhs.matrix, test_set->rhs.matrix, test_set->spec, + test_set->max_num_threads, + &test_result->storage_matrix.matrix); + } else { + RUY_CHECK(false); + } + } +}; + +template +struct EvalExternalPathImpl { + using DstScalar = typename TestSetType::DstScalar; + static void Run(TestSetType* test_set, TestResult* test_result) { + if (test_result->external_path == ExternalPath::kGemmlowp) { + EvalGemmlowp(test_set->lhs.matrix, test_set->rhs.matrix, test_set->spec, + test_set->max_num_threads, + &test_result->storage_matrix.matrix); + } else { + RUY_CHECK(false); + } + } +}; + +template +void EvalExternalPath( + TestSetType* test_set, + TestResult* test_result) { + EvalExternalPathImpl::Run(test_set, test_result); +} + +#endif // RUY_TEST_EXTERNAL_PATHS + +template +bool Agree(const Matrix& matrix1, const Matrix& matrix2, + int depth) { + RUY_CHECK_EQ(matrix1.layout.rows, matrix2.layout.rows); + RUY_CHECK_EQ(matrix1.layout.cols, matrix2.layout.cols); + RUY_CHECK_EQ(matrix1.zero_point, matrix2.zero_point); + const int size = matrix1.layout.rows * matrix1.layout.cols; + double tolerated_max_diff = 0; + double tolerated_mean_diff = 0; + if (std::is_floating_point::value) { + // TODO: replace hardcoded 100 by something more sensible, probably + // roughly sqrt(depth) based on central limit theorem. + double max_abs_val = 0; + for (int row = 0; row < matrix1.layout.rows; row++) { + for (int col = 0; col < matrix1.layout.cols; col++) { + max_abs_val = + std::max(max_abs_val, + std::abs(static_cast(Element(matrix1, row, col)))); + max_abs_val = + std::max(max_abs_val, + std::abs(static_cast(Element(matrix2, row, col)))); + } + } + tolerated_max_diff = max_abs_val * std::numeric_limits::epsilon() * + 64 * std::sqrt(static_cast(depth)); + tolerated_mean_diff = tolerated_max_diff / std::sqrt(size); + } else if (RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING)) { + tolerated_max_diff = 1; + // totally empirical + tolerated_mean_diff = std::min(1.0, 2.0 * std::pow(size, -0.2)); + } + double sum_diff = 0; + for (int row = 0; row < matrix1.layout.rows; row++) { + for (int col = 0; col < matrix1.layout.cols; col++) { + double elem1 = Element(matrix1, row, col); + double elem2 = Element(matrix2, row, col); + double diff = elem1 - elem2; + + sum_diff += diff; + // Test (std::abs(diff) > tolerated_max_diff), but also true if diff is + // NaN. + if (!(std::abs(diff) <= tolerated_max_diff)) { + return false; + } + } + } + double mean_diff = sum_diff / size; + if (std::abs(mean_diff) > tolerated_mean_diff) { + return false; + } + return true; +} + +template +bool Agree(const StorageMatrix& storage_matrix1, + const StorageMatrix& storage_matrix2, int depth) { + VerifyConsistentFields(storage_matrix1); + VerifyConsistentFields(storage_matrix2); + return Agree(storage_matrix1.matrix, storage_matrix2.matrix, depth); +} + +template +bool Agree(const TestResult& result1, const TestResult& result2, + int depth) { + return Agree(result1.storage_matrix, result2.storage_matrix, depth); +} + +struct Stats { + double median; + double mean; + double min; + double max; +}; + +inline std::string StatsAsString(const Stats& stats) { + char buf[256]; + snprintf(buf, sizeof(buf), "(median = %g, mean = %g, min = %g, max = %g)", + stats.median, stats.mean, stats.min, stats.max); + return std::string(buf); +} + +template +void GetMatrixStats(const Matrix& matrix, Stats* stats) { + double min = std::numeric_limits::infinity(); + double max = -std::numeric_limits::infinity(); + double sum = 0; + std::vector allvals; + for (int row = 0; row < matrix.layout.rows; row++) { + for (int col = 0; col < matrix.layout.cols; col++) { + double val = Element(matrix, row, col); + min = std::min(min, val); + max = std::max(max, val); + sum += val; + allvals.push_back(val); + } + } + std::sort(allvals.begin(), allvals.end()); + stats->min = min; + stats->max = max; + stats->mean = sum / allvals.size(); + stats->median = allvals[allvals.size() / 2]; +} + +struct ErrorAnalysis { + Stats stats_good; + Stats stats_bad; + // The below is to help document departure from bit exactness. It's probably + // not going to be relevant to floating-point. + std::set error_rows; + std::set error_cols; + int row_of_first_error = 0; + int col_of_first_error = 0; + double first_error_good_value = 0; + double first_error_bad_value = 0; +}; + +template +void AnalyzeTestError(const TestSetType& test_set, int first_bad_result_index, + ErrorAnalysis* error_analysis) { + const auto& good_matrix = test_set.results[0]->storage_matrix.matrix; + const auto& bad_matrix = + test_set.results[first_bad_result_index]->storage_matrix.matrix; + GetMatrixStats(good_matrix, &error_analysis->stats_good); + GetMatrixStats(bad_matrix, &error_analysis->stats_bad); + bool found_first_error = false; + for (int row = 0; row < good_matrix.layout.rows; row++) { + for (int col = 0; col < good_matrix.layout.cols; col++) { + if (Element(good_matrix, row, col) != Element(bad_matrix, row, col)) { + if (!found_first_error) { + found_first_error = true; + error_analysis->row_of_first_error = row; + error_analysis->col_of_first_error = col; + error_analysis->first_error_good_value = + Element(good_matrix, row, col); + error_analysis->first_error_bad_value = Element(bad_matrix, row, col); + } + error_analysis->error_rows.insert(row); + error_analysis->error_cols.insert(col); + } + } + } +} + +template +void ComputeReasonableMultiplier( + const Matrix& lhs, + const Matrix& rhs, double* multiplier) { + using LhsScalar = typename TestSetType::LhsScalar; + using RhsScalar = typename TestSetType::RhsScalar; + using DstScalar = typename TestSetType::DstScalar; + if (std::is_floating_point::value || + std::is_same::value) { + *multiplier = 0; + return; + } + *multiplier = static_cast(std::numeric_limits::max()) / + (static_cast(lhs.layout.cols) * + std::numeric_limits::max() * + std::numeric_limits::max()); +} + +inline void QuantizeMultiplier(double multiplier_double, + std::int32_t* multiplier_fixedpoint, + int* multiplier_exponent) { + RUY_CHECK_GT(multiplier_double, 0); + if (multiplier_double == 0.) { + *multiplier_fixedpoint = 0; + *multiplier_exponent = 0; + return; + } + const double q = std::frexp(multiplier_double, multiplier_exponent); + auto q_fixed = static_cast(std::round(q * (1ll << 31))); + RUY_CHECK_LE(q_fixed, (1ll << 31)); + if (q_fixed == (1ll << 31)) { + q_fixed /= 2; + ++*multiplier_exponent; + } + RUY_CHECK_LE(q_fixed, std::numeric_limits::max()); + *multiplier_fixedpoint = static_cast(q_fixed); +} + +template +void SwitchMultiplierToPerChannel(TestSetType* test_set) { + test_set->per_channel_multiplier_fixedpoint.resize(test_set->rows); + test_set->per_channel_multiplier_exponent.resize(test_set->rows); + for (int i = 0; i < test_set->rows; i++) { + // multipliers typically range in [2^30 ; 2^31 - 1]. + // Values in [0, 2^30 - 1] are normally unused, but harmless. + // Thus a good way to randomize multipliers is to subtract from them + // a random value smaller than 2^30 but still significant compared to it. + std::int32_t nudged_multiplier = test_set->spec.multiplier_fixedpoint - + (global_random_engine()() % (1 << 26)); + int nudged_exponent = + test_set->spec.multiplier_exponent - 1 + (global_random_engine()() % 4); + test_set->per_channel_multiplier_fixedpoint[i] = nudged_multiplier; + test_set->per_channel_multiplier_exponent[i] = nudged_exponent; + } + test_set->spec.multiplier_fixedpoint_perchannel = + test_set->per_channel_multiplier_fixedpoint.data(); + test_set->spec.multiplier_exponent_perchannel = + test_set->per_channel_multiplier_exponent.data(); + test_set->spec.multiplier_fixedpoint = 0; + test_set->spec.multiplier_exponent = 0; +} + +template < + typename TestSetType, + bool IsApplicable = + std::is_same::value && + !std::is_same::value> +struct MakeSpecMultiplierFieldsImpl {}; + +template +struct MakeSpecMultiplierFieldsImpl { + static void Run(TestSetType* test_set) { + double multiplier; + ComputeReasonableMultiplier(test_set->lhs.matrix, + test_set->rhs.matrix, &multiplier); + QuantizeMultiplier(multiplier, &test_set->spec.multiplier_fixedpoint, + &test_set->spec.multiplier_exponent); + if (!test_set->benchmark) { + test_set->perchannel = global_random_engine()() & 1; + } + if (test_set->perchannel) { + SwitchMultiplierToPerChannel(test_set); + } + } +}; + +template +struct MakeSpecMultiplierFieldsImpl { + static void Run(TestSetType* test_set) { + test_set->spec.multiplier_fixedpoint = 0; + test_set->spec.multiplier_exponent = 0; + } +}; + +template +void MakeSpecClampFields(Spec* spec) { + using AccumScalar = typename Spec::AccumScalar; + using DstScalar = typename Spec::DstScalar; + + if (std::is_same::value) { + // Returning raw accumulators, clamping is not supported. + spec->clamp_min = std::numeric_limits::lowest(); + spec->clamp_max = std::numeric_limits::max(); + return; + } + + if (getenv("BENCHMARK_ONLY_MATMUL")) { + if (std::is_floating_point::value) { + spec->clamp_min = -std::numeric_limits::infinity(); + spec->clamp_max = std::numeric_limits::infinity(); + } else { + spec->clamp_min = std::numeric_limits::lowest(); + spec->clamp_max = std::numeric_limits::max(); + } + return; + } + + spec->clamp_min = std::numeric_limits::lowest() + 1; + spec->clamp_max = std::numeric_limits::max() - 1; +} + +template +void TestSet::MakeZeroPoints() { + RUY_CHECK_EQ(life_stage, LifeStage::kInitial); + if (!benchmark && !use_specified_zero_points) { + MakeRandomScalar(RandomRange::kReasonableSrcZeroPoint, &lhs_zero_point); + MakeRandomScalar(RandomRange::kReasonableSrcZeroPoint, &rhs_zero_point); + // If destination is std::int32_t, no dst_zero_point is necessary. + if (std::is_same::value) { + dst_zero_point = 0; + } else { + MakeRandomScalar(RandomRange::kReasonableDstZeroPoint, &dst_zero_point); + } + } + life_stage = LifeStage::kHasZeroPoints; +} + +template +void TestSet::MakeLhsRhs() { + RUY_CHECK_EQ(life_stage, LifeStage::kHasZeroPoints); + MakeRandom(rows, depth, lhs_order, lhs_zero_point, layout_style, + RandomRange::kOffCenterAvoidMinValue, &lhs); + MakeRandom(depth, cols, rhs_order, rhs_zero_point, layout_style, + RandomRange::kGeneral, &rhs); + life_stage = LifeStage::kHasLhsRhs; +} + +template +void TestSet::MakeSpec() { + RUY_CHECK_EQ(life_stage, LifeStage::kHasLhsRhs); + + if (!getenv("BENCHMARK_ONLY_MATMUL") && + (benchmark || (global_random_engine()() & 1))) { + MakeRandomVector(RandomRange::kBias, rows, &bias_data); + spec.bias = bias_data.data(); + } + if (lhs.matrix.zero_point == std::numeric_limits::lowest() && + rhs.matrix.zero_point == std::numeric_limits::lowest()) { + lhs.matrix.zero_point += 1; + } + MakeSpecMultiplierFieldsImpl::Run(this); + MakeSpecClampFields(&spec); + life_stage = LifeStage::kHasSpec; +} + +inline int GetIntEnvVarOrZero(const char* name) { + const char* val = getenv(name); + if (!val) { + return 0; + } + return std::stoi(val); +} + +inline float GetFloatEnvVarOrZero(const char* name) { + const char* val = getenv(name); + if (!val) { + return 0; + } + return std::stof(val); +} + +inline int GetHexIntEnvVarOrZero(const char* name) { + const char* val = getenv(name); + if (!val) { + return 0; + } + return std::stoi(val, nullptr, 16); +} + +inline bool GetBoolEnvVarOrFalse(const char* name) { + return static_cast(GetIntEnvVarOrZero(name)); +} + +template +void TestSet::MakeOtherParams() { + RUY_CHECK_EQ(life_stage, LifeStage::kHasSpec); + if (max_num_threads == 0) { + max_num_threads = GetIntEnvVarOrZero("THREADS"); + } + life_stage = LifeStage::kHasOtherParams; +} + +inline std::vector PathsBitfieldAsVector(Path paths_bitfield) { + std::vector result; + std::uint32_t remaining_paths = static_cast(paths_bitfield); + std::uint32_t test_bit = 1; + while (remaining_paths) { + if (remaining_paths & test_bit) { + result.push_back(static_cast(test_bit)); + } + remaining_paths &= ~test_bit; + test_bit <<= 1; + } + return result; +} + +inline std::vector EnumerateTuningsForPath(Path path, bool benchmark) { + if (benchmark) { + return {Tuning::kAuto}; + } +#if RUY_PLATFORM(ARM) + if (path == Path::kNeon || path == Path::kNeonDotprod) { + return {Tuning::kInOrder, Tuning::kOutOfOrder, Tuning::kAuto}; + } +#endif + return {Tuning::kAuto}; +} + +template +void TestSet::MakePrepackedMatrices() { + RUY_CHECK_EQ(life_stage, LifeStage::kHasResultPaths); + + // Prepacked matrices are Path-dependent, so create them for each test result. + for (auto& result : results) { + // If this result uses an external path, then skip this entirely. + if (result->path == Path::kNone) { + continue; + } + // Pre-packing doesn't make sense for Path::kReference. + // TODO(silvasean): Make Path::kReference an ExternalPath? + if (result->path == Path::kReference) { + continue; + } + + // Determine whether we should create/use prepacked matrices. + if (benchmark) { + // For benchmarking, do as requested. + result->use_prepacked_lhs = benchmark_prepack_lhs; + result->use_prepacked_rhs = benchmark_prepack_rhs; + } else { + // When testing, randomly pre-pack sometimes. But don't do it too often. + result->use_prepacked_lhs = (global_random_engine()() & 7) == 0; + result->use_prepacked_rhs = (global_random_engine()() & 7) == 0; + } + + // Create the pre-packed matrices. + PrepackedMatrix* prepacked_lhs_ptr = + result->use_prepacked_lhs ? &result->prepacked_lhs : nullptr; + PrepackedMatrix* prepacked_rhs_ptr = + result->use_prepacked_rhs ? &result->prepacked_rhs : nullptr; + auto alloc_fn = [&result](std::size_t num_bytes) { + return result->allocator.AllocateBytes(num_bytes); + }; + // Use a dst with a null data pointer to check that the pre-packing + // invocation doesn't write into it. + Matrix null_data_dst = result->storage_matrix.matrix; + null_data_dst.data = nullptr; + GlobalContext().SetRuntimeEnabledPaths(result->path); + PrePackForMul(lhs.matrix, rhs.matrix, spec, &GlobalContext(), + &null_data_dst, prepacked_lhs_ptr, + prepacked_rhs_ptr, alloc_fn); + RUY_CHECK_EQ(GlobalContext().last_taken_path, result->path); + } + + life_stage = LifeStage::kHasPrepackedMatrices; +} + +template +void TestSet::MakeResultPaths() { + RUY_CHECK_EQ(life_stage, LifeStage::kHasOtherParams); + + Path paths_bitfield = static_cast(GetHexIntEnvVarOrZero("PATHS")); + + if (paths_bitfield == Path::kNone) { + // Use a dummy Context just to perform the resolution of specific runtime + // enabled paths. + Context context; + paths_bitfield = context.GetRuntimeEnabledPaths(); + } + + // Trim bits that don't correspond to a compiled path, + // to allow specifying e.g. ffff to mean 'all paths' regardless of whether all + // those bits exist as actual paths. + paths_bitfield = paths_bitfield & kAllPaths; + RUY_CHECK_NE(paths_bitfield, Path::kNone); + paths = PathsBitfieldAsVector(paths_bitfield); + +#ifdef RUY_TEST_EXTERNAL_PATHS + + using TestSetType = TestSet; + + if (!GetBoolEnvVarOrFalse("NOEXT")) { + if (SupportsGemmlowp::kValue) { +#ifdef GEMMLOWP_SSE4 + const bool gemmlowp_supported = !spec.multiplier_fixedpoint_perchannel; +#else + const bool gemmlowp_supported = true; +#endif + if (gemmlowp_supported) { + external_paths.push_back(ExternalPath::kGemmlowp); + } + } + if (UsesSingleScalarType::kValue && + std::is_floating_point::value) { + external_paths.push_back(ExternalPath::kEigen); + if (layout_style == LayoutStyle::kPackedLinear) { + external_paths.push_back(ExternalPath::kEigenTensor); + } +// We link against a generic BLAS target that only maps to OpenBLAS on specific +// architectures. +#if RUY_PLATFORM(ARM_32) || RUY_PLATFORM(ARM_64) + // OpenBLAS multi-threading is disabled, so avoid mixing single-threaded + // and multi-threaded benchmark results. + if (max_num_threads == 1 && !getenv("NO_OPENBLAS")) { + external_paths.push_back(ExternalPath::kOpenBlas); + } +#endif + } + } + +#endif // RUY_TEST_EXTERNAL_PATHS + + for (Path path : paths) { + for (Tuning tuning : EnumerateTuningsForPath(path, benchmark)) { + results.emplace_back(new TestResultType); + TestResultType& result = *results.back(); + result.path = path; + result.tuning = tuning; + MakeRandom(rows, cols, dst_order, dst_zero_point, layout_style, + RandomRange::kGeneral, &result.storage_matrix); + } + } + + for (ExternalPath external_path : external_paths) { + results.emplace_back(new TestResultType); + TestResultType& result = *results.back(); + result.external_path = external_path; + MakeRandom(rows, cols, dst_order, dst_zero_point, layout_style, + RandomRange::kGeneral, &result.storage_matrix); + } + + life_stage = LifeStage::kHasResultPaths; +} + +template +void TestSet::EvalResult( + TestResult* result) { + RUY_CHECK(result->path != Path::kNone || + result->external_path != ExternalPath::kNone); + if (result->path != Path::kNone) { + EvalRuy(result); + } else { +#ifdef RUY_TEST_EXTERNAL_PATHS + using TestSetType = TestSet; + EvalExternalPath(this, result); +#endif + } + const std::string& pathname = PathName(*result); + if (std::find(CoveredPaths()->begin(), CoveredPaths()->end(), pathname) == + CoveredPaths()->end()) { + CoveredPaths()->push_back(pathname); + } +} + +using f32 = float; +using f64 = double; +using u8 = std::uint8_t; +using i8 = std::int8_t; +using u16 = std::uint16_t; +using i16 = std::int16_t; +using u32 = std::uint32_t; +using i32 = std::int32_t; +using u64 = std::uint64_t; +using i64 = std::int64_t; + +template +const char* TypeName() { + return nullptr; +} + +#define RUY_TYPENAME(TYPE) \ + template <> \ + const char* TypeName() { \ + return #TYPE; \ + } + +RUY_TYPENAME(f32) +RUY_TYPENAME(f64) +RUY_TYPENAME(u8) +RUY_TYPENAME(i8) +RUY_TYPENAME(u16) +RUY_TYPENAME(i16) +RUY_TYPENAME(u32) +RUY_TYPENAME(i32) +RUY_TYPENAME(u64) +RUY_TYPENAME(i64) + +#undef RUY_TYPENAME + +template +const char* SymmetryName(const Matrix& matrix) { + if (matrix.zero_point == SymmetricZeroPoint()) { + return "symm"; + } else { + return "asymm"; + } +} + +template +int StorageSize(const Matrix& matrix) { + return sizeof(Scalar) * FlatSize(matrix.layout); +} + +// Helper that replicates a buffer and gives out pointers to the replicas. +// This is useful when one wants to traverse data so that it is cold in cache. +// By having a sufficiently large value of num_repeats, one can ensure that the +// working set covered by the replicas is greater than the cache size. +template +class RepeatedBuffer { + public: + RepeatedBuffer() = default; + void Init(const T* elems, std::size_t num_elems, int num_repeats) { + buffers_.clear(); + allocator_.FreeAll(); + for (int i = 0; i < num_repeats; i++) { + T* p; + allocator_.Allocate(num_elems, &p); + memcpy(p, elems, num_elems * sizeof(T)); + buffers_.push_back(p); + } + } + T* Next() { + T* ret = buffers_[current_]; + current_ = (current_ + 1) % buffers_.size(); + return ret; + } + + private: + Allocator allocator_; + std::vector buffers_; + int current_ = 0; +}; + +template +void TestSet::Benchmark( + TestResult* result) { + using DstScalar = typename SpecType::DstScalar; + + const bool cold = getenv("RUY_BENCHMARK_COLD"); + LhsScalar* orig_lhs_data = lhs.matrix.data.get(); + RhsScalar* orig_rhs_data = rhs.matrix.data.get(); + DstScalar* orig_dst_data = result->storage_matrix.matrix.data.get(); + void* orig_prepacked_lhs_data = result->prepacked_lhs.data; + void* orig_prepacked_rhs_data = result->prepacked_rhs.data; + + int num_matmul_sets = 0; + + RepeatedBuffer cold_lhs; + RepeatedBuffer cold_rhs; + RepeatedBuffer cold_dst; + RepeatedBuffer cold_prepacked_lhs; + RepeatedBuffer cold_prepacked_rhs; + + if (cold) { + const int kWorkingSetSize = 100 << 20; + const int each_matmul_set_size = StorageSize(lhs.matrix) + + StorageSize(rhs.matrix) + + StorageSize(result->storage_matrix.matrix); + num_matmul_sets = + (kWorkingSetSize + each_matmul_set_size - 1) / each_matmul_set_size; + + cold_lhs.Init(lhs.matrix.data.get(), FlatSize(lhs.matrix.layout), + num_matmul_sets); + cold_rhs.Init(rhs.matrix.data.get(), FlatSize(rhs.matrix.layout), + num_matmul_sets); + cold_dst.Init(result->storage_matrix.matrix.data.get(), + FlatSize(result->storage_matrix.matrix.layout), + num_matmul_sets); + if (benchmark_prepack_lhs) { + cold_prepacked_lhs.Init(static_cast(result->prepacked_lhs.data), + result->prepacked_lhs.data_size, num_matmul_sets); + } + if (benchmark_prepack_rhs) { + cold_prepacked_rhs.Init(static_cast(result->prepacked_rhs.data), + result->prepacked_rhs.data_size, num_matmul_sets); + } + } + const bool record_pmu = GetBoolEnvVarOrFalse("RUY_BENCHMARK_PMU"); + int repeats = GetIntEnvVarOrZero("RUY_BENCHMARK_REPEATS"); + if (!repeats) { + repeats = 4; + } + float benchmark_min_secs = GetFloatEnvVarOrZero("RUY_BENCHMARK_MIN_SECS"); + if (!benchmark_min_secs) { + benchmark_min_secs = 0.5; + } +#ifdef RUY_PROFILER + { + const char* lhstype = TypeName(); + const char* lhssymm = SymmetryName(lhs.matrix); + const char* rhstype = TypeName(); + const char* rhssymm = SymmetryName(rhs.matrix); + + printf("Profiling path=%s shape=(%dx%dx%d) lhs=(%s,%s) rhs=(%s,%s)\n", + PathName(*result).c_str(), rows, depth, cols, lhstype, lhssymm, + rhstype, rhssymm); + ruy::profiler::ScopeProfile profile; +#endif + + float latency = std::numeric_limits::infinity(); + float l1_refill_rate = std::numeric_limits::infinity(); + float l2_refill_rate = std::numeric_limits::infinity(); + float l3_refill_rate = std::numeric_limits::infinity(); + float l1tlb_refill_rate = std::numeric_limits::infinity(); + float l2tlb_refill_rate = std::numeric_limits::infinity(); + float mispred_rate = std::numeric_limits::infinity(); + float frontend_stall_rate = std::numeric_limits::infinity(); + float backend_stall_rate = std::numeric_limits::infinity(); + + for (int repeat = 0; repeat < repeats; repeat++) { + auto& pmu_events = GlobalPmuEvents(); + if (record_pmu) { + pmu_events.StartRecording(); + } + TimePoint time_start = Now(); + TimePoint t = time_start; + int iters = 0; + int iters_at_a_time = 1; + while (ToFloatSeconds(t - time_start) < benchmark_min_secs) { + for (int i = 0; i < iters_at_a_time; i++) { + if (cold) { + lhs.matrix.data = cold_lhs.Next(); + rhs.matrix.data = cold_rhs.Next(); + result->storage_matrix.matrix.data = cold_dst.Next(); + if (benchmark_prepack_lhs) { + result->prepacked_lhs.data = cold_prepacked_lhs.Next(); + } + if (benchmark_prepack_rhs) { + result->prepacked_rhs.data = cold_prepacked_rhs.Next(); + } + } + EvalResult(result); + iters++; + } + iters_at_a_time *= 2; + t = Now(); + } + latency = std::min( + latency, static_cast(ToFloatSeconds(t - time_start) / iters)); + if (record_pmu) { + pmu_events.StopRecording(); + const float normalization_factor = + 1.0f / (static_cast(iters) * rows * cols * depth); + l1_refill_rate = std::min( + l1_refill_rate, pmu_events.L1RefillCount() * normalization_factor); + l2_refill_rate = std::min( + l2_refill_rate, pmu_events.L2RefillCount() * normalization_factor); + l3_refill_rate = std::min( + l3_refill_rate, pmu_events.L3RefillCount() * normalization_factor); + l1tlb_refill_rate = + std::min(l1tlb_refill_rate, + pmu_events.L1TLBRefillCount() * normalization_factor); + l2tlb_refill_rate = + std::min(l2tlb_refill_rate, + pmu_events.L2TLBRefillCount() * normalization_factor); + mispred_rate = + std::min(mispred_rate, pmu_events.BranchMispredictionCount() * + normalization_factor); + frontend_stall_rate = + std::min(frontend_stall_rate, + pmu_events.FrontendStallCount() * normalization_factor); + backend_stall_rate = + std::min(backend_stall_rate, + pmu_events.BackendStallCount() * normalization_factor); + } + } + result->latency = latency; + if (record_pmu) { + result->l1_refill_rate = l1_refill_rate; + result->l2_refill_rate = l2_refill_rate; + result->l3_refill_rate = l3_refill_rate; + result->l1tlb_refill_rate = l1tlb_refill_rate; + result->l2tlb_refill_rate = l2tlb_refill_rate; + result->mispred_rate = mispred_rate; + result->frontend_stall_rate = frontend_stall_rate; + result->backend_stall_rate = backend_stall_rate; + } + +#ifdef RUY_PROFILER + } + fflush(stdout); +#endif + + if (cold) { + lhs.matrix.data = orig_lhs_data; + rhs.matrix.data = orig_rhs_data; + memcpy(orig_dst_data, result->storage_matrix.matrix.data.get(), + StorageSize(result->storage_matrix.matrix)); + result->storage_matrix.matrix.data = orig_dst_data; + result->prepacked_lhs.data = orig_prepacked_lhs_data; + result->prepacked_rhs.data = orig_prepacked_rhs_data; + } +} + +template +void TestSet::Eval() { + RUY_CHECK_EQ(life_stage, LifeStage::kHasPrepackedMatrices); + for (auto& result : results) { + if (benchmark) { + Benchmark(result.get()); + } else { + EvalResult(result.get()); + } + } + life_stage = LifeStage::kEvaluated; +} + +template +std::string DumpRegion(const Matrix& matrix, int center_row, + int center_col) { + static constexpr int kRadius = 20; + int first_row = std::max(0, center_row - kRadius); + int last_row = std::min(matrix.layout.rows - 1, center_row + kRadius); + int first_col = std::max(0, center_col - kRadius); + int last_col = std::min(matrix.layout.cols - 1, center_col + kRadius); + std::ostringstream stream; + for (int row = first_row; row <= last_row; row++) { + for (int col = first_col; col <= last_col; col++) { + stream << static_cast(Element(matrix, row, col)) << " "; + } + stream << "\n"; + } + return stream.str(); +} + +template +void TestSet::VerifyTestResults() const { + const int depth = lhs.matrix.layout.cols; + for (int i = 0; i < results.size() - 1; i++) { + if (!Agree(*results[i], *results[i + 1], depth)) { + std::string paths_in_agreement; + paths_in_agreement.append(PathName(*results[0])); + for (int j = 1; j <= i; j++) { + paths_in_agreement.append(", "); + paths_in_agreement.append(PathName(*results[j])); + } + ErrorAnalysis error_analysis; + AnalyzeTestError(*this, i + 1, &error_analysis); + std::cerr << "Error: path (" << PathName(*results[i + 1]) + << ") disagrees with the other paths (" << paths_in_agreement + << "), which agree with each other." << std::endl; + std::cerr << "Shape: rows = " << rows << ", cols = " << cols + << ", depth = " << depth << std::endl; + std::cerr << "Stats of the good result matrix: " + << StatsAsString(error_analysis.stats_good) << std::endl; + std::cerr << "Stats of the bad result matrix: " + << StatsAsString(error_analysis.stats_bad) << std::endl; + if (error_analysis.error_rows.size() < rows) { + std::cerr << "Rows containing errors: " + << Join(error_analysis.error_rows) << std::endl; + } else { + std::cerr << "Errors found in ALL rows." << std::endl; + } + if (error_analysis.error_cols.size() < cols) { + std::cerr << "Cols containing errors: " + << Join(error_analysis.error_cols) << std::endl; + } else { + std::cerr << "Errors found in ALL cols." << std::endl; + } + std::cerr << "The first error occurs at row " + << error_analysis.row_of_first_error << ", col " + << error_analysis.col_of_first_error << std::endl; + std::cerr << "Good value: " << error_analysis.first_error_good_value + << std::endl; + std::cerr << "Bad value : " << error_analysis.first_error_bad_value + << std::endl; + std::cerr << "Region of Good result matrix around first error:\n\n" + << DumpRegion(results[0]->storage_matrix.matrix, + error_analysis.row_of_first_error, + error_analysis.col_of_first_error) + << std::endl; + std::cerr << "Region of Bad result matrix around first error:\n\n" + << DumpRegion(results[i + 1]->storage_matrix.matrix, + error_analysis.row_of_first_error, + error_analysis.col_of_first_error) + << std::endl; + RUY_CHECK(false); + } + } +} + +template +void TestSet::Verify() { + RUY_CHECK_EQ(life_stage, LifeStage::kEvaluated); + if (expected_outcome == ExpectedOutcome::kSuccess) { + VerifyTestResults(); + } + life_stage = LifeStage::kFinal; +} + +template +void TestRCC(int rows, int depth, int cols, ExpectedOutcome expected_outcome) { + TestSetType test_set; + test_set.rows = rows; + test_set.depth = depth; + test_set.cols = cols; + test_set.lhs_order = Order::kRowMajor; + test_set.rhs_order = Order::kColMajor; + test_set.dst_order = Order::kColMajor; + test_set.layout_style = LayoutStyle::kPackedLinear; + test_set.expected_outcome = expected_outcome; + test_set.Run(); +} + +template +void TestRCC(int rows, int depth, int cols) { + TestRCC(rows, depth, cols, ExpectedOutcome::kSuccess); +} + +template +void TestNonRCC(int rows, int depth, int cols, + ExpectedOutcome expected_outcome) { + TestSetType test_set; + test_set.rows = rows; + test_set.depth = depth; + test_set.cols = cols; + test_set.lhs_order = Order::kColMajor; + test_set.rhs_order = Order::kColMajor; + test_set.dst_order = Order::kColMajor; + test_set.layout_style = LayoutStyle::kPackedLinear; + test_set.expected_outcome = expected_outcome; + test_set.Run(); +} + +template +void TestLinearAllOrders(int rows, int depth, int cols, + ExpectedOutcome expected_outcome) { + const std::vector orders{Order::kColMajor, Order::kRowMajor}; + + for (Order lhs_order : orders) { + for (Order rhs_order : orders) { + for (Order dst_order : orders) { + TestSetType test_set; + test_set.rows = rows; + test_set.depth = depth; + test_set.cols = cols; + test_set.lhs_order = lhs_order; + test_set.rhs_order = rhs_order; + test_set.dst_order = dst_order; + test_set.layout_style = LayoutStyle::kLinear; + test_set.expected_outcome = expected_outcome; + test_set.Run(); + } + } + } +} + +template +void TestLinearAllOrders(int rows, int depth, int cols) { + TestLinearAllOrders(rows, depth, cols, + ExpectedOutcome::kSuccess); +} + +} // namespace ruy + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TEST_H_ diff --git a/ruy/test_fast.cc b/ruy/test_fast.cc new file mode 100644 index 0000000..d1c1308 --- /dev/null +++ b/ruy/test_fast.cc @@ -0,0 +1,110 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This test contains cheap test cases, completes in a few seconds. + +#include + +#include "ruy/test.h" + +namespace ruy { + +using LhsScalar = RUY_TEST_LHSSCALAR; +using RhsScalar = RUY_TEST_RHSSCALAR; +using AccumScalar = RUY_TEST_ACCUMSCALAR; +using DstScalar = RUY_TEST_DSTSCALAR; + +using TestSetType = + TestSet>; + +TEST(RuyTest, TestSquareMuls) { + const std::vector sizes{ + // small sizes + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + // multiplies of 16 + 16, + 32, + 48, + 64, + // pot-minus-1 sizes + 15, + 31, + 63, + // pot-plus-1 sizes + 17, + 33, + 65, + }; + + for (int size : sizes) { + TestRCC(size, size, size); + TestLinearAllOrders(size, size, size); + } +} + +TEST(RuyTest, TestMiscMuls) { + const int shapes[][3] = { + {2, 3, 4}, {7, 6, 5}, {12, 23, 6}, {19, 3, 11}, {3, 10, 17}, + {30, 21, 43}, {7, 57, 9}, {49, 69, 71}, {38, 111, 29}, {87, 98, 76}, + {16, 96, 16}, {16, 88, 16}, {16, 84, 16}, {16, 92, 16}, {16, 82, 16}, + {16, 81, 16}, {16, 95, 16}, {3, 128, 5}}; + for (const auto& shape : shapes) { + TestLinearAllOrders(shape[0], shape[1], shape[2]); + } +} + +TEST(RuyTest, TestDeepMuls) { + // TODO(b/137649322): clarify what's the max allowed matrix size. + TestRCC(1, 32767, 1); + TestLinearAllOrders(5, 5001, 4); + TestLinearAllOrders(9, 1025, 10); +} + +TEST(RuyTest, TestShallowMuls) { + TestLinearAllOrders(101, 1, 103); + TestLinearAllOrders(71, 2, 53); + TestLinearAllOrders(51, 3, 73); + TestLinearAllOrders(51, 4, 43); +} + +TEST(RuyTest, TestNarrowMuls) { + for (int width : {1, 2, 3, 4, 5, 8}) { + TestLinearAllOrders(width, 12, 13); + TestLinearAllOrders(15, 19, width); + TestLinearAllOrders(width, 123, 137); + TestLinearAllOrders(158, 119, width); + } +} + +TEST(RuyTest, TestGEMV) { + for (int size = 1; size < 1024; size *= 2) { + for (int depth = 1; depth < 500; depth += 47) { + TestLinearAllOrders(size, depth, 1); + } + } + TestLinearAllOrders(5, 5001, 1); + TestLinearAllOrders(8193, 17, 1); +} + +} // namespace ruy diff --git a/ruy/test_slow.cc b/ruy/test_slow.cc new file mode 100644 index 0000000..9f0f218 --- /dev/null +++ b/ruy/test_slow.cc @@ -0,0 +1,71 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This test contains more expensive test cases. + +#include "ruy/test.h" + +namespace ruy { + +using LhsScalar = RUY_TEST_LHSSCALAR; +using RhsScalar = RUY_TEST_RHSSCALAR; +using AccumScalar = RUY_TEST_ACCUMSCALAR; +using DstScalar = RUY_TEST_DSTSCALAR; + +using TestSetType = + TestSet>; + +TEST(RuyTest, TestBigNarrowMuls) { + for (int width : {1, 2, 3, 4, 5, 8}) { + TestRCC(width, 401, 601); + TestRCC(587, 443, width); + } + TestRCC(7, 45984, + 5); // Large enough to trigger row-sum overflows. + TestRCC(512, 256, 16); +} + +TEST(RuyTest, TestBigShallowMuls) { + TestLinearAllOrders(501, 1, 321); + TestLinearAllOrders(301, 5, 403); + TestLinearAllOrders(256, 32, 512); +} + +TEST(RuyTest, TestBigMuls) { + TestRCC(225, 303, 199); + TestLinearAllOrders(256, 192, 128); +} + +TEST(RuyTest, TestBigPowerOfTwoDepthWithAvoidAliasing) { + // Important to test some power-of-two depths: that's when the + // RUY_AVOID_ALIASING optimization kicks in and makes packed matrices + // strided, exposing bugs in kernels mixing up size and stride. + // Moreover, it's important that the test matrices be sufficiently wide + // that they will result in multiple blocks, exposing bugs in the + // computation of the base address of each block. + TestLinearAllOrders(70, 1024, 80); + TestLinearAllOrders(60, 2048, 70); + TestLinearAllOrders(40, 4096, 50); +} + +TEST(RuyTest, TestGEMV) { + for (int size = 1025; size <= 1409; size += 384) { + for (int depth = 350; depth < 500; depth += 47) { + TestLinearAllOrders(size, depth, 1); + } + } +} + +} // namespace ruy diff --git a/ruy/test_special_specs.cc b/ruy/test_special_specs.cc new file mode 100644 index 0000000..a621d0e --- /dev/null +++ b/ruy/test_special_specs.cc @@ -0,0 +1,163 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This test covers non-basic specs. + +#include "ruy/test.h" + +namespace ruy { + +template +struct LoopStructureSpec : BasicSpec { + static constexpr LoopStructure kLoopStructure = tLoopStructure; +}; + +template +struct ZeroPointSupportSpec : BasicSpec { + static constexpr ZeroPointSupport kZeroPointSupport = tZeroPointSupport; +}; + +template +struct RCCSpec : BasicSpec { + static constexpr LayoutSupport kLayoutSupport = LayoutSupport::kRCC; +}; + +template +struct StandardCppKernelLayoutSpec : BasicSpec { + using StandardCppKernelLhsLayout = LhsKernelLayout; + using StandardCppKernelRhsLayout = RhsKernelLayout; + static int local_data_cache_size() { return 1; } + static int shared_data_cache_size() { return 1; } +}; + +using LhsScalar = RUY_TEST_LHSSCALAR; +using RhsScalar = RUY_TEST_RHSSCALAR; +using AccumScalar = RUY_TEST_ACCUMSCALAR; +using DstScalar = RUY_TEST_DSTSCALAR; + +template +void TestLoopStructure() { + using SpecType = LoopStructureSpec; + using TestSetType = TestSet; + for (int size = 1; size < 10; size++) { + TestLinearAllOrders(size, size, size); + } + TestLinearAllOrders(3, 5, 78); + TestLinearAllOrders(19, 91, 7); + TestLinearAllOrders(71, 26, 44); + TestLinearAllOrders(81, 93, 72); +} + +TEST(TestSpecialSpecs, LoopStructure) { + static_assert(BasicSpec::kLoopStructure == + LoopStructure::kAuto, + ""); + static_assert(BasicSpec::kLoopStructure == LoopStructure::kAuto, + ""); + TestLoopStructure(); + TestLoopStructure(); +} + +template +void TestZeroPointSupport(LhsScalar lhs_zero_point, RhsScalar rhs_zero_point, + DstScalar dst_zero_point, + ExpectedOutcome expected_outcome) { + using SpecType = + ZeroPointSupportSpec; + using TestSetType = TestSet; + TestSetType test_set; + test_set.rows = 11; + test_set.depth = 12; + test_set.cols = 13; + test_set.lhs_order = Order::kRowMajor; + test_set.rhs_order = Order::kColMajor; + test_set.dst_order = Order::kColMajor; + test_set.layout_style = LayoutStyle::kPackedLinear; + test_set.expected_outcome = expected_outcome; + test_set.lhs_zero_point = lhs_zero_point; + test_set.rhs_zero_point = rhs_zero_point; + test_set.dst_zero_point = dst_zero_point; + test_set.use_specified_zero_points = true; + test_set.Run(); +} + +TEST(TestSpecialSpecs, ZeroPointSupport) { + // Sanity check + RUY_CHECK_EQ(SymmetricZeroPoint(), 128); + RUY_CHECK_EQ(SymmetricZeroPoint(), 0); + + if (std::is_floating_point::value) { + return; + } + + TestZeroPointSupport( + SymmetricZeroPoint(), SymmetricZeroPoint(), + SymmetricZeroPoint(), ExpectedOutcome::kSuccess); + TestZeroPointSupport( + SymmetricZeroPoint() - 1, SymmetricZeroPoint(), + SymmetricZeroPoint(), ExpectedOutcome::kSuccess); + TestZeroPointSupport( + SymmetricZeroPoint(), SymmetricZeroPoint(), + SymmetricZeroPoint(), ExpectedOutcome::kSuccess); + TestZeroPointSupport( + SymmetricZeroPoint() + 1, SymmetricZeroPoint(), + SymmetricZeroPoint(), ExpectedOutcome::kDeath); + TestZeroPointSupport( + SymmetricZeroPoint(), SymmetricZeroPoint() + 1, + SymmetricZeroPoint(), ExpectedOutcome::kDeath); + TestZeroPointSupport( + SymmetricZeroPoint(), SymmetricZeroPoint(), + SymmetricZeroPoint() - 1, ExpectedOutcome::kDeath); +} + +TEST(TestSpecialSpecs, RCC) { + using RCCSpec = RCCSpec; + using RCCTestSet = TestSet; + TestRCC(81, 93, 72); + TestNonRCC(81, 93, 72, ExpectedOutcome::kDeath); +} + +template +void TestStandardCppKernelLayout() { + using SpecType = + StandardCppKernelLayoutSpec; + using TestSetType = TestSet; + for (int size = 1; size < 10; size++) { + TestLinearAllOrders(size, size, size); + } + TestLinearAllOrders(87, 34, 56); + TestLinearAllOrders(123, 234, 78); +} + +TEST(TestSpecialSpecs, StandardCppKernelLayoutTrivial1x1) { + TestStandardCppKernelLayout, + FixedKernelLayout>(); +} + +TEST(TestSpecialSpecs, StandardCppKernelLayoutSquare4x4) { + TestStandardCppKernelLayout, + FixedKernelLayout>(); +} + +TEST(TestSpecialSpecs, StandardCppKernelLayoutRectangular4x8) { + TestStandardCppKernelLayout, + FixedKernelLayout>(); +} + +} // namespace ruy diff --git a/ruy/thread_pool.cc b/ruy/thread_pool.cc new file mode 100644 index 0000000..d09bf1e --- /dev/null +++ b/ruy/thread_pool.cc @@ -0,0 +1,200 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/thread_pool.h" + +#include +#include // NOLINT(build/c++11) +#include // NOLINT(build/c++11) +#include +#include +#include +#include // NOLINT(build/c++11) +#include // NOLINT(build/c++11) + +#include "ruy/check_macros.h" +#include "ruy/wait.h" + +namespace ruy { + +// A worker thread. +class Thread { + public: + enum class State { + Startup, // The initial state before the thread main loop runs. + Ready, // Is not working, has not yet received new work to do. + HasWork, // Has work to do. + ExitAsSoonAsPossible // Should exit at earliest convenience. + }; + + explicit Thread(BlockingCounter* counter_to_decrement_when_ready) + : task_(nullptr), + state_(State::Startup), + counter_to_decrement_when_ready_(counter_to_decrement_when_ready) { + thread_.reset(new std::thread(ThreadFunc, this)); + } + + ~Thread() { + ChangeState(State::ExitAsSoonAsPossible); + thread_->join(); + } + + // Changes State; may be called from either the worker thread + // or the master thread; however, not all state transitions are legal, + // which is guarded by assertions. + // + // The Task argument is to be used only with new_state==HasWork. + // It specifies the Task being handed to this Thread. + void ChangeState(State new_state, Task* task = nullptr) { + state_mutex_.lock(); + State old_state = state_.load(std::memory_order_relaxed); + RUY_DCHECK_NE(old_state, new_state); + switch (old_state) { + case State::Startup: + RUY_DCHECK_EQ(new_state, State::Ready); + break; + case State::Ready: + RUY_DCHECK(new_state == State::HasWork || + new_state == State::ExitAsSoonAsPossible); + break; + case State::HasWork: + RUY_DCHECK(new_state == State::Ready || + new_state == State::ExitAsSoonAsPossible); + break; + default: + abort(); + } + switch (new_state) { + case State::Ready: + if (task_) { + // Doing work is part of reverting to 'ready' state. + task_->Run(); + task_ = nullptr; + } + break; + case State::HasWork: + RUY_DCHECK(!task_); + task_ = task; + break; + default: + break; + } + state_.store(new_state, std::memory_order_relaxed); + state_cond_.notify_all(); + state_mutex_.unlock(); + if (new_state == State::Ready) { + counter_to_decrement_when_ready_->DecrementCount(); + } + } + + static void ThreadFunc(Thread* arg) { arg->ThreadFuncImpl(); } + + // Called by the master thead to give this thread work to do. + void StartWork(Task* task) { ChangeState(State::HasWork, task); } + + private: + // Thread entry point. + void ThreadFuncImpl() { + ChangeState(State::Ready); + + // Thread main loop + while (true) { + // In the 'Ready' state, we have nothing to do but to wait until + // we switch to another state. + const auto& condition = [this]() { + return state_.load(std::memory_order_acquire) != State::Ready; + }; + Wait(condition, &state_cond_, &state_mutex_); + + // Act on new state. + switch (state_.load(std::memory_order_acquire)) { + case State::HasWork: + // Got work to do! So do it, and then revert to 'Ready' state. + ChangeState(State::Ready); + break; + case State::ExitAsSoonAsPossible: + return; + default: + abort(); + } + } + } + + // The underlying thread. + std::unique_ptr thread_; + + // The task to be worked on. + Task* task_; + + // The condition variable and mutex guarding state changes. + std::condition_variable state_cond_; + std::mutex state_mutex_; + + // The state enum tells if we're currently working, waiting for work, etc. + // Its concurrent accesses by the thread and main threads are guarded by + // state_mutex_, and can thus use memory_order_relaxed. This still needs + // to be a std::atomic because we use WaitForVariableChange. + std::atomic state_; + + // pointer to the master's thread BlockingCounter object, to notify the + // master thread of when this thread switches to the 'Ready' state. + BlockingCounter* const counter_to_decrement_when_ready_; +}; + +void ThreadPool::ExecuteImpl(int task_count, int stride, Task* tasks) { + RUY_DCHECK_GE(task_count, 1); + + // Case of 1 thread: just run the single task on the current thread. + if (task_count == 1) { + (tasks + 0)->Run(); + return; + } + + // Task #0 will be run on the current thread. + CreateThreads(task_count - 1); + counter_to_decrement_when_ready_.Reset(task_count - 1); + for (int i = 1; i < task_count; i++) { + auto task_address = reinterpret_cast(tasks) + i * stride; + threads_[i - 1]->StartWork(reinterpret_cast(task_address)); + } + + // Execute task #0 immediately on the current thread. + (tasks + 0)->Run(); + + // Wait for the threads submitted above to finish. + counter_to_decrement_when_ready_.Wait(); +} + +// Ensures that the pool has at least the given count of threads. +// If any new thread has to be created, this function waits for it to +// be ready. +void ThreadPool::CreateThreads(int threads_count) { + if (threads_.size() >= threads_count) { + return; + } + counter_to_decrement_when_ready_.Reset(threads_count - threads_.size()); + while (threads_.size() < threads_count) { + threads_.push_back(new Thread(&counter_to_decrement_when_ready_)); + } + counter_to_decrement_when_ready_.Wait(); +} + +ThreadPool::~ThreadPool() { + for (auto w : threads_) { + delete w; + } +} + +} // end namespace ruy diff --git a/ruy/thread_pool.h b/ruy/thread_pool.h new file mode 100644 index 0000000..04c201c --- /dev/null +++ b/ruy/thread_pool.h @@ -0,0 +1,102 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file is a fork of gemmlowp's multi_thread_gemm.h, under Apache 2.0 +// license. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_THREAD_POOL_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_THREAD_POOL_H_ + +#include + +#include "ruy/blocking_counter.h" + +namespace ruy { + +// A workload for a thread. +struct Task { + virtual ~Task() {} + virtual void Run() = 0; +}; + +class Thread; + +// A simple pool of threads, that only allows the very +// specific parallelization pattern that we use here: +// One thread, which we call the 'main thread', calls Execute, distributing +// a Task each to N threads, being N-1 'worker threads' and the main thread +// itself. After the main thread has completed its own Task, it waits for +// the worker threads to have all completed. That is the only synchronization +// performed by this ThreadPool. +// +// In particular, there is a naive 1:1 mapping of Tasks to threads. +// This ThreadPool considers it outside of its own scope to try to work +// with fewer threads than there are Tasks. The idea is that such N:M mappings +// of tasks to threads can be implemented as a higher-level feature on top of +// the present low-level 1:1 threadpool. For example, a user might have a +// Task subclass referencing a shared atomic counter indexing into a vector of +// finer-granularity subtasks. Different threads would then concurrently +// increment this atomic counter, getting each their own subtasks to work on. +// That approach is the one used in ruy's multi-thread matrix multiplication +// implementation --- see ruy's TrMulTask. +class ThreadPool { + public: + ThreadPool() {} + + ~ThreadPool(); + + // Executes task_count tasks on task_count threads. + // Grows the threadpool as needed to have at least (task_count-1) threads. + // The 0-th task is run on the thread on which Execute is called: that + // is by definition what we call the "main thread". Synchronization of all + // threads is performed before this function returns. + // + // As explained in the class comment, there is a 1:1 mapping of tasks to + // threads. If you need something smarter than that, for instance if you + // want to run an unbounded number of tasks on a bounded number of threads, + // then you need something higher-level than this ThreadPool, that can + // be layered on top of it by appropriately subclassing Tasks. + // + // TaskType must be a subclass of ruy::Task. That is implicitly guarded by + // the static_cast in this inline implementation. + template + void Execute(int task_count, TaskType* tasks) { + ExecuteImpl(task_count, sizeof(TaskType), static_cast(tasks)); + } + + private: + // Ensures that the pool has at least the given count of threads. + // If any new thread has to be created, this function waits for it to + // be ready. + void CreateThreads(int threads_count); + + // Non-templatized implementation of the public Execute method. + // See the inline implementation of Execute for how this is used. + void ExecuteImpl(int task_count, int stride, Task* tasks); + + // copy construction disallowed + ThreadPool(const ThreadPool&) = delete; + + // The threads in this pool. They are owned by the pool: + // the pool creates threads and destroys them in its destructor. + std::vector threads_; + + // The BlockingCounter used to wait for the threads. + BlockingCounter counter_to_decrement_when_ready_; +}; + +} // namespace ruy + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_THREAD_POOL_H_ diff --git a/ruy/time.h b/ruy/time.h new file mode 100644 index 0000000..9dba75e --- /dev/null +++ b/ruy/time.h @@ -0,0 +1,81 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TIME_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TIME_H_ + +#include // NOLINT(build/c++11) +#include // IWYU pragma: keep +#include // NOLINT(build/c++11) + +#ifdef __linux__ +#include +// IWYU pragma: no_include + +#include +#endif + +namespace ruy { + +using InternalDefaultClock = std::chrono::steady_clock; + +using TimePoint = InternalDefaultClock::time_point; +using Duration = InternalDefaultClock::duration; + +template +Duration DurationFromSeconds(RepresentationType representation) { + return std::chrono::duration_cast( + std::chrono::duration(representation)); +} + +template +Duration DurationFromMilliseconds(RepresentationType representation) { + return std::chrono::duration_cast( + std::chrono::duration(representation)); +} + +template +Duration DurationFromNanoseconds(RepresentationType representation) { + return std::chrono::duration_cast( + std::chrono::duration(representation)); +} + +inline float ToFloatSeconds(const Duration& duration) { + return std::chrono::duration_cast>(duration) + .count(); +} + +inline std::int64_t ToInt64Nanoseconds(const Duration& duration) { + return std::chrono::duration_cast< + std::chrono::duration>(duration) + .count(); +} + +inline TimePoint Now() { return InternalDefaultClock::now(); } + +inline TimePoint CoarseNow() { +#ifdef __linux__ + timespec t; + clock_gettime(CLOCK_MONOTONIC_COARSE, &t); + return TimePoint( + DurationFromNanoseconds(1000000000LL * t.tv_sec + t.tv_nsec)); +#else + return Now(); +#endif +} + +} // namespace ruy + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TIME_H_ diff --git a/ruy/trace.cc b/ruy/trace.cc new file mode 100644 index 0000000..1822cdb --- /dev/null +++ b/ruy/trace.cc @@ -0,0 +1,325 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/trace.h" + +#include +#include // IWYU pragma: keep +#include +#include +#include +#include + +#include "ruy/check_macros.h" +#include "ruy/side_pair.h" +#include "ruy/time.h" + +namespace ruy { + +#ifdef RUY_TRACE + +enum class TraceEvent : std::uint8_t { + kNone, + kThreadStart, + kThreadLoopStart, + kThreadEnd, + kBlockReserved, + kBlockPackedLhs, + kBlockPackedRhs, + kBlockFinished +}; + +struct TraceEntry { + TimePoint time_point; + TraceEvent event; + // ruy-internal thread id i.e. contiguous index into array of threads, + // with 0 designating the main thread. + std::uint16_t thread_id = 0; + // Additional parameters whose meaning depends on the 'event' type. + std::uint32_t params[1]; +}; + +struct Trace { + BlockMap block_map; + // During recording, to avoid having to use locks or atomics, we let + // each thread append to its own specific vector. + std::vector> thread_specific_entries; + // Global vector of entries into which we coalesce thread_specific_entries + // after recording is finished, when dumping a trace. See + // AggregateThreadSpecificEntries. + std::vector entries; + TimePoint time_start; + TimePoint time_execute; + TimePoint time_end; +}; + +namespace { + +// Coalesce Trace::thread_specific_entries into Trace::entries. +void AggregateThreadSpecificEntries(Trace* trace) { + RUY_CHECK(trace->entries.empty()); + for (auto& thread_specific_entries_vector : trace->thread_specific_entries) { + for (const TraceEntry& entry : thread_specific_entries_vector) { + trace->entries.push_back(entry); + } + thread_specific_entries_vector.clear(); + } +} + +// Sort Trace::entries by ascending time. In case of equal timepoints, +// sort by some semi-arbitrary ordering of event types. +void Sort(Trace* trace) { + std::sort(std::begin(trace->entries), std::end(trace->entries), + [](const TraceEntry& a, const TraceEntry& b) -> bool { + return a.time_point < b.time_point || + (a.time_point == b.time_point && + static_cast(a.event) < static_cast(b.event)); + }); +} + +// Dump a trace. Assumes that AggregateThreadSpecificEntries and Sort have +// already been called on it. +// +// On some architectures long long ints are not same as std::int64_t, and +// time is printed as %lld, so static_casts are necessary. +void Dump(const Trace& trace) { + const char* trace_filename = getenv("RUY_TRACE_FILE"); + FILE* trace_file = trace_filename ? fopen(trace_filename, "w") : stderr; + if (!trace_file) { + fprintf(stderr, "Failed to open %s for write, errno=%d\n", trace_filename, + errno); + RUY_CHECK(false); + } + fprintf(trace_file, "thread_count:%d\n", trace.block_map.thread_count); + fprintf(trace_file, "rows:%d\n", trace.block_map.dims[Side::kLhs]); + fprintf(trace_file, "cols:%d\n", trace.block_map.dims[Side::kRhs]); + fprintf(trace_file, "Execute: %lld\n", + static_cast( + ToInt64Nanoseconds(trace.time_execute - trace.time_start))); + for (const TraceEntry& entry : trace.entries) { + long long int time = static_cast( + ToInt64Nanoseconds(entry.time_point - trace.time_start)); + switch (entry.event) { + case TraceEvent::kThreadStart: + fprintf(trace_file, "ThreadStart: %lld, %d\n", time, entry.thread_id); + break; + case TraceEvent::kThreadLoopStart: + fprintf(trace_file, "ThreadLoopStart: %lld, %d\n", time, + entry.thread_id); + break; + case TraceEvent::kThreadEnd: + fprintf(trace_file, "ThreadEnd: %lld, %d\n", time, entry.thread_id); + break; + case TraceEvent::kBlockReserved: { + std::uint32_t block_id = entry.params[0]; + SidePair block; + GetBlockByIndex(trace.block_map, block_id, &block); + SidePair start, end; + GetBlockMatrixCoords(trace.block_map, block, &start, &end); + fprintf(trace_file, + "BlockReserved: %lld, %d, %d, %d, %d, %d, %d, %d, %d\n", time, + entry.thread_id, block_id, block[Side::kLhs], block[Side::kRhs], + start[Side::kLhs], start[Side::kRhs], end[Side::kLhs], + end[Side::kRhs]); + break; + } + case TraceEvent::kBlockPackedLhs: { + std::uint32_t block = entry.params[0]; + int start, end; + GetBlockMatrixCoords(Side::kLhs, trace.block_map, block, &start, &end); + fprintf(trace_file, "BlockPackedLhs: %lld, %d, %d, %d, %d\n", time, + entry.thread_id, block, start, end); + break; + } + case TraceEvent::kBlockPackedRhs: { + std::uint32_t block = entry.params[0]; + int start, end; + GetBlockMatrixCoords(Side::kRhs, trace.block_map, block, &start, &end); + fprintf(trace_file, "BlockPackedRhs: %lld, %d, %d, %d, %d\n", time, + entry.thread_id, block, start, end); + break; + } + case TraceEvent::kBlockFinished: { + std::uint32_t block_id = entry.params[0]; + SidePair block; + GetBlockByIndex(trace.block_map, block_id, &block); + fprintf(trace_file, "BlockFinished: %lld, %d, %d, %d, %d\n", time, + entry.thread_id, block_id, block[Side::kLhs], + block[Side::kRhs]); + break; + } + default: + RUY_CHECK(false); + } + } + fprintf(trace_file, "End: %lld\n", + static_cast( + ToInt64Nanoseconds(trace.time_end - trace.time_start))); + if (trace_filename) { + fclose(trace_file); + } +} + +} // anonymous namespace + +// Get a Trace object to record to, or null of tracing is not enabled. +Trace* NewTraceOrNull(TracingContext* tracing, int rows, int depth, int cols) { + if (!tracing->initialized) { + tracing->initialized = true; + tracing->enabled = getenv("RUY_TRACE"); + if (!tracing->enabled) { + return nullptr; + } + if (getenv("RUY_TRACE_FILTER_ROWS")) { + tracing->filter_shape_rows = std::stoi(getenv("RUY_TRACE_FILTER_ROWS")); + } + if (getenv("RUY_TRACE_FILTER_DEPTH")) { + tracing->filter_shape_depth = std::stoi(getenv("RUY_TRACE_FILTER_DEPTH")); + } + if (getenv("RUY_TRACE_FILTER_COLS")) { + tracing->filter_shape_cols = std::stoi(getenv("RUY_TRACE_FILTER_COLS")); + } + } + if (!tracing->enabled) { + return nullptr; + } + if (tracing->filter_shape_rows && rows != tracing->filter_shape_rows) { + return nullptr; + } + if (tracing->filter_shape_depth && depth != tracing->filter_shape_depth) { + return nullptr; + } + if (tracing->filter_shape_cols && cols != tracing->filter_shape_cols) { + return nullptr; + } + // Delete any existing trace. + delete tracing->trace; + // Create a new one. + tracing->trace = new Trace; + return tracing->trace; +} + +// The trace recorded on a context is finalized and dumped by +// this TracingContext destructor. +// +// The idea of dumping on context destructor is that typically one wants to +// run many matrix multiplications, e.g. to hit a steady state in terms of +// performance characteristics, but only trace the last repetition of the +// workload, when that steady state was attained. +TracingContext::~TracingContext() { + if (trace) { + AggregateThreadSpecificEntries(trace); + Sort(trace); + Dump(*trace); + } + delete trace; +} + +void TraceRecordStart(Trace* trace) { + if (trace) { + trace->time_start = Now(); + } +} + +void TraceRecordExecute(const BlockMap& block_map, Trace* trace) { + if (trace) { + trace->time_execute = Now(); + trace->block_map = block_map; + trace->thread_specific_entries.resize(block_map.thread_count); + for (int thread = 0; thread < block_map.thread_count; thread++) { + trace->thread_specific_entries[thread].clear(); + // Reserve some large size to avoid frequent heap allocations + // affecting the recorded timings. + trace->thread_specific_entries[thread].reserve(16384); + } + } +} + +void TraceRecordEnd(Trace* trace) { + if (trace) { + trace->time_end = Now(); + } +} + +void TraceRecordThreadStart(std::uint32_t thread_id, Trace* trace) { + if (trace) { + TraceEntry entry; + entry.event = TraceEvent::kThreadStart; + entry.time_point = Now(); + entry.thread_id = thread_id; + trace->thread_specific_entries[thread_id].push_back(entry); + } +} + +void TraceRecordThreadLoopStart(std::uint32_t thread_id, Trace* trace) { + if (trace) { + TraceEntry entry; + entry.event = TraceEvent::kThreadLoopStart; + entry.time_point = Now(); + entry.thread_id = thread_id; + trace->thread_specific_entries[thread_id].push_back(entry); + } +} + +void TraceRecordBlockReserved(std::uint32_t thread_id, std::uint32_t block_id, + Trace* trace) { + if (trace) { + TraceEntry entry; + entry.event = TraceEvent::kBlockReserved; + entry.time_point = Now(); + entry.thread_id = thread_id; + entry.params[0] = block_id; + trace->thread_specific_entries[thread_id].push_back(entry); + } +} + +void TraceRecordBlockPacked(std::uint32_t thread_id, Side side, int block, + Trace* trace) { + if (trace) { + TraceEntry entry; + entry.event = side == Side::kLhs ? TraceEvent::kBlockPackedLhs + : TraceEvent::kBlockPackedRhs; + entry.time_point = Now(); + entry.thread_id = thread_id; + entry.params[0] = block; + trace->thread_specific_entries[thread_id].push_back(entry); + } +} + +void TraceRecordBlockFinished(std::uint32_t thread_id, std::uint32_t block_id, + Trace* trace) { + if (trace) { + TraceEntry entry; + entry.event = TraceEvent::kBlockFinished; + entry.time_point = Now(); + entry.thread_id = thread_id; + entry.params[0] = block_id; + trace->thread_specific_entries[thread_id].push_back(entry); + } +} + +void TraceRecordThreadEnd(std::uint32_t thread_id, Trace* trace) { + if (trace) { + TraceEntry entry; + entry.event = TraceEvent::kThreadEnd; + entry.time_point = Now(); + entry.thread_id = thread_id; + trace->thread_specific_entries[thread_id].push_back(entry); + } +} + +#endif + +} // namespace ruy diff --git a/ruy/trace.h b/ruy/trace.h new file mode 100644 index 0000000..d2cc51d --- /dev/null +++ b/ruy/trace.h @@ -0,0 +1,73 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TRACE_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TRACE_H_ + +#include + +#include "ruy/block_map.h" +#include "ruy/side_pair.h" + +namespace ruy { + +struct Trace; + +#ifdef RUY_TRACE + +struct TracingContext { + bool initialized = false; + bool enabled = false; + int filter_shape_rows = 0; + int filter_shape_cols = 0; + int filter_shape_depth = 0; + Trace* trace = nullptr; + ~TracingContext(); +}; + +Trace* NewTraceOrNull(TracingContext* context, int rows, int depth, int cols); +void TraceRecordThreadStart(std::uint32_t thread_id, Trace* trace); +void TraceRecordThreadLoopStart(std::uint32_t thread_id, Trace* trace); +void TraceRecordBlockReserved(std::uint32_t thread_id, std::uint32_t block_id, + Trace* trace); +void TraceRecordBlockPacked(std::uint32_t thread_id, Side side, int block, + Trace* trace); +void TraceRecordBlockFinished(std::uint32_t thread_id, std::uint32_t block_id, + Trace* trace); +void TraceRecordThreadEnd(std::uint32_t thread_id, Trace* trace); +void TraceRecordStart(Trace* trace); +void TraceRecordExecute(const BlockMap& block_map, Trace* trace); +void TraceRecordEnd(Trace* trace); + +#else + +struct TracingContext {}; + +inline Trace* NewTraceOrNull(TracingContext*, int, int, int) { return nullptr; } +inline void TraceRecordThreadStart(std::uint32_t, Trace*) {} +inline void TraceRecordThreadLoopStart(std::uint32_t, Trace*) {} +inline void TraceRecordBlockReserved(std::uint32_t, std::uint32_t, Trace*) {} +inline void TraceRecordBlockPacked(std::uint32_t, Side, int, Trace*) {} +inline void TraceRecordBlockFinished(std::uint32_t, std::uint32_t, Trace*) {} +inline void TraceRecordThreadEnd(std::uint32_t, Trace*) {} +inline void TraceRecordStart(Trace*) {} +inline void TraceRecordExecute(const BlockMap&, Trace*) {} +inline void TraceRecordEnd(Trace*) {} + +#endif + +} // namespace ruy + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TRACE_H_ diff --git a/ruy/trmul.cc b/ruy/trmul.cc new file mode 100644 index 0000000..a3ba46a --- /dev/null +++ b/ruy/trmul.cc @@ -0,0 +1,401 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/trmul.h" + +#include +#include +#include +#include +#include + +#include "ruy/allocator.h" +#include "ruy/block_map.h" +#include "ruy/check_macros.h" +#include "ruy/common.h" +#include "ruy/internal_matrix.h" +#include "ruy/matrix.h" +#include "ruy/opt_set.h" +#include "ruy/profiler/instrumentation.h" +#include "ruy/side_pair.h" +#include "ruy/size_util.h" +#include "ruy/spec.h" +#include "ruy/thread_pool.h" +#include "ruy/trace.h" +#include "ruy/tune.h" + +namespace ruy { + +namespace { + +enum class PackingStatus : std::uint8_t { kNotStarted, kInProgress, kFinished }; + +struct TrMulTask final : Task { + TrMulTask(TrMulParams* params_, const BlockMap& block_map_, + std::atomic* atomic_block_id_, int thread_id_, + bool need_atomics_, + SidePair*> packing_status_, + TuningResolver* tuning_resolver_, Allocator* local_allocator_, + Trace* trace_) + : params(params_), + block_map(block_map_), + atomic_block_id(atomic_block_id_), + thread_id(thread_id_), + need_atomics(need_atomics_), + packing_status(packing_status_), + tuning_resolver(tuning_resolver_), + local_allocator(local_allocator_), + trace(trace_), + local_packed{nullptr, nullptr} {} + + void Run() override { + TraceRecordThreadStart(thread_id, trace); + + for (Side side : {Side::kLhs, Side::kRhs}) { + if (!params->is_prepacked[side]) { + const int size = NumBlocksPerSide(side, block_map); + local_allocator->Allocate(size, &local_packed[side]); + memset(local_packed[side], 0, size * sizeof(bool)); + } + } + + const int num_blocks = NumBlocks(block_map); + + const Tuning tuning = tuning_resolver->Resolve(); + + TraceRecordThreadLoopStart(thread_id, trace); + + SidePair block; + SidePair start; + SidePair end; + + // Each thread starts by initially reserving the block whose id + // is the thread id. + int block_id = thread_id; + TraceRecordBlockReserved(thread_id, block_id, trace); + + while (block_id < num_blocks) { + // Reserve the next block to handle. In order to hide the latency + // (typically comparable to an access to the level of data cache that + // is shared among CPU cores, e.g. 60 cycles on an ARM CPU as of 2019) + // of this atomic operation, we structure this code so as to avoid + // immediately depending on the `next_n` result. + const int next_block_id = + atomic_block_id->fetch_add(1, std::memory_order_relaxed); + TraceRecordBlockReserved(thread_id, next_block_id, trace); + // Get coordinates of the current block to handle, in "block space". + GetBlockByIndex(block_map, block_id, &block); + // Get coordinates of the current block to handle, in matrix space. + GetBlockMatrixCoords(block_map, block, &start, &end); + // Maybe pack the current LHS/RHS block, if not already packed. + EnsurePacked(block, start, end, tuning); + // Actually do matrix multiplication work + params->RunKernel(tuning, start, end); + TraceRecordBlockFinished(thread_id, block_id, trace); + // Move on to the next block as obtained by the atomic increment + // at the start of this while loop iteration. + block_id = next_block_id; + } + + local_allocator->FreeAll(); + + TraceRecordThreadEnd(thread_id, trace); + } + + private: + // Tries to pack a block, without blocking. + // If the block was already packed, returns true. + // If the block was not started packing, packs it and returns true. + // If the block was being packed by another thread, returns false. + bool TryPack(Side side, int block, int start, int end, Tuning tuning) { + if (params->is_prepacked[side]) { + return true; + } + if (!local_packed[side][block]) { + if (need_atomics) { + // Explanation of this compare_exchange_strong operation: + // This atomically performs all of the following: + // 1. Read `status` with "acquire" memory order. + // * That this read uses "acquire" is because both memory orders + // specified have "acquire" as their read-component. + // 2. Compare (bitwise) with `exchanged_status`. + // 3. If equal, stores the value kInProgress to `status` with "release" + // memory order, and returns true, so we take this 'if' branch. + // * That this store uses "release" is because of the _rel part in + // memory_order_acq_rel passed as the first memory order argument. + // 4. If not equal, stores the loaded value of `status` to + // `exchanged_status` with "relaxed" semantics, and returns false, + // so we take the 'else' branch. + // * That this store uses "relaxed" is because the second memory + // order argument, memory_order_acquire, implies no particular + // store semantics. "relaxed" is acceptable here because this + // stores to a local stack variable. + // + // Rationale for compare_exchange_strong as opposed to + // compare_exchange_weak: + // The spurious-failure case with compare_exchange_weak will actually + // happen a lot here, because the atomic 'status' bytes are stored + // contiguously in arrays and neighboring values will be accessed + // by multiple threads concurrently. On a typical ARM CPU, an exclusives + // reservation granule is 64 bytes, so a lot of false-sharing may + // happen. Using compare_exchange_weak would thus result in often having + // TryPack return 'false' when it could instead have done the packing + // work and returned 'true'. Heuristically, that is not a good thing. + // Moreover, this changes the TryPack contract, loosening it and making + // it harder for the caller to reason about. Finally, the overhead of + // atomic operations is mitigated by the enclosing check on + // local_packed, so maybe the overhead of compare_exchange_strong isn't + // such a problem. But we don't really know for sure, that would be + // interesting to experiment more with. + PackingStatus exchanged_status = PackingStatus::kNotStarted; + std::atomic& status = packing_status[side][block]; + if (status.compare_exchange_strong( + exchanged_status, PackingStatus::kInProgress, + std::memory_order_acq_rel, std::memory_order_acquire)) { + // In this branch, the status was kNotStarted and we just atomically + // changed it to kInProgress as we are about to handle the packing + // ourselves. + params->RunPack(side, tuning, start, end); + TraceRecordBlockPacked(thread_id, side, block, trace); + status.store(PackingStatus::kFinished, std::memory_order_release); + } else if (exchanged_status == PackingStatus::kInProgress) { + // Another thread is currently packing this block. + return false; + } + RUY_DCHECK(status.load(std::memory_order_acquire) == + PackingStatus::kFinished); + } else { + // Single-threaded case: no need for expensive atomics, local_packed + // is the truth already. + params->RunPack(side, tuning, start, end); + TraceRecordBlockPacked(thread_id, side, block, trace); + } + local_packed[side][block] = true; + } + return true; + } + + // Ensures that both the LHS and RHS blocks required by the specified block + // are packed. In the event that they are already being packed on another + // threads, this function may perform the packing of some other block while + // waiting for that other thread to finish packing the requested block. + void EnsurePacked(const SidePair& block, const SidePair& start, + const SidePair& end, Tuning tuning) { +#if RUY_OPT_ENABLED(RUY_OPT_PACK_AHEAD) + SidePair next_runahead_block{block[Side::kLhs] + 1, + block[Side::kRhs] + 1}; + Side next_runahead_side = Side::kLhs; +#endif + while (true) { + bool both_sides_packed = true; + for (Side side : {Side::kLhs, Side::kRhs}) { + both_sides_packed &= + TryPack(side, block[side], start[side], end[side], tuning); + } + if (both_sides_packed) { + break; + } +#if RUY_OPT_ENABLED(RUY_OPT_PACK_AHEAD) + const Side runahead_side = next_runahead_side; + const int runahead_block = next_runahead_block[runahead_side]; + next_runahead_side = + next_runahead_side == Side::kLhs ? Side::kRhs : Side::kLhs; + if (runahead_block >= NumBlocksPerSide(runahead_side, block_map)) { + continue; + } + int runahead_block_start, runahead_block_end; + GetBlockMatrixCoords(runahead_side, block_map, runahead_block, + &runahead_block_start, &runahead_block_end); + TryPack(runahead_side, runahead_block, runahead_block_start, + runahead_block_end, tuning); + next_runahead_block[runahead_side] = runahead_block + 1; +#endif + } + } + + TrMulParams* params; + const BlockMap& block_map; + std::atomic* atomic_block_id; + int thread_id; + bool need_atomics; + SidePair*> packing_status; + TuningResolver* tuning_resolver; + Allocator* local_allocator; + Trace* trace; + + // Local indicators of packedness to avoid the overhead of atomic ops. + SidePair local_packed; +}; + +void AllocatePMatrix(Allocator* allocator, PMatrix* packed) { + packed->data = allocator->AllocateBytes(DataSize(*packed)); + packed->sums = allocator->AllocateBytes(SumsSize(*packed)); +} + +int GetThreadCount(Context* context, int rows, int cols, int depth) { +#if RUY_PLATFORM(EMSCRIPTEN) + // b/139927184, std::thread constructor raises exception + return 1; +#endif + // Empirically determined rule for reasonable number of + // threads to use. This is proportional to the number of arithmetic ops + // in this Mul (product of the 3 sizes). + static constexpr int kDivisorLog2 = 15; + const int guess_log2 = std::max( + 0, ceil_log2(rows) + ceil_log2(cols) + ceil_log2(depth) - kDivisorLog2); + return std::min(1 << guess_log2, context->max_num_threads); +} + +LoopStructure GetLoopStructure(int tentative_thread_count, int rows, int cols, + int depth, int lhs_scalar_size, + int rhs_scalar_size, int local_data_cache_size, + int shared_data_cache_size) { + if (tentative_thread_count == 1) { + const BlockMapTraversalOrder traversal_order = + GetTraversalOrder(rows, cols, depth, lhs_scalar_size, rhs_scalar_size, + local_data_cache_size, shared_data_cache_size); + // If we are in the GEMV case or the block_map would be using linear + // traversal anyway, use the simple loop. + if ((cols == 1) || traversal_order == BlockMapTraversalOrder::kLinear) { + return LoopStructure::kSimple; + } + } + return LoopStructure::kGeneral; +} + +} // namespace + +void TrMul(TrMulParams* params, Context* context) { + profiler::ScopeLabel label( + "TrMul (Path=0x%x, max_num_threads=%d, is_prepacked=(%d,%d))", + static_cast(params->path), context->max_num_threads, + params->is_prepacked[Side::kLhs], params->is_prepacked[Side::kRhs]); + + PMatrix& packed_lhs = params->packed[Side::kLhs]; + PMatrix& packed_rhs = params->packed[Side::kRhs]; + DMatrix& lhs = params->src[Side::kLhs]; + DMatrix& rhs = params->src[Side::kRhs]; + + const int rows = lhs.layout.cols; + const int cols = rhs.layout.cols; + const int depth = lhs.layout.rows; + + const int tentative_thread_count = GetThreadCount(context, rows, cols, depth); + const auto loop_structure = GetLoopStructure( + tentative_thread_count, rows, cols, depth, lhs.data_type.size, + rhs.data_type.size, params->local_data_cache_size, + params->shared_data_cache_size); + Allocator* allocator = context->GetMainAllocator(); + + // Allocate packed matrices + for (Side side : {Side::kLhs, Side::kRhs}) { + if (!params->is_prepacked[side]) { + AllocatePMatrix(allocator, ¶ms->packed[side]); + } + } + + // Case of running this TrMul as a simple loop. + // This is a good place to start reading this function: all the rest + // of this function is just an optimized, but functionally equivalent, + // version of that. + if (loop_structure == LoopStructure::kSimple) { + profiler::ScopeLabel label_simple("TrMulImpl, simple loop"); + Tuning tuning = context->GetMainThreadTuning(); + + const SidePair origin{0, 0}; + const SidePair rounded_dims{packed_lhs.layout.cols, + packed_rhs.layout.cols}; + for (Side side : {Side::kLhs, Side::kRhs}) { + if (!params->is_prepacked[side]) { + params->RunPack(side, tuning, origin[side], rounded_dims[side]); + } + } + params->RunKernel(tuning, origin, rounded_dims); + + allocator->FreeAll(); + return; + } + + profiler::ScopeLabel label_general("TrMulImpl, general case"); + + auto* trace = NewTraceOrNull(&context->tracing, rows, depth, cols); + TraceRecordStart(trace); + + // Initialize block map. + BlockMap block_map; + MakeBlockMap(packed_lhs.layout.cols, packed_rhs.layout.cols, depth, + packed_lhs.layout.kernel.cols, packed_rhs.layout.kernel.cols, + packed_lhs.data_type.size, packed_rhs.data_type.size, + tentative_thread_count, params->path, + params->local_data_cache_size, params->shared_data_cache_size, + &block_map); + + // Initialize per-thread state. + const int thread_count = block_map.thread_count; + const bool need_atomics = thread_count > 1; + context->EnsureNPerThreadStates(thread_count); + for (auto& per_thread_state : context->per_thread_states) { + per_thread_state->tuning_resolver.SetTuning(context->explicit_tuning); + } + + // In the need_atomics case, allocate and initialize atomic values tracking + // the packing status of blocks. + SidePair*> packing_status{nullptr, nullptr}; + if (need_atomics) { + for (Side side : {Side::kLhs, Side::kRhs}) { + if (!params->is_prepacked[side]) { + const int size = NumBlocksPerSide(side, block_map); + allocator->Allocate(size, &packing_status[side]); + for (int i = 0; i < size; i++) { + packing_status[side][i].store(PackingStatus::kNotStarted, + std::memory_order_relaxed); + } + } + } + } + + // Create the atomic block id, allocate it using Allocator so that + // we get the alignment ensuring that it sits alone in its exclusives + // reservation granule. + std::atomic* atomic_block_id; + allocator->Allocate(1, &atomic_block_id); + + // Create task objects. + TrMulTask* tasks; + allocator->Allocate(thread_count, &tasks); + + atomic_block_id->store(thread_count); + + for (int i = 0; i < thread_count; i++) { + new (tasks + i) TrMulTask(params, block_map, atomic_block_id, i, + need_atomics, packing_status, + &context->per_thread_states[i]->tuning_resolver, + &context->per_thread_states[i]->allocator, trace); + } + + // Do the computation. + TraceRecordExecute(block_map, trace); + context->workers_pool.Execute(thread_count, tasks); + + // Finish up. + for (int i = 0; i < thread_count; i++) { + tasks[i].~TrMulTask(); + } + + allocator->FreeAll(); + TraceRecordEnd(trace); +} + +} // namespace ruy diff --git a/ruy/trmul.h b/ruy/trmul.h new file mode 100644 index 0000000..f50bb0c --- /dev/null +++ b/ruy/trmul.h @@ -0,0 +1,38 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// As a matrix multiplication library, Ruy offers a Mul entry point, performing +// matrix multiplication. For implementation purposes, it is much nicer to +// be dealing with the transpose-and-multiply operation, doing +// Destination = Transpose(LHS) * RHS +// Indeed, the latter is performing dot-products between the *columns* of LHS +// and the columns of RHS, whereas a plain matrix multiplication is performing +// dot-products between the *rows* of LHS and the columns of RHS. +// That is why TrMul is nicer to implement, allowing for a more symmetric +// treatment of LHS and RHS. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TRMUL_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TRMUL_H_ + +#include "ruy/context.h" +#include "ruy/trmul_params.h" + +namespace ruy { + +void TrMul(TrMulParams* params, Context* context); + +} // namespace ruy + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TRMUL_H_ diff --git a/ruy/trmul_params.h b/ruy/trmul_params.h new file mode 100644 index 0000000..47537b7 --- /dev/null +++ b/ruy/trmul_params.h @@ -0,0 +1,67 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TRMUL_PARAMS_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TRMUL_PARAMS_H_ + +#include "ruy/internal_matrix.h" +#include "ruy/side_pair.h" +#include "ruy/tune.h" + +namespace ruy { + +using RunKernelFn = void(Tuning, const SidePair&, void*, + const SidePair&, const SidePair&, DMatrix*); + +using RunPackFn = void(Tuning, const DMatrix&, PMatrix*, int, int); + +// Type-erased data needed for implementing TrMul. +struct TrMulParams { + TrMulParams() : run_pack{nullptr, nullptr}, is_prepacked{false, false} {} + // Helper functions for invoking the function pointers. + void RunPack(Side side, Tuning tuning, int start, int end) { + run_pack[side](tuning, src[side], &packed[side], start, end); + } + void RunKernel(Tuning tuning, const SidePair& start, + const SidePair& end) { + run_kernel(tuning, packed, spec, start, end, &dst); + } + + // path id, can be useful info for some fine-tuning, e.g. to guess reasonable + // cache sizes when not runtime-detectable. + Path path; + + // See Spec::local_data_cache_size(). + int local_data_cache_size = 0; + // See Spec::shared_data_cache_size(). + int shared_data_cache_size = 0; + + // Function pointers to type-erased entry points for kernels and packers. + SidePair run_pack; + RunKernelFn* run_kernel = nullptr; + + // Matrices and packed matrices. + SidePair src; + DMatrix dst; + SidePair packed; + SidePair is_prepacked; + + // Type-erased Spec. + void* spec = nullptr; +}; + +} // namespace ruy + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TRMUL_PARAMS_H_ diff --git a/ruy/tune.cc b/ruy/tune.cc new file mode 100644 index 0000000..a89242f --- /dev/null +++ b/ruy/tune.cc @@ -0,0 +1,161 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/tune.h" + +#include +#include + +namespace ruy { + +#ifdef RUY_IMPLEMENT_TUNING + +namespace { + +void PoorlyOrderedKernel(int iters) { + asm volatile( + "mov w0, %w[iters]\n" + "1:\n" + "subs w0, w0, #1\n" + "mul v0.4s, v0.4s, v0.4s\n" + "mul v0.4s, v0.4s, v0.4s\n" + "mul v0.4s, v0.4s, v0.4s\n" + "mul v0.4s, v0.4s, v0.4s\n" + "mul v1.4s, v1.4s, v1.4s\n" + "mul v1.4s, v1.4s, v1.4s\n" + "mul v1.4s, v1.4s, v1.4s\n" + "mul v1.4s, v1.4s, v1.4s\n" + "mul v2.4s, v2.4s, v2.4s\n" + "mul v2.4s, v2.4s, v2.4s\n" + "mul v2.4s, v2.4s, v2.4s\n" + "mul v2.4s, v2.4s, v2.4s\n" + "mul v3.4s, v3.4s, v3.4s\n" + "mul v3.4s, v3.4s, v3.4s\n" + "mul v3.4s, v3.4s, v3.4s\n" + "mul v3.4s, v3.4s, v3.4s\n" + "bne 1b\n" ::[iters] "r"(iters) + : "cc", "x0", "v0", "v1", "v2", "v3"); +} + +void NicelyOrderedKernel(int iters) { + asm volatile( + "mov w0, %w[iters]\n" + "1:\n" + "subs w0, w0, #1\n" + "mul v0.4s, v0.4s, v0.4s\n" + "mul v1.4s, v1.4s, v1.4s\n" + "mul v2.4s, v2.4s, v2.4s\n" + "mul v3.4s, v3.4s, v3.4s\n" + "mul v0.4s, v0.4s, v0.4s\n" + "mul v1.4s, v1.4s, v1.4s\n" + "mul v2.4s, v2.4s, v2.4s\n" + "mul v3.4s, v3.4s, v3.4s\n" + "mul v0.4s, v0.4s, v0.4s\n" + "mul v1.4s, v1.4s, v1.4s\n" + "mul v2.4s, v2.4s, v2.4s\n" + "mul v3.4s, v3.4s, v3.4s\n" + "mul v0.4s, v0.4s, v0.4s\n" + "mul v1.4s, v1.4s, v1.4s\n" + "mul v2.4s, v2.4s, v2.4s\n" + "mul v3.4s, v3.4s, v3.4s\n" + "bne 1b\n" ::[iters] "r"(iters) + : "cc", "x0", "v0", "v1", "v2", "v3"); +} + +} // namespace + +float TuningResolver::EvalRatio() { + // With the current settings, 400 iterations and 4 repeats, this test has + // a latency of roughly 80 microseconds on a Cortex-A53 at 1.4 GHz. + static constexpr int kLoopIters = 400; + static constexpr int kRepeats = 4; + + Duration timing_poorly_ordered = Duration::max(); + Duration timing_nicely_ordered = Duration::max(); + + for (int r = 0; r < kRepeats; r++) { + TimePoint t0 = Now(); + PoorlyOrderedKernel(kLoopIters); + TimePoint t1 = Now(); + NicelyOrderedKernel(kLoopIters); + TimePoint t2 = Now(); + timing_poorly_ordered = std::min(timing_poorly_ordered, t1 - t0); + timing_nicely_ordered = std::min(timing_nicely_ordered, t2 - t1); + } + + return ToFloatSeconds(timing_nicely_ordered) / + ToFloatSeconds(timing_poorly_ordered); +} + +float TuningResolver::ThresholdRatio() { + // Empirically (see :tune_tool) determined threshold to distinguish in-order + // Cortex-A53/A55 cores from out-of-order Cortex-A57/A73/A75/A76 cores. Based + // on these experimental results, which were obtained with much lower + // (kLoopIters=1000, kRepeats=1) so as to make them resilient to noise, we + // have: + // + // CPU core type | in/out of order | observed ratio + // --------------+-----------------+----------------------------------------- + // Cortex-A53 | in-order | 0.32 -- 0.329 + // Cortex-A55 | in-order | 0.319 -- 0.325 + // Cortex-A55r1 | in-order | 0.319 -- 0.325 + // Cortex-A57 | out-of-order | 0.99 -- 1.01 + // Cortex-A73 | out-of-order | 0.922 -- 0.927 + // Cortex-A75 | out-of-order | 0.921 -- 0.93 + // Cortex-A76 | out-of-order | 1 + // Kryo (pixel1) | out-of-order | 0.73 -- 0.76 + // + // Thus the allowable range for the threshold is [0.35 .. 0.70]. + // We pick a value closer to the upper bound because really any out-of-order + // CPU should by definition produce a ratio close to 1. + return 0.65f; +} + +Tuning TuningResolver::ResolveNow() { + const bool is_probably_inorder = EvalRatio() < ThresholdRatio(); + return is_probably_inorder ? Tuning::kInOrder : Tuning::kOutOfOrder; +} + +#else // not defined RUY_IMPLEMENT_TUNING + +float TuningResolver::EvalRatio() { return 0; } +float TuningResolver::ThresholdRatio() { return 0; } + +Tuning TuningResolver::ResolveNow() { return Tuning::kOutOfOrder; } + +#endif + +TuningResolver::TuningResolver() + : expiry_duration_(DurationFromMilliseconds(250)) {} + +Tuning TuningResolver::Resolve() { +#ifdef RUY_IMPLEMENT_TUNING + if (unresolved_tuning_ != Tuning::kAuto) { + return unresolved_tuning_; + } + TimePoint new_timepoint = CoarseNow(); + if (last_resolved_tuning_ != Tuning::kAuto && + (new_timepoint - last_resolved_timepoint_) < expiry_duration_) { + return last_resolved_tuning_; + } + last_resolved_timepoint_ = new_timepoint; + last_resolved_tuning_ = ResolveNow(); + return last_resolved_tuning_; +#else + return Tuning::kOutOfOrder; +#endif +} + +} // namespace ruy diff --git a/ruy/tune.h b/ruy/tune.h new file mode 100644 index 0000000..e6a0ee8 --- /dev/null +++ b/ruy/tune.h @@ -0,0 +1,163 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Library doing minimal CPU detection to decide what to tune asm code for. +// +// # Tuning vs Path +// +// Tunings are merely local variations of optimized code paths, that are +// drop-in replacements for each other --- the input and output data layouts +// are identical. By contrast, what ruy calls a Path dictates its own +// data layouts. For example, Path::kNeonDotprod will use different +// layouts compared to Path::kNeon; but within each, different tunings +// will share that same layout. +// +// # Tuning is for now only based on 1 bit: OutOfOrder / InOrder +// +// In practice, each of our asm code paths only needs one bit information to +// decide on tuning: whether the CPU is out-of-order or in-order. +// That is because out-of-order CPUs are by definition relatively insensitive +// to small-scale asm details (which is what "tuning" is about); and for each +// asm code path, there tends to be one main in-order CPU architecture that +// we focus our tuning effort on. Examples: +// * For Path::kNeon, the main in-order CPU is Cortex-A53/A55 (pre-dotprod) +// * For Path::kNeonDotprod, the main in-order CPU is Cortex-A55r1 (dotprod) +// +// Because having tuned code paths is a compromise of efficiency gains +// versus implementation effort and code size, we are happy to stop at just this +// single bit of information, OutOfOrder/InOrder, at least in the current CPU +// landscape. This could change in the future. +// +// # Implementation notes and alternatives. +// +// The current implementation uses a nano-benchmark, see tune.cc. +// That is why it's quite expensive, making caching / +// statefulness necessary (see TuningResolver class comment). +// +// An interesting alternative, which was explained to us by Marat Dukhan +// (maratek@) after this was implemented, would be to use the +// getcpu(2) system call on Linux. This returns a +// numeric CPU identifier that could be mapped to a OutOfOrder/InOrder +// classification given additional information about the CPU. Such +// additional information could be obtained by the cpuinfo library, +// https://github.com/pytorch/cpuinfo +// which obtains this information mainly from parsing /proc/cpuinfo. +// Pros: +// * Would remove the need for the relatively expensive nano-benchmark +// (dozens of microseconds, which have to be reevaluated again several +// times per second). +// * Would conceivably be more reliable. +// Cons: +// * Linux-specific. +// * Modest binary size increase (Marat mentioned the cpuinfo lib is 20k). +// * Won't support exactly 100% of devices (nonstandard /proc/cpuinfo etc). +// +// We could also have both: +// * Maybe by trying getcpu first if supported, then falling back to a +// nano-benchmark. +// * Maybe using getcpu in conjunction with the nano-benchmark to cache +// per-CPU-id nano-benchmark results. +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TUNE_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TUNE_H_ + +#include "ruy/opt_set.h" +#include "ruy/platform.h" +#include "ruy/time.h" + +// Tuning only implemented on NEON_64 at the moment (see assembly code +// in the nano-benchmark) and not on Apple (some Apple CPUs produce incorrect +// results on in-order-tuned kernels combining ARM and NEON load instructions +// and NEON `ins` instructions). +// +// When tuning is not implemented, we simply always use Tuning::kOutOfOrder. +#if RUY_OPT_ENABLED(RUY_OPT_TUNING) && RUY_PLATFORM(NEON_64) && \ + !RUY_PLATFORM(APPLE) +#define RUY_IMPLEMENT_TUNING +#endif + +namespace ruy { + +enum class Tuning { + // kAuto means please use auto-detection. It's the default in the + // user-visible parts (see Context). It's meant to be resolved to an + // actual tuning at some point by means of TuningResolver. + kAuto, + // Target an out-order CPU. Example: ARM Cortex-A75. + kOutOfOrder, + // Target an in-order CPU. Example: ARM Cortex-A55. + kInOrder +}; + +// Why a TuningResolver class? +// +// Ideally, this Library would offer a single function, +// Tuning GetCurrentCPUTuning(); +// +// However, determining information about the current CPU is not necessarily, +// cheap, so we currently cache that and only invalidate/reevaluate after +// a fixed amount of time. This need to store state is why this library +// has to expose a class, TuningResolver, not just a function. +class TuningResolver { + public: + TuningResolver(); + + // Allows the user to specify an explicit Tuning value, bypassing auto + // detection; or to specify Tuning::kAuto, reverting to auto detection. + void SetTuning(Tuning tuning) { unresolved_tuning_ = tuning; } + + // Get an actual tuning --- that is the function that this class wanted to be. + Tuning Resolve(); + + private: + TuningResolver(const TuningResolver&) = delete; + + // TuningTool is a demo/tool used to tweak the tuning implementation to + // specific devices. It needs to access some finer granularity information + // than just the Tuning returned by Resolve. Nothing else should need + // access to that. + friend class TuneTool; + // Actually runs a nano-benchmark, producing a real number called 'ratio' + // whose meaning is generally opaque / implementation defined. Typically, + // this would be the ratio between the latencies of two different + // pieces of asm code differing only by the ordering of instructions, + // revealing whether the CPU cares about such ordering details. + // An implementation may just return a dummy value if it is not based on + // such nanobenchmarking / ratio evaluation. + float EvalRatio(); + // Empirically determined threshold on ratio values delineating + // out-of-order (ratios closer to 1) from in-order (ratios farther from 1). + // An implementation may just return a dummy value if it is not based on + // such nanobenchmarking / ratio evaluation. + float ThresholdRatio(); + // Perform the tuning resolution now. That may typically use EvalRatio and + // ThresholdRatio, but an implementation may use a different approach instead. + Tuning ResolveNow(); + + // The tuning as specified by the user, before actual resolution happens + // i.e. before querying any specifics of the current CPU. + // The default value kAuto means try to auto-detect. Other values mean + // bypass auto-detect, use explicit value instead. See SetTuning(). + Tuning unresolved_tuning_ = Tuning::kAuto; + // Cached last resolved tuning. + Tuning last_resolved_tuning_ = Tuning::kAuto; + // Timepoint of cached last resolved tuning, for invalidation purposes. + TimePoint last_resolved_timepoint_; + // Cached last resolved tunings that are older than this age are invalid. + const Duration expiry_duration_; +}; + +} // namespace ruy + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_TUNE_H_ diff --git a/ruy/tune_test.cc b/ruy/tune_test.cc new file mode 100644 index 0000000..ebd86e0 --- /dev/null +++ b/ruy/tune_test.cc @@ -0,0 +1,53 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/tune.h" + +#include // NOLINT(build/c++11) +#include // NOLINT(build/c++11) + +#include "testing/base/public/gunit.h" + +namespace ruy { +namespace { + +TEST(TuneTest, TuneTest) { + TuningResolver tuning_resolver; + ASSERT_FALSE(tuning_resolver.Resolve() == Tuning::kAuto); + // 1 second is likely higher than TuningResolver's internal cache expiry, + // exercising the logic invalidating earlier tuning resolutions. + std::this_thread::sleep_for(std::chrono::seconds(1)); + ASSERT_FALSE(tuning_resolver.Resolve() == Tuning::kAuto); + + tuning_resolver.SetTuning(Tuning::kAuto); + +#ifdef RUY_IMPLEMENT_TUNING + for (auto tuning : {Tuning::kOutOfOrder, Tuning::kInOrder}) { + tuning_resolver.SetTuning(tuning); + ASSERT_TRUE(tuning_resolver.Resolve() == tuning); + // See above comment about 1 second. + std::this_thread::sleep_for(std::chrono::seconds(1)); + ASSERT_TRUE(tuning_resolver.Resolve() == tuning); + } +#endif +} + +} // namespace +} // namespace ruy + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/ruy/tune_tool.cc b/ruy/tune_tool.cc new file mode 100644 index 0000000..0b6e4ab --- /dev/null +++ b/ruy/tune_tool.cc @@ -0,0 +1,56 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Self-contained tool used to tune the tune code --- see the +// threshold ratios used in tune.cc. + +#include // NOLINT(build/c++11) +#include +#include // NOLINT(build/c++11) + +#include "ruy/tune.h" + +#ifdef _WIN32 +#define getpid() 0 +#else +#include +#endif + +namespace ruy { + +class TuneTool { + public: + static void Query(float* eval, float* threshold) { + TuningResolver resolver; + *eval = resolver.EvalRatio(); + *threshold = resolver.ThresholdRatio(); + } +}; + +} // namespace ruy + +int main() { + // Infinite loop: the user can hit Ctrl-C + while (true) { + float eval; + float threshold; + ruy::TuneTool::Query(&eval, &threshold); + printf("[%d] eval=%.3f %c threshold=%.3f ==> probably %s...\n", getpid(), + eval, eval < threshold ? '<' : '>', threshold, + eval < threshold ? "in-order" : "out-of-order"); + fflush(stdout); + std::this_thread::sleep_for(std::chrono::seconds(1)); + } +} diff --git a/ruy/wait.cc b/ruy/wait.cc new file mode 100644 index 0000000..d8156bc --- /dev/null +++ b/ruy/wait.cc @@ -0,0 +1,69 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/wait.h" + +#include // NOLINT(build/c++11) + +namespace ruy { + +void Wait(const std::function& condition, const Duration& spin_duration, + std::condition_variable* condvar, std::mutex* mutex) { + // First, trivial case where the `condition` is already true; + if (condition()) { + return; + } + + // Then try busy-waiting. + const TimePoint wait_start = Now(); + while (Now() - wait_start < spin_duration) { + if (condition()) { + return; + } + } + + // Finally, do real passive waiting. + std::unique_lock lock(*mutex); + condvar->wait(lock, condition); +} + +void Wait(const std::function& condition, + std::condition_variable* condvar, std::mutex* mutex) { + // This value was empirically derived with some microbenchmark, we don't have + // high confidence in it. + // + // TODO(b/135595069): make this value configurable at runtime. + // I almost wanted to file another bug to ask for experimenting in a more + // principled way to tune this value better, but this would have to be tuned + // on real end-to-end applications and we'd expect different applications to + // require different tunings. So the more important point is the need for + // this to be controllable by the application. + // + // That this value means that we may be sleeping substantially longer + // than a scheduler timeslice's duration is not necessarily surprising. The + // idea is to pick up quickly new work after having finished the previous + // workload. When it's new work within the same GEMM as the previous work, the + // time interval that we might be busy-waiting is very small, so for that + // purpose it would be more than enough to sleep for 1 ms. + // That is all what we would observe on a GEMM benchmark. However, in a real + // application, after having finished a GEMM, we might do unrelated work for + // a little while, then start on a new GEMM. In that case the wait interval + // may be a little longer. There may also not be another GEMM for a long time, + // in which case we'll end up passively waiting below. + const Duration spin_duration = DurationFromMilliseconds(2); + Wait(condition, spin_duration, condvar, mutex); +} + +} // namespace ruy diff --git a/ruy/wait.h b/ruy/wait.h new file mode 100644 index 0000000..900ec8d --- /dev/null +++ b/ruy/wait.h @@ -0,0 +1,73 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_WAIT_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_WAIT_H_ + +#include // NOLINT(build/c++11) +#include +#include // NOLINT(build/c++11) + +#include "ruy/time.h" + +namespace ruy { + +// Waits until some evaluation of `condition` has returned true. +// +// There is no guarantee that calling `condition` again after this function +// has returned would still return true. The only +// contract is that at some point during the execution of that function, +// `condition` has returned true. +// +// First does some spin-waiting for the specified `spin_duration`, +// then falls back to passive waiting for the given condvar, guarded +// by the given mutex. At this point it will try to acquire the mutex lock, +// around the waiting on the condition variable. +// Therefore, this function expects that the calling thread hasn't already +// locked the mutex before calling it. +// This function will always release the mutex lock before returning. +// +// The idea of doing some initial spin-waiting is to help get +// better and more consistent multithreading benefits for small GEMM sizes. +// Spin-waiting help ensuring that if we need to wake up soon after having +// started waiting, then we can wake up quickly (as opposed to, say, +// having to wait to be scheduled again by the OS). On the other hand, +// we must still eventually revert to passive waiting for longer waits +// (e.g. worker threads having finished a GEMM and waiting until the next GEMM) +// so as to avoid permanently spinning. +// +// In situations where other threads might have more useful things to do with +// these CPU cores than our spin-waiting, it may be best to reduce the value +// of `spin_duration`. Setting it to zero disables the spin-waiting entirely. +// +// There is a risk that the std::function used here might use a heap allocation +// to store its context. The expected usage pattern is that these functions' +// contexts will consist of a single pointer value (typically capturing only +// [this]), and that in this case the std::function implementation will use +// inline storage, avoiding a heap allocation. However, we can't effectively +// guard that assumption, and that's not a big concern anyway because the +// latency of a small heap allocation is probably low compared to the intrinsic +// latency of what this Wait function does. +void Wait(const std::function& condition, const Duration& spin_duration, + std::condition_variable* condvar, std::mutex* mutex); + +// Convenience overload using a default `spin_duration`. +// TODO(benoitjacob): let this be controlled from the ruy API. +void Wait(const std::function& condition, + std::condition_variable* condvar, std::mutex* mutex); + +} // namespace ruy + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_WAIT_H_ diff --git a/ruy/wait_test.cc b/ruy/wait_test.cc new file mode 100644 index 0000000..f0548f9 --- /dev/null +++ b/ruy/wait_test.cc @@ -0,0 +1,117 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ruy/wait.h" + +#include +#include // NOLINT(build/c++11) +#include // NOLINT(build/c++11) +#include // NOLINT(build/c++11) + +#include "testing/base/public/gunit.h" +#include "ruy/platform.h" + +namespace ruy { +namespace { + +// Thread taking a `value` atomic counter and incrementing it until it equals +// `end_value`, then notifying the condition variable as long as +// `value == end_value`. If `end_value` is increased, it will then resume +// incrementing `value`, etc. Terminates if `end_value == -1`. +class ThreadCountingUpToValue { + public: + ThreadCountingUpToValue(const std::atomic& end_value, + std::atomic* value, + std::condition_variable* condvar, std::mutex* mutex) + : end_value_(end_value), + value_(value), + condvar_(condvar), + mutex_(mutex) {} + void operator()() { + // end_value_==-1 is how the master thread will tell us it's OK to terminate + while (end_value_.load() != -1) { + // wait until end_value is set to a higher value + while (value_->load() == end_value_.load()) { + } + // increment value as long as it's lower than end_value + while (value_->fetch_add(1) < end_value_.load() - 1) { + } + // when value has reached end_value, notify the master thread. + while (value_->load() == end_value_.load()) { + std::lock_guard lock(*mutex_); + condvar_->notify_all(); + } + } + } + + private: + const std::atomic& end_value_; + std::atomic* value_; + std::condition_variable* condvar_; + std::mutex* mutex_; +}; + +void WaitTest(const Duration& spin_duration, const Duration& delay) { +#if RUY_PLATFORM(EMSCRIPTEN) + // b/139927184, std::thread constructor raises exception + return; +#endif + std::condition_variable condvar; + std::mutex mutex; + std::atomic value(0); + std::atomic end_value(0); + ThreadCountingUpToValue thread_callable(end_value, &value, &condvar, &mutex); + std::thread thread(thread_callable); + std::this_thread::sleep_for(delay); + for (int i = 1; i < 10; i++) { + end_value.store(1000 * i); + const auto& condition = [&value, &end_value]() { + return value.load() == end_value.load(); + }; + ruy::Wait(condition, spin_duration, &condvar, &mutex); + EXPECT_EQ(value.load(), end_value.load()); + } + end_value.store(-1); + thread.join(); +} + +TEST(WaitTest, WaitTestNoSpin) { + WaitTest(DurationFromSeconds(0), DurationFromSeconds(0)); +} + +TEST(WaitTest, WaitTestSpinOneMicrosecond) { + WaitTest(DurationFromSeconds(1e-6), DurationFromSeconds(0)); +} + +TEST(WaitTest, WaitTestSpinOneMillisecond) { + WaitTest(DurationFromSeconds(1e-3), DurationFromSeconds(0)); +} + +TEST(WaitTest, WaitTestSpinOneSecond) { + WaitTest(DurationFromSeconds(1), DurationFromSeconds(0)); +} + +// Testcase to consistently reproduce the hang in b/139062384. +TEST(WaitTest, WaitTestNoSpinWithDelayBug139062384) { + WaitTest(DurationFromSeconds(0), DurationFromSeconds(1)); +} + +} // namespace +} // namespace ruy + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/ruy_advanced.h b/ruy_advanced.h deleted file mode 100644 index 333e173..0000000 --- a/ruy_advanced.h +++ /dev/null @@ -1,69 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_ADVANCED_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_ADVANCED_H_ - -#include -#include - -#include "context.h" -#include "matrix.h" -#include "path.h" -#include "prepack.h" -#include "side_pair.h" - -namespace ruy { - -// Low-level, explicit pre-packing API. -// -// The cost of packing an input matrix (either the LHS or RHS) is amortized -// across the non-depth dimension of the opposite input matrix. Thus, when the -// LHS has very few rows or the RHS has very few columns, the cost of packing -// the opposite input matrix can become significant. See pack.h for further -// information on packing. -// -// This file provides an API allowing a user to explicitly pack a matrix and -// reuse the pre-packed matrix, avoiding that cost. -// -// See example_prepack.cc for example usage. - -template -void PrePackForMul(const Matrix& lhs, const Matrix& rhs, - const Spec& spec, Context* context, Matrix* dst, - PrepackedMatrix* prepacked_lhs, - PrepackedMatrix* prepacked_rhs, - std::function alloc_fn) { - SidePair prepacked(prepacked_lhs, prepacked_rhs); - PrePackForMulInternal(lhs, rhs, spec, context, dst, prepacked, - alloc_fn); -} - -template -void MulWithPrepacked(const Matrix& lhs, - const Matrix& rhs, const Spec& spec, - Context* context, Matrix* dst, - PrepackedMatrix* prepacked_lhs, - PrepackedMatrix* prepacked_rhs) { - SidePair prepacked(prepacked_lhs, prepacked_rhs); - MulWithPrepackedInternal(lhs, rhs, spec, context, dst, - prepacked); -} - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_ADVANCED_H_ diff --git a/ruy_test.bzl b/ruy_test.bzl deleted file mode 100644 index ef7e8b1..0000000 --- a/ruy_test.bzl +++ /dev/null @@ -1,34 +0,0 @@ -# Provides the ruy_test macro for type-parametrized tests. -"""ruy_test is a macro for building a test with multiple paths corresponding to tuples of types for LHS, RHS, accumulator and destination.""" - -def ruy_test(name, srcs, lhs_rhs_accum_dst, copts, tags = [], deps = None): - for (lhs, rhs, accum, dst) in lhs_rhs_accum_dst: - native.cc_test( - name = "%s_%s_%s_%s_%s" % (name, lhs, rhs, accum, dst), - srcs = srcs, - copts = copts + [ - "-DRUY_TEST_LHSSCALAR=%s" % lhs, - "-DRUY_TEST_RHSSCALAR=%s" % rhs, - "-DRUY_TEST_ACCUMSCALAR=%s" % accum, - "-DRUY_TEST_DSTSCALAR=%s" % dst, - ], - deps = deps, - tags = tags, - ) - -def ruy_benchmark(name, srcs, lhs_rhs_accum_dst, copts, deps = None): - tags = ["req_dep=//third_party/gemmlowp:profiler"] - for (lhs, rhs, accum, dst) in lhs_rhs_accum_dst: - native.cc_binary( - name = "%s_%s_%s_%s_%s" % (name, lhs, rhs, accum, dst), - testonly = True, - srcs = srcs, - copts = copts + [ - "-DRUY_TEST_LHSSCALAR=%s" % lhs, - "-DRUY_TEST_RHSSCALAR=%s" % rhs, - "-DRUY_TEST_ACCUMSCALAR=%s" % accum, - "-DRUY_TEST_DSTSCALAR=%s" % dst, - ], - deps = deps, - tags = tags, - ) diff --git a/ruy_test_ext.bzl b/ruy_test_ext.bzl deleted file mode 100644 index 5701fff..0000000 --- a/ruy_test_ext.bzl +++ /dev/null @@ -1,7 +0,0 @@ -"""Allows to specialize the ruy BUILD to availability of external libraries""" - -def ruy_test_ext_defines(): - return [] - -def ruy_test_ext_deps(): - return [] diff --git a/side_pair.h b/side_pair.h deleted file mode 100644 index 2951760..0000000 --- a/side_pair.h +++ /dev/null @@ -1,64 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_SIDE_PAIR_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_SIDE_PAIR_H_ - -#include "check_macros.h" - -namespace ruy { - -// Enumeration of the sides, i.e. the operands 'slots', in a matrix -// multiplication. The numerical values of these enumeration constants matter -// because these will be used as indices into the array underlying a SidePair. -enum class Side { - // Left-hand side - kLhs = 0, - // Right-hand side - kRhs = 1 -}; - -// SidePair is a pair container where the two elements are indexed by a Side -// enum. -template -class SidePair final { - public: - SidePair() {} - SidePair(const T& a, const T& b) : elem_{a, b} {} - const T& operator[](Side side) const { - const int index = static_cast(side); - // Technically this check is vacuous, since other values would be - // out-of-range for enum Side. - RUY_DCHECK(index == 0 || index == 1); - return elem_[index]; - } - - T& operator[](Side side) { - const int index = static_cast(side); - // Technically this check is vacuous, since other values would be - // out-of-range for enum Side. - RUY_DCHECK(index == 0 || index == 1); - return elem_[index]; - } - - private: - static_assert(static_cast(Side::kLhs) == 0, ""); - static_assert(static_cast(Side::kRhs) == 1, ""); - T elem_[2]; -}; - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_SIDE_PAIR_H_ diff --git a/size_util.h b/size_util.h deleted file mode 100644 index e459c22..0000000 --- a/size_util.h +++ /dev/null @@ -1,93 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_SIZE_UTIL_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_SIZE_UTIL_H_ - -#include - -#include "check_macros.h" - -#ifdef _WIN32 -#include -#endif - -namespace ruy { - -template -inline Integer floor_log2(Integer n) { - static_assert(std::is_integral::value, ""); - static_assert(std::is_signed::value, ""); - static_assert(sizeof(Integer) == 4 || sizeof(Integer) == 8, ""); - - RUY_DCHECK_GE(n, 1); -#ifdef _WIN32 - unsigned long result; // NOLINT[runtime/int] - if (sizeof(Integer) == 4) { - _BitScanReverse(&result, n); - } else { - _BitScanReverse64(&result, n); - } - return result; -#else - if (sizeof(Integer) == 4) { - return 31 - __builtin_clz(n); - } else { - return 63 - __builtin_clzll(n); - } -#endif -} - -template -Integer ceil_log2(Integer n) { - RUY_DCHECK_GE(n, 1); - return n == 1 ? 0 : floor_log2(n - 1) + 1; -} - -template -bool is_pot(Integer value) { - return (value > 0) && ((value & (value - 1)) == 0); -} - -template -Integer pot_log2(Integer n) { - RUY_DCHECK(is_pot(n)); - return floor_log2(n); -} - -template -Integer round_down_pot(Integer value) { - return static_cast(1) << floor_log2(value); -} - -template -Integer round_up_pot(Integer value) { - return static_cast(1) << ceil_log2(value); -} - -template -Integer round_down_pot(Integer value, Modulo modulo) { - RUY_DCHECK_EQ(modulo & (modulo - 1), 0); - return value & ~(modulo - 1); -} - -template -Integer round_up_pot(Integer value, Modulo modulo) { - return round_down_pot(value + modulo - 1, modulo); -} - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_SIZE_UTIL_H_ diff --git a/size_util_test.cc b/size_util_test.cc deleted file mode 100644 index 393f21e..0000000 --- a/size_util_test.cc +++ /dev/null @@ -1,101 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "size_util.h" - -#include -#include -#include - -#include "testing/base/public/gunit.h" - -namespace ruy { -namespace { - -template -void SizeUtilTestValue(Integer value) { - if (value == 0) { - return; - } - - EXPECT_LE(0, floor_log2(value)); - EXPECT_LE(floor_log2(value), ceil_log2(value)); - EXPECT_LE(ceil_log2(value), 8 * sizeof(Integer)); - - if (is_pot(value)) { - EXPECT_EQ(floor_log2(value), ceil_log2(value)); - EXPECT_EQ(floor_log2(value), pot_log2(value)); - } else { - EXPECT_EQ(floor_log2(value) + 1, ceil_log2(value)); - } - EXPECT_EQ(value >> floor_log2(value), 1); - EXPECT_EQ(round_down_pot(value), static_cast(1) - << floor_log2(value)); - EXPECT_LE(round_down_pot(value), value); - EXPECT_GE(round_down_pot(value), value >> 1); - EXPECT_TRUE(is_pot(round_down_pot(value))); - - if (ceil_log2(value) < 8 * sizeof(Integer) - 1) { - EXPECT_EQ(value >> ceil_log2(value), is_pot(value) ? 1 : 0); - EXPECT_EQ(round_up_pot(value), static_cast(1) << ceil_log2(value)); - EXPECT_GE(round_up_pot(value), value); - EXPECT_LE(round_up_pot(value) >> 1, value); - EXPECT_TRUE(is_pot(round_up_pot(value))); - } - - for (std::uint8_t modulo : {1, 2, 8, 32, 128}) { - EXPECT_GE(value, round_down_pot(value, modulo)); - EXPECT_EQ(round_down_pot(value, modulo) % modulo, 0); - - if (value <= std::numeric_limits::max() - modulo) { - EXPECT_LE(value, round_up_pot(value, modulo)); - EXPECT_EQ(round_up_pot(value, modulo) % modulo, 0); - } - } -} - -template -void SizeUtilTest() { - for (int exponent = 0; exponent < 8 * sizeof(Integer) - 1; exponent++) { - const Integer pot = static_cast(1) << exponent; - SizeUtilTestValue(pot - 1); - SizeUtilTestValue(pot); - SizeUtilTestValue(pot + 1); - SizeUtilTestValue(pot + 12); - SizeUtilTestValue(pot + 123); - } - SizeUtilTestValue(std::numeric_limits::max() - 1); - SizeUtilTestValue(std::numeric_limits::max()); -} - -TEST(SizeUtilTest, Int) { SizeUtilTest(); } - -TEST(SizeUtilTest, Long) { SizeUtilTest(); } // NOLINT - -TEST(SizeUtilTest, LongLong) { SizeUtilTest(); } // NOLINT - -TEST(SizeUtilTest, Int32) { SizeUtilTest(); } - -TEST(SizeUtilTest, Int64) { SizeUtilTest(); } - -TEST(SizeUtilTest, Ptrdiff) { SizeUtilTest(); } - -} // namespace -} // namespace ruy - -int main(int argc, char **argv) { - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); -} diff --git a/spec.h b/spec.h deleted file mode 100644 index 178ff20..0000000 --- a/spec.h +++ /dev/null @@ -1,118 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_SPEC_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_SPEC_H_ - -#include -#include - -#include "cpu_cache_size.h" -#include "matrix.h" - -namespace ruy { - -// Our 'general' loop structure (the default) involves multi-threading and -// complicated loops aiming to optimize cache-friendliness. One may opt out of -// this and pick the 'simple' loop structure instead, which only performs well -// for small matrix sizes and only allows using one thread, in exchange for -// smaller code size. -enum class LoopStructure { kGeneral, kSimple, kAuto }; - -// In general we allow zero_point's to have any Scalar value. This is called -// 'asymmetric' quantization. We do take advantage of the optimization -// opportunities when zero_points happen at runtime to be 'symmetric' (e.g. the -// int8 value 0 or the uint8 value 128), but we still generate code to handle -// the general asymmetric case. By choosing kSymmetric here, one opts out of -// this and supports only the symmetric case, in exchange for smaller code size. -enum class ZeroPointSupport { kGeneral, kSymmetric }; - -// In general we allow all Layout's, even if we may use slow paths for some -// kinds of layouts. By choosing kRCC, one may opt out of this and -// only keep support for the simplest and most efficient combination of -// Layout's, in exchange for smaller code size. The case covered by -// kRCC is where the storage orders are exactly the following: -// - LHS is RowMajor -// - RHS is ColMajor -// - Destination is ColMajor -enum class LayoutSupport { kGeneral, kRCC }; - -// A Spec describes all about a matrix multiplication operation that isn't -// encoded in the LHS, RHS and destination matrices. Some of that information -// is encoded as compile-time constants and types (for instance, the choice -// of accumulator type, AccumScalar). Some of that information is encoded as -// runtime values (for instance, the optional bias vector). -template -struct BasicSpec { - // Accumulator type. The type of accumulators used to compute the dot-products - // before being ultimately casted to the destination type. - using AccumScalar = tAccumScalar; - // The destination scalar type. - using DstScalar = tDstScalar; - // The bias vector data, if not null. - const AccumScalar* bias = nullptr; - // Only for non-floating-point cases. The fixed-point part (i.e. the mantissa) - // of the multiplier by which accumulators are multiplied before being casted - // to the destination type. - AccumScalar multiplier_fixedpoint = 0; - // Only for non-floating-point cases. The exponent part of the aforementioned - // multiplier. - int multiplier_exponent = 0; - // Per-channel variant of multiplier_fixedpoint. If not nullptr, this must - // point to a buffer of as many values as there are rows in the destination - // matrix. Each row of the destination matrix will use the corresponding - // buffer element instead of multiplier_fixedpoint. - const AccumScalar* multiplier_fixedpoint_perchannel = nullptr; - // Per-channel variant of multiplier_exponent. If not nullptr, this must - // point to a buffer of as many values as there are rows in the destination - // matrix. Each row of the destination matrix will use the corresponding - // buffer element instead of multiplier_exponent. - // - // Either none or both of multiplier_exponent_perchannel and - // multiplier_fixedpoint_perchannel must be nullptr. - const int* multiplier_exponent_perchannel = nullptr; - // min clamp bound of destination values. - DstScalar clamp_min = std::is_floating_point::value - ? -std::numeric_limits::infinity() - : std::numeric_limits::lowest(); - // max clamp bound of destination values. - DstScalar clamp_max = std::is_floating_point::value - ? std::numeric_limits::infinity() - : std::numeric_limits::max(); - // See above enum LoopStructure - static constexpr LoopStructure kLoopStructure = LoopStructure::kAuto; - // See above enum LayoutSupport - static constexpr LayoutSupport kLayoutSupport = LayoutSupport::kGeneral; - // See above enum ZeroPointSupport - static constexpr ZeroPointSupport kZeroPointSupport = - ZeroPointSupport::kGeneral; - // Testing-only, not meant to be used by actual users: - // Used for testing of various kernel layouts. - using StandardCppKernelLhsLayout = FixedKernelLayout; - using StandardCppKernelRhsLayout = FixedKernelLayout; - // Returns (a reasonable estimate of) the local CPU cache size. - // See ruy::LocalDataCacheSize() which returns some coarse, sane default for - // each CPU architecture. - // This may be overridden, either to provide more accurate/runtime values, - // or to test with other values to let testcases have more coverage. - static int local_data_cache_size() { return LocalDataCacheSize(); } - // Same as local_data_cache_size but for the total data cache size accessible - // to each CPU core. See ruy::SharedDataCacheSize(). - static int shared_data_cache_size() { return SharedDataCacheSize(); } -}; - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_SPEC_H_ diff --git a/test.h b/test.h deleted file mode 100644 index 8c93a56..0000000 --- a/test.h +++ /dev/null @@ -1,2125 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_TEST_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_TEST_H_ - -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "testing/base/public/gunit.h" // IWYU pragma: export -#include "matrix.h" // IWYU pragma: export -#include "platform.h" -#include "pmu.h" -#include "ruy.h" -#include "ruy_advanced.h" -#include "spec.h" // IWYU pragma: export -#include "time.h" - -#ifdef RUY_TEST_EXTERNAL_PATHS -#define EIGEN_USE_THREADS -#define EIGEN_USE_CUSTOM_THREAD_POOL -#include "third_party/eigen3/Eigen/Core" -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#include "third_party/gemmlowp/public/gemmlowp.h" -#include "third_party/lapack/blas.h" -#endif - -#ifdef RUY_PROFILER -#include "profiler/profiler.h" -#endif - -namespace ruy { - -const float kClampRatio = 0.1f; - -enum class ExternalPath { kNone, kGemmlowp, kEigen, kEigenTensor, kOpenBlas }; - -inline std::vector* CoveredPaths() { - static std::vector covered_paths; - return &covered_paths; -} - -inline const char* PathName(Path path) { -#define RUY_PATHNAME_CASE(NAME) \ - case Path::NAME: \ - return #NAME; - switch (path) { - RUY_PATHNAME_CASE(kReference) - RUY_PATHNAME_CASE(kStandardCpp) -#if RUY_PLATFORM(NEON) - RUY_PATHNAME_CASE(kNeon) - RUY_PATHNAME_CASE(kNeonDotprod) -#elif RUY_PLATFORM(X86) - RUY_PATHNAME_CASE(kSse42) - RUY_PATHNAME_CASE(kAvx2) - RUY_PATHNAME_CASE(kAvx512) - RUY_PATHNAME_CASE(kAvxVnni) -#endif - default: - RUY_CHECK(false); - return nullptr; - } -#undef RUY_PATHNAME_CASE -} - -inline const char* TuningName(Tuning tuning) { -#define RUY_SUBPATHNAME_CASE(NAME) \ - case Tuning::NAME: \ - return #NAME; - switch (tuning) { - RUY_SUBPATHNAME_CASE(kInOrder) - RUY_SUBPATHNAME_CASE(kOutOfOrder) - default: - RUY_CHECK(false); - return nullptr; - } -#undef RUY_SUBPATHNAME_CASE -} - -inline const char* PathName(ExternalPath path) { -#define RUY_PATHNAME_CASE(NAME) \ - case ExternalPath::NAME: \ - return #NAME; - switch (path) { - RUY_PATHNAME_CASE(kGemmlowp) - RUY_PATHNAME_CASE(kEigen) - RUY_PATHNAME_CASE(kEigenTensor) - RUY_PATHNAME_CASE(kOpenBlas) - default: - RUY_CHECK(false); - return nullptr; - } -#undef RUY_PATHNAME_CASE -} - -inline std::ostream& operator<<(std::ostream& stream, Path path) { - return stream << PathName(path); -} - -inline std::ostream& operator<<(std::ostream& stream, - ExternalPath external_path) { - return stream << PathName(external_path); -} - -template -std::string Join(const ContainerType& container) { - if (container.empty()) { - return ""; - } - std::ostringstream stream; - auto it = container.begin(); - stream << *it++; - for (; it != container.end(); ++it) { - stream << ", "; - stream << *it; - } - return stream.str(); -} - -struct LogCoveredPathsOnDestruction final { - ~LogCoveredPathsOnDestruction() { - std::cerr << "Covered paths: " << Join(*CoveredPaths()) << std::endl; - - // When testing on ARM64 ChromiumOS emulator, make sure that we covered - // the dotprod path. We're getting such coverage at the moment thanks to - // using a sufficiently recent emulator, and we don't want to regress that. -#if RUY_PLATFORM(ARM_64) && defined RUY_TESTING_ON_CHROMIUMOS - bool found_dotprod = false; - for (const std::string& covered_path : *CoveredPaths()) { - if (covered_path == "kNeonDotprod") { - found_dotprod = true; - } - } - if (!found_dotprod) { - std::cerr - << "Error: we haven't tested the kNeonDotprod path as we should " - "have. At the moment, this is required on ChromiumOS as this is " - "what we run emulator tests in, that currently supports " - "dot-product " - "instructions, and we care very much about not regressing that. " - "If this test was run in an emulator, please upgrade to a newer " - "emulator version. If this test was run on an actual device, and " - "you need to be able to run ruy tests on devices not supporting " - "dot-product instructions, get in touch with us.\n" - << std::endl; - abort(); - } -#endif - } - static void Singleton() { static LogCoveredPathsOnDestruction singleton; } -}; - -enum class RandomRange { - kGeneral, - kAvoidMinValue, - kOffCenterAvoidMinValue, - kReasonableSrcZeroPoint, - kReasonableDstZeroPoint, - kBias -}; - -template ::value> -struct RandomRangeBounds {}; - -template -struct RandomRangeBounds { - static Scalar GetMinBound(RandomRange range) { - switch (range) { - case RandomRange::kGeneral: - return -1; - case RandomRange::kAvoidMinValue: - return -1; - case RandomRange::kOffCenterAvoidMinValue: - return -1; - case RandomRange::kReasonableSrcZeroPoint: - return 0; - case RandomRange::kReasonableDstZeroPoint: - return 0; - case RandomRange::kBias: - return -1; - default: - RUY_CHECK(false); - return 0; - } - } - static Scalar GetMaxBound(RandomRange range) { - switch (range) { - case RandomRange::kGeneral: - return 1; - case RandomRange::kAvoidMinValue: - return 1; - case RandomRange::kOffCenterAvoidMinValue: - return 1; - case RandomRange::kReasonableSrcZeroPoint: - return 0; - case RandomRange::kReasonableDstZeroPoint: - return 0; - case RandomRange::kBias: - return 1; - default: - RUY_CHECK(false); - return 0; - } - } -}; - -template -Scalar WeightedSum(Scalar s1, float weight1, Scalar s2, float weight2) { - float sum = s1 * weight1 + s2 * weight2; - float clamped = std::min( - std::numeric_limits::max(), - std::max(std::numeric_limits::lowest(), sum)); - return static_cast(clamped); -} - -template -Scalar Parametrized(float param) { - return WeightedSum(std::numeric_limits::max(), param, - std::numeric_limits::lowest(), 1 - param); -} - -template -struct RandomRangeBounds { - static Scalar GetMinBound(RandomRange range) { - static constexpr double offcenteredness = - 0.02; // Shift lower limit by about 5 for range of 255. - switch (range) { - case RandomRange::kGeneral: - return std::numeric_limits::lowest(); - case RandomRange::kAvoidMinValue: - return 1 + std::numeric_limits::lowest(); - case RandomRange::kOffCenterAvoidMinValue: - return 1 + std::numeric_limits::lowest() + - static_cast( - offcenteredness * std::numeric_limits::max() - - offcenteredness * - (std::numeric_limits::lowest() + 1)); - case RandomRange::kReasonableSrcZeroPoint: - return std::numeric_limits::lowest(); - case RandomRange::kReasonableDstZeroPoint: - return Parametrized(0.4); - case RandomRange::kBias: - return std::is_same::value - ? static_cast(-10000) - : 0; - default: - RUY_CHECK(false); - return 0; - } - } - static Scalar GetMaxBound(RandomRange range) { - switch (range) { - case RandomRange::kGeneral: - return std::numeric_limits::max(); - case RandomRange::kAvoidMinValue: - return std::numeric_limits::max(); - case RandomRange::kOffCenterAvoidMinValue: - return std::numeric_limits::max(); - case RandomRange::kReasonableSrcZeroPoint: - return std::numeric_limits::max(); - case RandomRange::kReasonableDstZeroPoint: - return Parametrized(0.6); - case RandomRange::kBias: - return std::is_same::value - ? static_cast(10000) - : 0; - default: - RUY_CHECK(false); - return 0; - } - } -}; - -inline std::default_random_engine& global_random_engine() { - static std::default_random_engine engine; - return engine; -} - -template -struct UniformRandomDistribution { - UniformRandomDistribution(RandomRange range) - : dist(RandomRangeBounds::GetMinBound(range), - RandomRangeBounds::GetMaxBound(range)) {} - Scalar Get() { return dist(global_random_engine()); } - // std::uniform_int_distribution is specified not to support char types, - // only short and wider types. MSVC actually generates an error on - // std::uniform_int_distribution. - using StdDistType = typename std::conditional< - std::is_floating_point::value, - std::uniform_real_distribution, - std::uniform_int_distribution>::type; - StdDistType dist; -}; - -template -void MakeRandomScalar(UniformRandomDistribution* uniform_dist, - Scalar* dst) { - *dst = uniform_dist->Get(); -} - -template -void MakeRandomVector(UniformRandomDistribution* uniform_dist, int size, - std::vector* dst) { - dst->resize(size); - for (auto& x : *dst) { - MakeRandomScalar(uniform_dist, &x); - } -} - -template -void MakeRandomScalar(RandomRange range, Scalar* dst) { - UniformRandomDistribution dist(range); - *dst = dist.Get(); - if (range == RandomRange::kReasonableDstZeroPoint || - range == RandomRange::kReasonableSrcZeroPoint) { - if (global_random_engine()() & 1) { - *dst = SymmetricZeroPoint(); - } - } -} - -template -void MakeRandomVector(RandomRange range, int size, std::vector* dst) { - UniformRandomDistribution dist(range); - dst->resize(size); - for (auto& x : *dst) { - MakeRandomScalar(&dist, &x); - } -} - -enum class LayoutStyle { kPackedLinear, kLinear }; - -inline void MakeLayout(int rows, int cols, Order order, - LayoutStyle layout_style, Layout* layout) { - layout->rows = rows; - layout->cols = cols; - layout->order = order; - - const int packed_stride = order == Order::kColMajor ? rows : cols; - - RUY_CHECK(layout_style == LayoutStyle::kPackedLinear || - layout_style == LayoutStyle::kLinear); - if (layout_style == LayoutStyle::kPackedLinear) { - layout->stride = packed_stride; - } else { - layout->stride = packed_stride + 1; - } -} - -template -struct StorageMatrix { - StorageMatrix() = default; - StorageMatrix(const StorageMatrix&) = delete; - void operator=(const StorageMatrix&) = delete; - std::vector data; - Matrix matrix; -}; - -template -void VerifyConsistentFields(const StorageMatrix& storage_matrix) { - if (storage_matrix.data.empty()) { - RUY_CHECK_EQ(storage_matrix.matrix.data.get(), nullptr); - RUY_CHECK_EQ(storage_matrix.matrix.layout.rows, 0); - RUY_CHECK_EQ(storage_matrix.matrix.layout.cols, 0); - } else { - RUY_CHECK_EQ(storage_matrix.matrix.data.get(), storage_matrix.data.data()); - RUY_CHECK_EQ(FlatSize(storage_matrix.matrix.layout), - storage_matrix.data.size()); - } -} - -template -void MakeRandom(int rows, int cols, Order order, Scalar zero_point, - LayoutStyle layout_style, RandomRange range, - StorageMatrix* storage_matrix) { - MakeLayout(rows, cols, order, layout_style, &storage_matrix->matrix.layout); - storage_matrix->matrix.zero_point = zero_point; - UniformRandomDistribution data_dist(range); - MakeRandomVector(&data_dist, FlatSize(storage_matrix->matrix.layout), - &storage_matrix->data); - storage_matrix->matrix.data = storage_matrix->data.data(); - VerifyConsistentFields(*storage_matrix); -} - -template -struct TestResult { - void operator=(const TestResult&) = delete; - void operator=(const TestResult&&) = delete; - StorageMatrix storage_matrix; - Path path = Path::kNone; - Tuning tuning = Tuning::kAuto; - ExternalPath external_path = ExternalPath::kNone; - float latency; - float l1_refill_rate; - float l2_refill_rate; - float l3_refill_rate; - float l1tlb_refill_rate; - float l2tlb_refill_rate; - float mispred_rate; - float frontend_stall_rate; - float backend_stall_rate; - - // Per-path data for pre-packing. - // This is not used by external paths or by Path::kReference. - Allocator allocator; - PrepackedMatrix prepacked_lhs; - PrepackedMatrix prepacked_rhs; - bool use_prepacked_lhs = false; - bool use_prepacked_rhs = false; -}; - -template -std::string PathName(const TestResult& result) { - std::string pathname; - if (result.path != Path::kNone) { - pathname.assign(PathName(result.path)); - } else if (result.external_path != ExternalPath::kNone) { - pathname.assign(PathName(result.external_path)); - } else { - RUY_CHECK(false); - } - if (result.tuning != Tuning::kAuto) { - pathname.append("/"); - pathname.append(TuningName(result.tuning)); - } - return pathname; -} - -enum class ExpectedOutcome { kSuccess, kDeath }; - -template -struct TestSet final { - using LhsScalar = tLhsScalar; - using RhsScalar = tRhsScalar; - using AccumScalar = typename SpecType::AccumScalar; - using DstScalar = typename SpecType::DstScalar; - using Spec = SpecType; - using TestResultType = TestResult; - - void Run() { - MakeZeroPoints(); - MakeLhsRhs(); - MakeSpec(); - MakeOtherParams(); - MakeResultPaths(); - MakePrepackedMatrices(); - Eval(); - Verify(); - } - - private: - void MakeZeroPoints(); - void MakeLhsRhs(); - void MakeSpec(); - void MakeResultPaths(); - void MakePrepackedMatrices(); - void MakeOtherParams(); - void EvalAndVerify(); - void Eval(); - void Verify(); - - void EvalResult(TestResultType* result); - void EvalRuy(TestResultType* result); - void DoMul(TestResultType* result); - void Benchmark(TestResultType* result); - void VerifyTestResults() const; - - public: - enum class LifeStage { - kInitial, - kHasZeroPoints, - kHasLhsRhs, - kHasSpec, - kHasOtherParams, - kHasResultPaths, - kHasPrepackedMatrices, - kEvaluated, - kFinal - }; - - ~TestSet() { - RUY_CHECK_EQ(life_stage, LifeStage::kFinal); - LogCoveredPathsOnDestruction::Singleton(); - } - - LifeStage life_stage = LifeStage::kInitial; - - int rows = 0; - int cols = 0; - int depth = 0; - Order lhs_order = Order::kRowMajor; - Order rhs_order = Order::kColMajor; - Order dst_order = Order::kColMajor; - LayoutStyle layout_style = LayoutStyle::kPackedLinear; - ExpectedOutcome expected_outcome = ExpectedOutcome::kSuccess; - - bool use_specified_zero_points = false; - LhsScalar lhs_zero_point = 0; - RhsScalar rhs_zero_point = 0; - DstScalar dst_zero_point = 0; - - std::vector per_channel_multiplier_fixedpoint; - std::vector per_channel_multiplier_exponent; - - StorageMatrix lhs; - StorageMatrix rhs; - Spec spec; - std::vector bias_data; - std::vector> results; - - std::vector paths; - std::vector external_paths; - - bool benchmark = false; - bool perchannel = false; - int max_num_threads = 0; - bool benchmark_prepack_lhs = false; - bool benchmark_prepack_rhs = false; -}; - -inline PmuEvents& GlobalPmuEvents() { - static PmuEvents pmu; - return pmu; -} - -inline Context& GlobalContext() { - // Ensure that GlobalPmuEvents is constructed before we create any context. - // This ensures that pmu counters are opened before we create any worker - // thread, which is necessary to count events from worker threads. - GlobalPmuEvents(); - - static Context context; - return context; -} - -#if defined(__has_feature) -#if __has_feature(thread_sanitizer) -#define RUY_TSAN -#endif -#if __has_feature(address_sanitizer) -#define RUY_ASAN -#endif -#endif // defined(__has_feature) - -template -void TestSet::DoMul(TestResultType* result) { - Context* context = &GlobalContext(); - - if (!result->use_prepacked_lhs && !result->use_prepacked_rhs) { - Mul(lhs.matrix, rhs.matrix, spec, context, - &result->storage_matrix.matrix); - return; - } - - // If we prepacked an input matrix, null out its data pointer to check - // that we don't access any data through it. - Matrix null_data_lhs = lhs.matrix; - Matrix null_data_rhs = rhs.matrix; - if (result->use_prepacked_lhs) { - null_data_lhs.data = nullptr; - } - if (result->use_prepacked_rhs) { - null_data_rhs.data = nullptr; - } - - // Do the multiplication with pre-packed matrices. - PrepackedMatrix* prepacked_lhs_ptr = - result->use_prepacked_lhs ? &result->prepacked_lhs : nullptr; - PrepackedMatrix* prepacked_rhs_ptr = - result->use_prepacked_rhs ? &result->prepacked_rhs : nullptr; - MulWithPrepacked(null_data_lhs, null_data_rhs, spec, context, - &result->storage_matrix.matrix, prepacked_lhs_ptr, - prepacked_rhs_ptr); -} - -// When building for WAsm, ASSERT_DEATH is not defined. -#ifdef ASSERT_DEATH -#define RUY_ASSERT_DEATH(CONDITION, MESSAGE) ASSERT_DEATH(CONDITION, MESSAGE) -#else -#define RUY_ASSERT_DEATH(CONDITION, MESSAGE) -#endif - -template -void TestSet::EvalRuy(TestResultType* result) { - GlobalContext().explicit_tuning = result->tuning; - if (max_num_threads) { - GlobalContext().max_num_threads = max_num_threads; - } else if (benchmark) { - GlobalContext().max_num_threads = 1; - } else { - GlobalContext().max_num_threads = 1 + global_random_engine()() % 8; - } - GlobalContext().SetRuntimeEnabledPaths(result->path); - if (expected_outcome == ExpectedOutcome::kSuccess) { - DoMul(result); - RUY_CHECK_EQ(GlobalContext().last_taken_path, result->path); - } else if (expected_outcome == ExpectedOutcome::kDeath) { - // TODO(benoitjacob) TSan and ASan seem to be breaking ASSERT_DEATH. - // Report a bug? -#if (!defined NDEBUG) && (!defined RUY_ASAN) && (!defined RUY_TSAN) - RUY_ASSERT_DEATH(DoMul(result), ""); -#endif - } else { - RUY_CHECK(false); - } - GlobalContext().explicit_tuning = Tuning::kAuto; - GlobalContext().max_num_threads = 1; -} - -#ifdef RUY_TEST_EXTERNAL_PATHS - -template -void WrapGemmlowp(const Matrix& src, - gemmlowp::MatrixMap* dst) { - RUY_CHECK(src.layout.order == (tOrder == gemmlowp::MapOrder::ColMajor - ? Order::kColMajor - : Order::kRowMajor)); - *dst = gemmlowp::MatrixMap( - src.data.get(), src.layout.rows, src.layout.cols, src.layout.stride); -} - -template -void WrapGemmlowpMutable(Matrix* src, - gemmlowp::MatrixMap* dst) { - RUY_CHECK(src->layout.order == (tOrder == gemmlowp::MapOrder::ColMajor - ? Order::kColMajor - : Order::kRowMajor)); - *dst = gemmlowp::MatrixMap( - src->data.get(), src->layout.rows, src->layout.cols, src->layout.stride); -} - -template -struct GemmlowpOrder {}; - -template <> -struct GemmlowpOrder { - static constexpr gemmlowp::MapOrder kValue = gemmlowp::MapOrder::ColMajor; -}; - -template <> -struct GemmlowpOrder { - static constexpr gemmlowp::MapOrder kValue = gemmlowp::MapOrder::RowMajor; -}; - -inline gemmlowp::GemmContext& GlobalGemmlowpContext() { - static gemmlowp::GemmContext context; - return context; -} - -template -void EvalGemmlowp(const Matrix& lhs, const Matrix& rhs, - const Spec& spec, int max_num_threads, - Matrix* dst) { - static constexpr gemmlowp::MapOrder kGemmlowpLhsOrder = - GemmlowpOrder::kValue; - static constexpr gemmlowp::MapOrder kGemmlowpRhsOrder = - GemmlowpOrder::kValue; - static constexpr gemmlowp::MapOrder kGemmlowpDstOrder = - GemmlowpOrder::kValue; - gemmlowp::MatrixMap gemmlowp_lhs; - gemmlowp::MatrixMap gemmlowp_rhs; - gemmlowp::MatrixMap gemmlowp_dst; - WrapGemmlowp(lhs, &gemmlowp_lhs); - WrapGemmlowp(rhs, &gemmlowp_rhs); - WrapGemmlowpMutable(dst, &gemmlowp_dst); - - gemmlowp::OutputStageScaleInt32ByFixedPointAndExponent quantize_down_stage; - quantize_down_stage.result_offset_after_shift = dst->zero_point; - quantize_down_stage.result_fixedpoint_multiplier = spec.multiplier_fixedpoint; - quantize_down_stage.result_exponent = spec.multiplier_exponent; - gemmlowp::OutputStageScaleInt32ByFixedPointAndExponentPC< - gemmlowp::VectorShape::Col> - quantize_down_stage_pc; - quantize_down_stage_pc.result_offset_after_shift = dst->zero_point; - using ColVectorMap = - gemmlowp::VectorMap; - quantize_down_stage_pc.result_fixedpoint_multiplier = - ColVectorMap(spec.multiplier_fixedpoint_perchannel, lhs.layout.rows); - quantize_down_stage_pc.result_exponent = - ColVectorMap(spec.multiplier_exponent_perchannel, lhs.layout.rows); - - gemmlowp::OutputStageClamp clamp_stage; - clamp_stage.min = spec.clamp_min; - clamp_stage.max = spec.clamp_max; - using OutputStageSaturatingCast = typename std::conditional< - std::is_same::value, - gemmlowp::OutputStageSaturatingCastToUint8, - gemmlowp::OutputStageSaturatingCastToInt16>::type; - OutputStageSaturatingCast saturating_cast_stage; - - GlobalGemmlowpContext().set_max_num_threads(max_num_threads ? max_num_threads - : 1); - if (spec.bias) { - using ColVectorMap = - gemmlowp::VectorMap; - gemmlowp::OutputStageBiasAddition bias_add_stage; - bias_add_stage.bias_vector = ColVectorMap(spec.bias, dst->layout.rows); -#ifndef GEMMLOWP_SSE4 // gemmlowp perchannel stuff does not build on SSE - if (spec.multiplier_exponent_perchannel) { - const auto& output_pipeline = - std::make_tuple(bias_add_stage, quantize_down_stage_pc, clamp_stage, - saturating_cast_stage); - gemmlowp::GemmWithOutputPipeline< - LhsScalar, DstScalar, gemmlowp::L8R8WithLhsNonzeroBitDepthParams>( - &GlobalGemmlowpContext(), gemmlowp_lhs, gemmlowp_rhs, &gemmlowp_dst, - -lhs.zero_point, -rhs.zero_point, output_pipeline); - } else // NOLINT[readability/braces] -#endif - { - const auto& output_pipeline = - std::make_tuple(bias_add_stage, quantize_down_stage, clamp_stage, - saturating_cast_stage); - gemmlowp::GemmWithOutputPipeline< - LhsScalar, DstScalar, gemmlowp::L8R8WithLhsNonzeroBitDepthParams>( - &GlobalGemmlowpContext(), gemmlowp_lhs, gemmlowp_rhs, &gemmlowp_dst, - -lhs.zero_point, -rhs.zero_point, output_pipeline); - } - } else { -#ifndef GEMMLOWP_SSE4 // gemmlowp perchannel stuff does not build on SSE - if (spec.multiplier_exponent_perchannel) { - const auto& output_pipeline = std::make_tuple( - quantize_down_stage_pc, clamp_stage, saturating_cast_stage); - gemmlowp::GemmWithOutputPipeline< - LhsScalar, DstScalar, gemmlowp::L8R8WithLhsNonzeroBitDepthParams>( - &GlobalGemmlowpContext(), gemmlowp_lhs, gemmlowp_rhs, &gemmlowp_dst, - -lhs.zero_point, -rhs.zero_point, output_pipeline); - } else // NOLINT[readability/braces] -#endif - { - const auto& output_pipeline = std::make_tuple( - quantize_down_stage, clamp_stage, saturating_cast_stage); - gemmlowp::GemmWithOutputPipeline< - LhsScalar, DstScalar, gemmlowp::L8R8WithLhsNonzeroBitDepthParams>( - &GlobalGemmlowpContext(), gemmlowp_lhs, gemmlowp_rhs, &gemmlowp_dst, - -lhs.zero_point, -rhs.zero_point, output_pipeline); - } - } -} - -inline constexpr int Mash(Order LhsOrder, Order RhsOrder, Order DstOrder) { - return (LhsOrder == Order::kRowMajor ? 4 : 0) + - (RhsOrder == Order::kRowMajor ? 2 : 0) + - (DstOrder == Order::kRowMajor ? 1 : 0); -} - -template -void EvalGemmlowp(const Matrix& lhs, const Matrix& rhs, - const Spec& spec, int max_num_threads, - Matrix* dst) { - int index = Mash(lhs.layout.order, rhs.layout.order, dst->layout.order); - switch (index) { -#define EVALGEMMLOWP_CASE3(LHS, RHS, DST) \ - case Mash(LHS, RHS, DST): \ - return EvalGemmlowp(lhs, rhs, spec, max_num_threads, dst); -#define EVALGEMMLOWP_CASE2(LHS, RHS) \ - EVALGEMMLOWP_CASE3(LHS, RHS, Order::kColMajor) \ - EVALGEMMLOWP_CASE3(LHS, RHS, Order::kRowMajor) -#define EVALGEMMLOWP_CASE1(LHS) \ - EVALGEMMLOWP_CASE2(LHS, Order::kColMajor) \ - EVALGEMMLOWP_CASE2(LHS, Order::kRowMajor) - - EVALGEMMLOWP_CASE1(Order::kColMajor) - EVALGEMMLOWP_CASE1(Order::kRowMajor) - -#undef EVALGEMMLOWP_CASE1 -#undef EVALGEMMLOWP_CASE2 -#undef EVALGEMMLOWP_CASE3 - - default: - RUY_CHECK(false); - } -} - -template -struct EigenOrder {}; - -template <> -struct EigenOrder { - static constexpr int kValue = Eigen::ColMajor; -}; - -template <> -struct EigenOrder { - static constexpr int kValue = Eigen::RowMajor; -}; - -template -void EvalEigen(const Matrix& lhs, const Matrix& rhs, - const Spec& spec, int max_num_threads, Matrix* dst) { - RUY_CHECK_EQ(lhs.zero_point, 0); - RUY_CHECK_EQ(rhs.zero_point, 0); - RUY_CHECK_EQ(dst->zero_point, 0); - RUY_CHECK_EQ(spec.multiplier_fixedpoint, 0); - RUY_CHECK_EQ(spec.multiplier_exponent, 0); - - static constexpr int kEigenLhsOrder = EigenOrder::kValue; - static constexpr int kEigenRhsOrder = EigenOrder::kValue; - static constexpr int kEigenDstOrder = EigenOrder::kValue; - - using EigenLhsType = typename Eigen::Matrix:: - template StridedConstMapType>::type; - using EigenRhsType = typename Eigen::Matrix:: - template StridedConstMapType>::type; - using EigenDstType = typename Eigen::Matrix:: - template StridedMapType>::type; - using EigenBiasType = - typename Eigen::Matrix::ConstMapType; - - EigenLhsType eigen_lhs(lhs.data.get(), lhs.layout.rows, lhs.layout.cols, - Eigen::OuterStride(lhs.layout.stride)); - EigenRhsType eigen_rhs(rhs.data.get(), rhs.layout.rows, rhs.layout.cols, - Eigen::OuterStride(rhs.layout.stride)); - EigenDstType eigen_dst( - dst->data.get(), dst->layout.rows, dst->layout.cols, - Eigen::OuterStride(dst->layout.stride)); - Eigen::setNbThreads(max_num_threads ? max_num_threads : 1); - - if (spec.bias) { - EigenBiasType eigen_bias(spec.bias, dst->layout.rows); - if (spec.clamp_max == std::numeric_limits::infinity() && - spec.clamp_min == -std::numeric_limits::infinity()) { - eigen_dst.noalias() = (eigen_lhs * eigen_rhs).colwise() + eigen_bias; - } else { - eigen_dst.noalias() = ((eigen_lhs * eigen_rhs).colwise() + eigen_bias) - .cwiseMin(spec.clamp_max) - .cwiseMax(spec.clamp_min); - } - } else { - if (spec.clamp_max == std::numeric_limits::infinity() && - spec.clamp_min == -std::numeric_limits::infinity()) { - eigen_dst.noalias() = eigen_lhs * eigen_rhs; - } else { - eigen_dst.noalias() = (eigen_lhs * eigen_rhs) - .cwiseMin(spec.clamp_max) - .cwiseMax(spec.clamp_min); - } - } -} - -template -void EvalEigen(const Matrix& lhs, const Matrix& rhs, - const Spec& spec, int max_num_threads, Matrix* dst) { - int index = Mash(lhs.layout.order, rhs.layout.order, dst->layout.order); - switch (index) { -#define EVALEIGEN_CASE3(LHS, RHS, DST) \ - case Mash(LHS, RHS, DST): \ - return EvalEigen(lhs, rhs, spec, max_num_threads, dst); -#define EVALEIGEN_CASE2(LHS, RHS) \ - EVALEIGEN_CASE3(LHS, RHS, Order::kColMajor) \ - EVALEIGEN_CASE3(LHS, RHS, Order::kRowMajor) -#define EVALEIGEN_CASE1(LHS) \ - EVALEIGEN_CASE2(LHS, Order::kColMajor) \ - EVALEIGEN_CASE2(LHS, Order::kRowMajor) - - EVALEIGEN_CASE1(Order::kColMajor) - EVALEIGEN_CASE1(Order::kRowMajor) - -#undef EVALEIGEN_CASE1 -#undef EVALEIGEN_CASE2 -#undef EVALEIGEN_CASE3 - - default: - RUY_CHECK(false); - } -} - -template -void EvalEigenTensor(const Matrix& lhs, const Matrix& rhs, - const Spec& spec, int max_num_threads, - Matrix* dst) { - RUY_CHECK_EQ(lhs.zero_point, 0); - RUY_CHECK_EQ(rhs.zero_point, 0); - RUY_CHECK_EQ(dst->zero_point, 0); - RUY_CHECK_EQ(spec.multiplier_fixedpoint, 0); - RUY_CHECK_EQ(spec.multiplier_exponent, 0); - - // Eigen::TensorMap only supports packed layouts - RUY_CHECK(IsPacked(lhs.layout)); - RUY_CHECK(IsPacked(rhs.layout)); - RUY_CHECK(IsPacked(dst->layout)); - - using TensorLhsType = - Eigen::TensorMap>; - using TensorRhsType = - Eigen::TensorMap>; - using TensorDstType = - Eigen::TensorMap>; - using TensorBiasType = - Eigen::TensorMap>; - - const bool tr = DstOrder == Order::kRowMajor; - const auto& contract_lhs = tr ? rhs : lhs; - const auto& contract_rhs = tr ? lhs : rhs; - - TensorLhsType tensor_lhs( - contract_lhs.data.get(), - LhsOrder == Order::kColMajor ? contract_lhs.layout.rows - : contract_lhs.layout.cols, - LhsOrder == Order::kColMajor ? contract_lhs.layout.cols - : contract_lhs.layout.rows); - TensorRhsType tensor_rhs( - contract_rhs.data.get(), - RhsOrder == Order::kColMajor ? contract_rhs.layout.rows - : contract_rhs.layout.cols, - RhsOrder == Order::kColMajor ? contract_rhs.layout.cols - : contract_rhs.layout.rows); - TensorDstType tensor_dst( - dst->data.get(), - DstOrder == Order::kColMajor ? dst->layout.rows : dst->layout.cols, - DstOrder == Order::kColMajor ? dst->layout.cols : dst->layout.rows); - using DimPair = - typename Eigen::Tensor::DimensionPair; - Eigen::array contract_dims( - {DimPair((LhsOrder == Order::kColMajor) ? 1 : 0, - (RhsOrder == Order::kColMajor) ? 0 : 1)}); - Eigen::array shuffle(DstOrder == Order::kColMajor ? 0 : 1, - DstOrder == Order::kColMajor ? 1 : 0); - static Eigen::ThreadPool pool(max_num_threads ? max_num_threads : 1); - static Eigen::ThreadPoolDevice device(&pool, pool.NumThreads()); - if (spec.bias) { - TensorBiasType tensor_bias(spec.bias, dst->layout.rows); - Eigen::array bias_2d_shape(tr ? 1 : dst->layout.rows, - tr ? dst->layout.rows : 1); - Eigen::array bcast(tr ? dst->layout.cols : 1, - tr ? 1 : dst->layout.cols); - if (spec.clamp_max == std::numeric_limits::infinity() && - spec.clamp_min == -std::numeric_limits::infinity()) { - tensor_dst.device(device) = - tensor_lhs.contract(tensor_rhs, contract_dims); - } else { - tensor_dst.device(device) = - (tensor_lhs.contract(tensor_rhs, contract_dims) + - tensor_bias.reshape(bias_2d_shape).broadcast(bcast)) - .cwiseMin(spec.clamp_max) - .cwiseMax(spec.clamp_min); - } - } else { - if (spec.clamp_max == std::numeric_limits::infinity() && - spec.clamp_min == -std::numeric_limits::infinity()) { - tensor_dst.device(device) = - tensor_lhs.contract(tensor_rhs, contract_dims); - } else { - tensor_dst.device(device) = tensor_lhs.contract(tensor_rhs, contract_dims) - .cwiseMin(spec.clamp_max) - .cwiseMax(spec.clamp_min); - } - } -} - -template -void EvalEigenTensor(const Matrix& lhs, const Matrix& rhs, - const Spec& spec, int max_num_threads, - Matrix* dst) { - int index = Mash(lhs.layout.order, rhs.layout.order, dst->layout.order); - switch (index) { -#define EVALEIGENTENSOR_CASE3(LHS, RHS, DST) \ - case Mash(LHS, RHS, DST): \ - return EvalEigenTensor(lhs, rhs, spec, max_num_threads, dst); -#define EVALEIGENTENSOR_CASE2(LHS, RHS) \ - EVALEIGENTENSOR_CASE3(LHS, RHS, Order::kColMajor) \ - EVALEIGENTENSOR_CASE3(LHS, RHS, Order::kRowMajor) -#define EVALEIGENTENSOR_CASE1(LHS) \ - EVALEIGENTENSOR_CASE2(LHS, Order::kColMajor) \ - EVALEIGENTENSOR_CASE2(LHS, Order::kRowMajor) - - EVALEIGENTENSOR_CASE1(Order::kColMajor) - EVALEIGENTENSOR_CASE1(Order::kRowMajor) - -#undef EVALEIGENTENSOR_CASE1 -#undef EVALEIGENTENSOR_CASE2 -#undef EVALEIGENTENSOR_CASE3 - - default: - RUY_CHECK(false); - } -} - -template -struct GenericBlasGemm {}; - -template <> -struct GenericBlasGemm { - static void Run(char* transa, char* transb, lapack::integer* m, - lapack::integer* n, lapack::integer* k, - lapack::doublereal* alpha, lapack::doublereal* a, - lapack::integer* lda, lapack::doublereal* b, - lapack::integer* ldb, lapack::doublereal* beta, - lapack::doublereal* c, lapack::integer* ldc) { - dgemm_(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); - } -}; - -template <> -struct GenericBlasGemm { - static void Run(char* transa, char* transb, lapack::integer* m, - lapack::integer* n, lapack::integer* k, lapack::real* alpha, - lapack::real* a, lapack::integer* lda, lapack::real* b, - lapack::integer* ldb, lapack::real* beta, lapack::real* c, - lapack::integer* ldc) { - sgemm_(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); - } -}; - -template -void EvalOpenBlas(const Matrix& lhs, const Matrix& rhs, - const Spec& spec, int max_num_threads, Matrix* dst) { - RUY_CHECK_EQ(lhs.zero_point, 0); - RUY_CHECK_EQ(rhs.zero_point, 0); - RUY_CHECK_EQ(dst->zero_point, 0); - RUY_CHECK_EQ(spec.multiplier_fixedpoint, 0); - RUY_CHECK_EQ(spec.multiplier_exponent, 0); - - Matrix gemm_lhs; - Matrix gemm_rhs; - Matrix gemm_dst; - gemm_dst = *dst; - - // Use Transpose to reduce to the all-column-major case. - // Notice that ruy::Matrix merely holds a pointer, does not own data, - // so Transpose is cheap -- no actual matrix data is being transposed here. - if (dst->layout.order == Order::kColMajor) { - gemm_lhs = lhs; - gemm_rhs = rhs; - } else { - gemm_lhs = rhs; - gemm_rhs = lhs; - Transpose(&gemm_lhs); - Transpose(&gemm_rhs); - Transpose(&gemm_dst); - } - bool transposed_lhs = false; - bool transposed_rhs = false; - - if (gemm_lhs.layout.order == Order::kRowMajor) { - Transpose(&gemm_lhs); - transposed_lhs = true; - } - if (gemm_rhs.layout.order == Order::kRowMajor) { - Transpose(&gemm_rhs); - transposed_rhs = true; - } - - RUY_CHECK_EQ(gemm_lhs.layout.order, Order::kColMajor); - RUY_CHECK_EQ(gemm_rhs.layout.order, Order::kColMajor); - RUY_CHECK_EQ(gemm_dst.layout.order, Order::kColMajor); - - char transa = transposed_lhs ? 'T' : 'N'; - char transb = transposed_rhs ? 'T' : 'N'; - int m = gemm_lhs.layout.rows; - int n = gemm_rhs.layout.cols; - int k = gemm_lhs.layout.cols; - float alpha = 1; - Scalar* a = gemm_lhs.data.get(); - int lda = gemm_lhs.layout.stride; - Scalar* b = gemm_rhs.data.get(); - int ldb = gemm_rhs.layout.stride; - float beta = 0; - Scalar* c = gemm_dst.data.get(); - int ldc = gemm_dst.layout.stride; - GenericBlasGemm::Run(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, - &ldb, &beta, c, &ldc); - - // BLAS does not allow us to express the bias-addition and clamping, so - // we use Eigen for that. - - using EigenDstType = - typename Eigen::Matrix:: - template StridedMapType>::type; - using EigenBiasType = - typename Eigen::Matrix::ConstMapType; - - EigenDstType eigen_dst( - gemm_dst.data.get(), gemm_dst.layout.rows, gemm_dst.layout.cols, - Eigen::OuterStride(gemm_dst.layout.stride)); - Eigen::setNbThreads(max_num_threads ? max_num_threads : 1); - - if (spec.bias) { - EigenBiasType eigen_bias(spec.bias, dst->layout.rows); - if (spec.clamp_max == std::numeric_limits::infinity() && - spec.clamp_min == -std::numeric_limits::infinity()) { - eigen_dst.noalias() = eigen_dst.colwise() + eigen_bias; - } else { - eigen_dst.noalias() = (eigen_dst.colwise() + eigen_bias) - .cwiseMin(spec.clamp_max) - .cwiseMax(spec.clamp_min); - } - } else { - if (spec.clamp_max == std::numeric_limits::infinity() && - spec.clamp_min == -std::numeric_limits::infinity()) { - } else { - eigen_dst.noalias() = - eigen_dst.cwiseMin(spec.clamp_max).cwiseMax(spec.clamp_min); - } - } -} - -template -struct SupportsGemmlowp { - static constexpr bool kValue = - std::is_same::value && - std::is_same::value; -}; - -template -struct UsesSingleScalarType { - static constexpr bool kValue = - std::is_same::value && - std::is_same::value && - std::is_same::value; -}; - -template ::value, - bool EnableGemmlowp = SupportsGemmlowp::kValue, - bool SingleScalarType = UsesSingleScalarType::kValue> -struct EvalExternalPathImpl { - using DstScalar = typename TestSetType::DstScalar; - static void Run(TestSetType*, TestResult*) { RUY_CHECK(false); } -}; - -template -struct EvalExternalPathImpl { - using DstScalar = typename TestSetType::DstScalar; - static void Run(TestSetType* test_set, TestResult* test_result) { - if (test_result->external_path == ExternalPath::kEigen) { - EvalEigen(test_set->lhs.matrix, test_set->rhs.matrix, test_set->spec, - test_set->max_num_threads, &test_result->storage_matrix.matrix); - } else if (test_result->external_path == ExternalPath::kEigenTensor) { - EvalEigenTensor(test_set->lhs.matrix, test_set->rhs.matrix, - test_set->spec, test_set->max_num_threads, - &test_result->storage_matrix.matrix); - } else if (test_result->external_path == ExternalPath::kOpenBlas) { - EvalOpenBlas(test_set->lhs.matrix, test_set->rhs.matrix, test_set->spec, - test_set->max_num_threads, - &test_result->storage_matrix.matrix); - } else { - RUY_CHECK(false); - } - } -}; - -template -struct EvalExternalPathImpl { - using DstScalar = typename TestSetType::DstScalar; - static void Run(TestSetType* test_set, TestResult* test_result) { - if (test_result->external_path == ExternalPath::kGemmlowp) { - EvalGemmlowp(test_set->lhs.matrix, test_set->rhs.matrix, test_set->spec, - test_set->max_num_threads, - &test_result->storage_matrix.matrix); - } else { - RUY_CHECK(false); - } - } -}; - -template -void EvalExternalPath( - TestSetType* test_set, - TestResult* test_result) { - EvalExternalPathImpl::Run(test_set, test_result); -} - -#endif // RUY_TEST_EXTERNAL_PATHS - -template -bool Agree(const Matrix& matrix1, const Matrix& matrix2, - int depth) { - RUY_CHECK_EQ(matrix1.layout.rows, matrix2.layout.rows); - RUY_CHECK_EQ(matrix1.layout.cols, matrix2.layout.cols); - RUY_CHECK_EQ(matrix1.zero_point, matrix2.zero_point); - const int size = matrix1.layout.rows * matrix1.layout.cols; - double tolerated_max_diff = 0; - double tolerated_mean_diff = 0; - if (std::is_floating_point::value) { - // TODO: replace hardcoded 100 by something more sensible, probably - // roughly sqrt(depth) based on central limit theorem. - double max_abs_val = 0; - for (int row = 0; row < matrix1.layout.rows; row++) { - for (int col = 0; col < matrix1.layout.cols; col++) { - max_abs_val = - std::max(max_abs_val, - std::abs(static_cast(Element(matrix1, row, col)))); - max_abs_val = - std::max(max_abs_val, - std::abs(static_cast(Element(matrix2, row, col)))); - } - } - tolerated_max_diff = max_abs_val * std::numeric_limits::epsilon() * - 64 * std::sqrt(static_cast(depth)); - tolerated_mean_diff = tolerated_max_diff / std::sqrt(size); - } else if (RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING)) { - tolerated_max_diff = 1; - // totally empirical - tolerated_mean_diff = std::min(1.0, 2.0 * std::pow(size, -0.2)); - } - double sum_diff = 0; - for (int row = 0; row < matrix1.layout.rows; row++) { - for (int col = 0; col < matrix1.layout.cols; col++) { - double elem1 = Element(matrix1, row, col); - double elem2 = Element(matrix2, row, col); - double diff = elem1 - elem2; - - sum_diff += diff; - // Test (std::abs(diff) > tolerated_max_diff), but also true if diff is - // NaN. - if (!(std::abs(diff) <= tolerated_max_diff)) { - return false; - } - } - } - double mean_diff = sum_diff / size; - if (std::abs(mean_diff) > tolerated_mean_diff) { - return false; - } - return true; -} - -template -bool Agree(const StorageMatrix& storage_matrix1, - const StorageMatrix& storage_matrix2, int depth) { - VerifyConsistentFields(storage_matrix1); - VerifyConsistentFields(storage_matrix2); - return Agree(storage_matrix1.matrix, storage_matrix2.matrix, depth); -} - -template -bool Agree(const TestResult& result1, const TestResult& result2, - int depth) { - return Agree(result1.storage_matrix, result2.storage_matrix, depth); -} - -struct Stats { - double median; - double mean; - double min; - double max; -}; - -inline std::string StatsAsString(const Stats& stats) { - char buf[256]; - snprintf(buf, sizeof(buf), "(median = %g, mean = %g, min = %g, max = %g)", - stats.median, stats.mean, stats.min, stats.max); - return std::string(buf); -} - -template -void GetMatrixStats(const Matrix& matrix, Stats* stats) { - double min = std::numeric_limits::infinity(); - double max = -std::numeric_limits::infinity(); - double sum = 0; - std::vector allvals; - for (int row = 0; row < matrix.layout.rows; row++) { - for (int col = 0; col < matrix.layout.cols; col++) { - double val = Element(matrix, row, col); - min = std::min(min, val); - max = std::max(max, val); - sum += val; - allvals.push_back(val); - } - } - std::sort(allvals.begin(), allvals.end()); - stats->min = min; - stats->max = max; - stats->mean = sum / allvals.size(); - stats->median = allvals[allvals.size() / 2]; -} - -struct ErrorAnalysis { - Stats stats_good; - Stats stats_bad; - // The below is to help document departure from bit exactness. It's probably - // not going to be relevant to floating-point. - std::set error_rows; - std::set error_cols; - int row_of_first_error = 0; - int col_of_first_error = 0; - double first_error_good_value = 0; - double first_error_bad_value = 0; -}; - -template -void AnalyzeTestError(const TestSetType& test_set, int first_bad_result_index, - ErrorAnalysis* error_analysis) { - const auto& good_matrix = test_set.results[0]->storage_matrix.matrix; - const auto& bad_matrix = - test_set.results[first_bad_result_index]->storage_matrix.matrix; - GetMatrixStats(good_matrix, &error_analysis->stats_good); - GetMatrixStats(bad_matrix, &error_analysis->stats_bad); - bool found_first_error = false; - for (int row = 0; row < good_matrix.layout.rows; row++) { - for (int col = 0; col < good_matrix.layout.cols; col++) { - if (Element(good_matrix, row, col) != Element(bad_matrix, row, col)) { - if (!found_first_error) { - found_first_error = true; - error_analysis->row_of_first_error = row; - error_analysis->col_of_first_error = col; - error_analysis->first_error_good_value = - Element(good_matrix, row, col); - error_analysis->first_error_bad_value = Element(bad_matrix, row, col); - } - error_analysis->error_rows.insert(row); - error_analysis->error_cols.insert(col); - } - } - } -} - -template -void ComputeReasonableMultiplier( - const Matrix& lhs, - const Matrix& rhs, double* multiplier) { - using LhsScalar = typename TestSetType::LhsScalar; - using RhsScalar = typename TestSetType::RhsScalar; - using DstScalar = typename TestSetType::DstScalar; - if (std::is_floating_point::value || - std::is_same::value) { - *multiplier = 0; - return; - } - *multiplier = static_cast(std::numeric_limits::max()) / - (static_cast(lhs.layout.cols) * - std::numeric_limits::max() * - std::numeric_limits::max()); -} - -inline void QuantizeMultiplier(double multiplier_double, - std::int32_t* multiplier_fixedpoint, - int* multiplier_exponent) { - RUY_CHECK_GT(multiplier_double, 0); - if (multiplier_double == 0.) { - *multiplier_fixedpoint = 0; - *multiplier_exponent = 0; - return; - } - const double q = std::frexp(multiplier_double, multiplier_exponent); - auto q_fixed = static_cast(std::round(q * (1ll << 31))); - RUY_CHECK_LE(q_fixed, (1ll << 31)); - if (q_fixed == (1ll << 31)) { - q_fixed /= 2; - ++*multiplier_exponent; - } - RUY_CHECK_LE(q_fixed, std::numeric_limits::max()); - *multiplier_fixedpoint = static_cast(q_fixed); -} - -template -void SwitchMultiplierToPerChannel(TestSetType* test_set) { - test_set->per_channel_multiplier_fixedpoint.resize(test_set->rows); - test_set->per_channel_multiplier_exponent.resize(test_set->rows); - for (int i = 0; i < test_set->rows; i++) { - // multipliers typically range in [2^30 ; 2^31 - 1]. - // Values in [0, 2^30 - 1] are normally unused, but harmless. - // Thus a good way to randomize multipliers is to subtract from them - // a random value smaller than 2^30 but still significant compared to it. - std::int32_t nudged_multiplier = test_set->spec.multiplier_fixedpoint - - (global_random_engine()() % (1 << 26)); - int nudged_exponent = - test_set->spec.multiplier_exponent - 1 + (global_random_engine()() % 4); - test_set->per_channel_multiplier_fixedpoint[i] = nudged_multiplier; - test_set->per_channel_multiplier_exponent[i] = nudged_exponent; - } - test_set->spec.multiplier_fixedpoint_perchannel = - test_set->per_channel_multiplier_fixedpoint.data(); - test_set->spec.multiplier_exponent_perchannel = - test_set->per_channel_multiplier_exponent.data(); - test_set->spec.multiplier_fixedpoint = 0; - test_set->spec.multiplier_exponent = 0; -} - -template < - typename TestSetType, - bool IsApplicable = - std::is_same::value && - !std::is_same::value> -struct MakeSpecMultiplierFieldsImpl {}; - -template -struct MakeSpecMultiplierFieldsImpl { - static void Run(TestSetType* test_set) { - double multiplier; - ComputeReasonableMultiplier(test_set->lhs.matrix, - test_set->rhs.matrix, &multiplier); - QuantizeMultiplier(multiplier, &test_set->spec.multiplier_fixedpoint, - &test_set->spec.multiplier_exponent); - if (!test_set->benchmark) { - test_set->perchannel = global_random_engine()() & 1; - } - if (test_set->perchannel) { - SwitchMultiplierToPerChannel(test_set); - } - } -}; - -template -struct MakeSpecMultiplierFieldsImpl { - static void Run(TestSetType* test_set) { - test_set->spec.multiplier_fixedpoint = 0; - test_set->spec.multiplier_exponent = 0; - } -}; - -template -void MakeSpecClampFields(Spec* spec) { - using AccumScalar = typename Spec::AccumScalar; - using DstScalar = typename Spec::DstScalar; - - if (std::is_same::value) { - // Returning raw accumulators, clamping is not supported. - spec->clamp_min = std::numeric_limits::lowest(); - spec->clamp_max = std::numeric_limits::max(); - return; - } - - if (getenv("BENCHMARK_ONLY_MATMUL")) { - if (std::is_floating_point::value) { - spec->clamp_min = -std::numeric_limits::infinity(); - spec->clamp_max = std::numeric_limits::infinity(); - } else { - spec->clamp_min = std::numeric_limits::lowest(); - spec->clamp_max = std::numeric_limits::max(); - } - return; - } - - spec->clamp_min = std::numeric_limits::lowest() + 1; - spec->clamp_max = std::numeric_limits::max() - 1; -} - -template -void TestSet::MakeZeroPoints() { - RUY_CHECK_EQ(life_stage, LifeStage::kInitial); - if (!benchmark && !use_specified_zero_points) { - MakeRandomScalar(RandomRange::kReasonableSrcZeroPoint, &lhs_zero_point); - MakeRandomScalar(RandomRange::kReasonableSrcZeroPoint, &rhs_zero_point); - // If destination is std::int32_t, no dst_zero_point is necessary. - if (std::is_same::value) { - dst_zero_point = 0; - } else { - MakeRandomScalar(RandomRange::kReasonableDstZeroPoint, &dst_zero_point); - } - } - life_stage = LifeStage::kHasZeroPoints; -} - -template -void TestSet::MakeLhsRhs() { - RUY_CHECK_EQ(life_stage, LifeStage::kHasZeroPoints); - MakeRandom(rows, depth, lhs_order, lhs_zero_point, layout_style, - RandomRange::kOffCenterAvoidMinValue, &lhs); - MakeRandom(depth, cols, rhs_order, rhs_zero_point, layout_style, - RandomRange::kGeneral, &rhs); - life_stage = LifeStage::kHasLhsRhs; -} - -template -void TestSet::MakeSpec() { - RUY_CHECK_EQ(life_stage, LifeStage::kHasLhsRhs); - - if (!getenv("BENCHMARK_ONLY_MATMUL") && - (benchmark || (global_random_engine()() & 1))) { - MakeRandomVector(RandomRange::kBias, rows, &bias_data); - spec.bias = bias_data.data(); - } - if (lhs.matrix.zero_point == std::numeric_limits::lowest() && - rhs.matrix.zero_point == std::numeric_limits::lowest()) { - lhs.matrix.zero_point += 1; - } - MakeSpecMultiplierFieldsImpl::Run(this); - MakeSpecClampFields(&spec); - life_stage = LifeStage::kHasSpec; -} - -inline int GetIntEnvVarOrZero(const char* name) { - const char* val = getenv(name); - if (!val) { - return 0; - } - return std::stoi(val); -} - -inline float GetFloatEnvVarOrZero(const char* name) { - const char* val = getenv(name); - if (!val) { - return 0; - } - return std::stof(val); -} - -inline int GetHexIntEnvVarOrZero(const char* name) { - const char* val = getenv(name); - if (!val) { - return 0; - } - return std::stoi(val, nullptr, 16); -} - -inline bool GetBoolEnvVarOrFalse(const char* name) { - return static_cast(GetIntEnvVarOrZero(name)); -} - -template -void TestSet::MakeOtherParams() { - RUY_CHECK_EQ(life_stage, LifeStage::kHasSpec); - if (max_num_threads == 0) { - max_num_threads = GetIntEnvVarOrZero("THREADS"); - } - life_stage = LifeStage::kHasOtherParams; -} - -inline std::vector PathsBitfieldAsVector(Path paths_bitfield) { - std::vector result; - std::uint32_t remaining_paths = static_cast(paths_bitfield); - std::uint32_t test_bit = 1; - while (remaining_paths) { - if (remaining_paths & test_bit) { - result.push_back(static_cast(test_bit)); - } - remaining_paths &= ~test_bit; - test_bit <<= 1; - } - return result; -} - -inline std::vector EnumerateTuningsForPath(Path path, bool benchmark) { - if (benchmark) { - return {Tuning::kAuto}; - } -#if RUY_PLATFORM(ARM) - if (path == Path::kNeon || path == Path::kNeonDotprod) { - return {Tuning::kInOrder, Tuning::kOutOfOrder, Tuning::kAuto}; - } -#endif - return {Tuning::kAuto}; -} - -template -void TestSet::MakePrepackedMatrices() { - RUY_CHECK_EQ(life_stage, LifeStage::kHasResultPaths); - - // Prepacked matrices are Path-dependent, so create them for each test result. - for (auto& result : results) { - // If this result uses an external path, then skip this entirely. - if (result->path == Path::kNone) { - continue; - } - // Pre-packing doesn't make sense for Path::kReference. - // TODO(silvasean): Make Path::kReference an ExternalPath? - if (result->path == Path::kReference) { - continue; - } - - // Determine whether we should create/use prepacked matrices. - if (benchmark) { - // For benchmarking, do as requested. - result->use_prepacked_lhs = benchmark_prepack_lhs; - result->use_prepacked_rhs = benchmark_prepack_rhs; - } else { - // When testing, randomly pre-pack sometimes. But don't do it too often. - result->use_prepacked_lhs = (global_random_engine()() & 7) == 0; - result->use_prepacked_rhs = (global_random_engine()() & 7) == 0; - } - - // Create the pre-packed matrices. - PrepackedMatrix* prepacked_lhs_ptr = - result->use_prepacked_lhs ? &result->prepacked_lhs : nullptr; - PrepackedMatrix* prepacked_rhs_ptr = - result->use_prepacked_rhs ? &result->prepacked_rhs : nullptr; - auto alloc_fn = [&result](std::size_t num_bytes) { - return result->allocator.AllocateBytes(num_bytes); - }; - // Use a dst with a null data pointer to check that the pre-packing - // invocation doesn't write into it. - Matrix null_data_dst = result->storage_matrix.matrix; - null_data_dst.data = nullptr; - GlobalContext().SetRuntimeEnabledPaths(result->path); - PrePackForMul(lhs.matrix, rhs.matrix, spec, &GlobalContext(), - &null_data_dst, prepacked_lhs_ptr, - prepacked_rhs_ptr, alloc_fn); - RUY_CHECK_EQ(GlobalContext().last_taken_path, result->path); - } - - life_stage = LifeStage::kHasPrepackedMatrices; -} - -template -void TestSet::MakeResultPaths() { - RUY_CHECK_EQ(life_stage, LifeStage::kHasOtherParams); - - Path paths_bitfield = static_cast(GetHexIntEnvVarOrZero("PATHS")); - - if (paths_bitfield == Path::kNone) { - // Use a dummy Context just to perform the resolution of specific runtime - // enabled paths. - Context context; - paths_bitfield = context.GetRuntimeEnabledPaths(); - } - - // Trim bits that don't correspond to a compiled path, - // to allow specifying e.g. ffff to mean 'all paths' regardless of whether all - // those bits exist as actual paths. - paths_bitfield = paths_bitfield & kAllPaths; - RUY_CHECK_NE(paths_bitfield, Path::kNone); - paths = PathsBitfieldAsVector(paths_bitfield); - -#ifdef RUY_TEST_EXTERNAL_PATHS - - using TestSetType = TestSet; - - if (!GetBoolEnvVarOrFalse("NOEXT")) { - if (SupportsGemmlowp::kValue) { -#ifdef GEMMLOWP_SSE4 - const bool gemmlowp_supported = !spec.multiplier_fixedpoint_perchannel; -#else - const bool gemmlowp_supported = true; -#endif - if (gemmlowp_supported) { - external_paths.push_back(ExternalPath::kGemmlowp); - } - } - if (UsesSingleScalarType::kValue && - std::is_floating_point::value) { - external_paths.push_back(ExternalPath::kEigen); - if (layout_style == LayoutStyle::kPackedLinear) { - external_paths.push_back(ExternalPath::kEigenTensor); - } -// We link against a generic BLAS target that only maps to OpenBLAS on specific -// architectures. -#if RUY_PLATFORM(ARM_32) || RUY_PLATFORM(ARM_64) - // OpenBLAS multi-threading is disabled, so avoid mixing single-threaded - // and multi-threaded benchmark results. - if (max_num_threads == 1 && !getenv("NO_OPENBLAS")) { - external_paths.push_back(ExternalPath::kOpenBlas); - } -#endif - } - } - -#endif // RUY_TEST_EXTERNAL_PATHS - - for (Path path : paths) { - for (Tuning tuning : EnumerateTuningsForPath(path, benchmark)) { - results.emplace_back(new TestResultType); - TestResultType& result = *results.back(); - result.path = path; - result.tuning = tuning; - MakeRandom(rows, cols, dst_order, dst_zero_point, layout_style, - RandomRange::kGeneral, &result.storage_matrix); - } - } - - for (ExternalPath external_path : external_paths) { - results.emplace_back(new TestResultType); - TestResultType& result = *results.back(); - result.external_path = external_path; - MakeRandom(rows, cols, dst_order, dst_zero_point, layout_style, - RandomRange::kGeneral, &result.storage_matrix); - } - - life_stage = LifeStage::kHasResultPaths; -} - -template -void TestSet::EvalResult( - TestResult* result) { - RUY_CHECK(result->path != Path::kNone || - result->external_path != ExternalPath::kNone); - if (result->path != Path::kNone) { - EvalRuy(result); - } else { -#ifdef RUY_TEST_EXTERNAL_PATHS - using TestSetType = TestSet; - EvalExternalPath(this, result); -#endif - } - const std::string& pathname = PathName(*result); - if (std::find(CoveredPaths()->begin(), CoveredPaths()->end(), pathname) == - CoveredPaths()->end()) { - CoveredPaths()->push_back(pathname); - } -} - -using f32 = float; -using f64 = double; -using u8 = std::uint8_t; -using i8 = std::int8_t; -using u16 = std::uint16_t; -using i16 = std::int16_t; -using u32 = std::uint32_t; -using i32 = std::int32_t; -using u64 = std::uint64_t; -using i64 = std::int64_t; - -template -const char* TypeName() { - return nullptr; -} - -#define RUY_TYPENAME(TYPE) \ - template <> \ - const char* TypeName() { \ - return #TYPE; \ - } - -RUY_TYPENAME(f32) -RUY_TYPENAME(f64) -RUY_TYPENAME(u8) -RUY_TYPENAME(i8) -RUY_TYPENAME(u16) -RUY_TYPENAME(i16) -RUY_TYPENAME(u32) -RUY_TYPENAME(i32) -RUY_TYPENAME(u64) -RUY_TYPENAME(i64) - -#undef RUY_TYPENAME - -template -const char* SymmetryName(const Matrix& matrix) { - if (matrix.zero_point == SymmetricZeroPoint()) { - return "symm"; - } else { - return "asymm"; - } -} - -template -int StorageSize(const Matrix& matrix) { - return sizeof(Scalar) * FlatSize(matrix.layout); -} - -// Helper that replicates a buffer and gives out pointers to the replicas. -// This is useful when one wants to traverse data so that it is cold in cache. -// By having a sufficiently large value of num_repeats, one can ensure that the -// working set covered by the replicas is greater than the cache size. -template -class RepeatedBuffer { - public: - RepeatedBuffer() = default; - void Init(const T* elems, std::size_t num_elems, int num_repeats) { - buffers_.clear(); - allocator_.FreeAll(); - for (int i = 0; i < num_repeats; i++) { - T* p; - allocator_.Allocate(num_elems, &p); - memcpy(p, elems, num_elems * sizeof(T)); - buffers_.push_back(p); - } - } - T* Next() { - T* ret = buffers_[current_]; - current_ = (current_ + 1) % buffers_.size(); - return ret; - } - - private: - Allocator allocator_; - std::vector buffers_; - int current_ = 0; -}; - -template -void TestSet::Benchmark( - TestResult* result) { - using DstScalar = typename SpecType::DstScalar; - - const bool cold = getenv("RUY_BENCHMARK_COLD"); - LhsScalar* orig_lhs_data = lhs.matrix.data.get(); - RhsScalar* orig_rhs_data = rhs.matrix.data.get(); - DstScalar* orig_dst_data = result->storage_matrix.matrix.data.get(); - void* orig_prepacked_lhs_data = result->prepacked_lhs.data; - void* orig_prepacked_rhs_data = result->prepacked_rhs.data; - - int num_matmul_sets = 0; - - RepeatedBuffer cold_lhs; - RepeatedBuffer cold_rhs; - RepeatedBuffer cold_dst; - RepeatedBuffer cold_prepacked_lhs; - RepeatedBuffer cold_prepacked_rhs; - - if (cold) { - const int kWorkingSetSize = 100 << 20; - const int each_matmul_set_size = StorageSize(lhs.matrix) + - StorageSize(rhs.matrix) + - StorageSize(result->storage_matrix.matrix); - num_matmul_sets = - (kWorkingSetSize + each_matmul_set_size - 1) / each_matmul_set_size; - - cold_lhs.Init(lhs.matrix.data.get(), FlatSize(lhs.matrix.layout), - num_matmul_sets); - cold_rhs.Init(rhs.matrix.data.get(), FlatSize(rhs.matrix.layout), - num_matmul_sets); - cold_dst.Init(result->storage_matrix.matrix.data.get(), - FlatSize(result->storage_matrix.matrix.layout), - num_matmul_sets); - if (benchmark_prepack_lhs) { - cold_prepacked_lhs.Init(static_cast(result->prepacked_lhs.data), - result->prepacked_lhs.data_size, num_matmul_sets); - } - if (benchmark_prepack_rhs) { - cold_prepacked_rhs.Init(static_cast(result->prepacked_rhs.data), - result->prepacked_rhs.data_size, num_matmul_sets); - } - } - const bool record_pmu = GetBoolEnvVarOrFalse("RUY_BENCHMARK_PMU"); - int repeats = GetIntEnvVarOrZero("RUY_BENCHMARK_REPEATS"); - if (!repeats) { - repeats = 4; - } - float benchmark_min_secs = GetFloatEnvVarOrZero("RUY_BENCHMARK_MIN_SECS"); - if (!benchmark_min_secs) { - benchmark_min_secs = 0.5; - } -#ifdef RUY_PROFILER - { - const char* lhstype = TypeName(); - const char* lhssymm = SymmetryName(lhs.matrix); - const char* rhstype = TypeName(); - const char* rhssymm = SymmetryName(rhs.matrix); - - printf("Profiling path=%s shape=(%dx%dx%d) lhs=(%s,%s) rhs=(%s,%s)\n", - PathName(*result).c_str(), rows, depth, cols, lhstype, lhssymm, - rhstype, rhssymm); - ruy::profiler::ScopeProfile profile; -#endif - - float latency = std::numeric_limits::infinity(); - float l1_refill_rate = std::numeric_limits::infinity(); - float l2_refill_rate = std::numeric_limits::infinity(); - float l3_refill_rate = std::numeric_limits::infinity(); - float l1tlb_refill_rate = std::numeric_limits::infinity(); - float l2tlb_refill_rate = std::numeric_limits::infinity(); - float mispred_rate = std::numeric_limits::infinity(); - float frontend_stall_rate = std::numeric_limits::infinity(); - float backend_stall_rate = std::numeric_limits::infinity(); - - for (int repeat = 0; repeat < repeats; repeat++) { - auto& pmu_events = GlobalPmuEvents(); - if (record_pmu) { - pmu_events.StartRecording(); - } - TimePoint time_start = Now(); - TimePoint t = time_start; - int iters = 0; - int iters_at_a_time = 1; - while (ToFloatSeconds(t - time_start) < benchmark_min_secs) { - for (int i = 0; i < iters_at_a_time; i++) { - if (cold) { - lhs.matrix.data = cold_lhs.Next(); - rhs.matrix.data = cold_rhs.Next(); - result->storage_matrix.matrix.data = cold_dst.Next(); - if (benchmark_prepack_lhs) { - result->prepacked_lhs.data = cold_prepacked_lhs.Next(); - } - if (benchmark_prepack_rhs) { - result->prepacked_rhs.data = cold_prepacked_rhs.Next(); - } - } - EvalResult(result); - iters++; - } - iters_at_a_time *= 2; - t = Now(); - } - latency = std::min( - latency, static_cast(ToFloatSeconds(t - time_start) / iters)); - if (record_pmu) { - pmu_events.StopRecording(); - const float normalization_factor = - 1.0f / (static_cast(iters) * rows * cols * depth); - l1_refill_rate = std::min( - l1_refill_rate, pmu_events.L1RefillCount() * normalization_factor); - l2_refill_rate = std::min( - l2_refill_rate, pmu_events.L2RefillCount() * normalization_factor); - l3_refill_rate = std::min( - l3_refill_rate, pmu_events.L3RefillCount() * normalization_factor); - l1tlb_refill_rate = - std::min(l1tlb_refill_rate, - pmu_events.L1TLBRefillCount() * normalization_factor); - l2tlb_refill_rate = - std::min(l2tlb_refill_rate, - pmu_events.L2TLBRefillCount() * normalization_factor); - mispred_rate = - std::min(mispred_rate, pmu_events.BranchMispredictionCount() * - normalization_factor); - frontend_stall_rate = - std::min(frontend_stall_rate, - pmu_events.FrontendStallCount() * normalization_factor); - backend_stall_rate = - std::min(backend_stall_rate, - pmu_events.BackendStallCount() * normalization_factor); - } - } - result->latency = latency; - if (record_pmu) { - result->l1_refill_rate = l1_refill_rate; - result->l2_refill_rate = l2_refill_rate; - result->l3_refill_rate = l3_refill_rate; - result->l1tlb_refill_rate = l1tlb_refill_rate; - result->l2tlb_refill_rate = l2tlb_refill_rate; - result->mispred_rate = mispred_rate; - result->frontend_stall_rate = frontend_stall_rate; - result->backend_stall_rate = backend_stall_rate; - } - -#ifdef RUY_PROFILER - } - fflush(stdout); -#endif - - if (cold) { - lhs.matrix.data = orig_lhs_data; - rhs.matrix.data = orig_rhs_data; - memcpy(orig_dst_data, result->storage_matrix.matrix.data.get(), - StorageSize(result->storage_matrix.matrix)); - result->storage_matrix.matrix.data = orig_dst_data; - result->prepacked_lhs.data = orig_prepacked_lhs_data; - result->prepacked_rhs.data = orig_prepacked_rhs_data; - } -} - -template -void TestSet::Eval() { - RUY_CHECK_EQ(life_stage, LifeStage::kHasPrepackedMatrices); - for (auto& result : results) { - if (benchmark) { - Benchmark(result.get()); - } else { - EvalResult(result.get()); - } - } - life_stage = LifeStage::kEvaluated; -} - -template -std::string DumpRegion(const Matrix& matrix, int center_row, - int center_col) { - static constexpr int kRadius = 20; - int first_row = std::max(0, center_row - kRadius); - int last_row = std::min(matrix.layout.rows - 1, center_row + kRadius); - int first_col = std::max(0, center_col - kRadius); - int last_col = std::min(matrix.layout.cols - 1, center_col + kRadius); - std::ostringstream stream; - for (int row = first_row; row <= last_row; row++) { - for (int col = first_col; col <= last_col; col++) { - stream << static_cast(Element(matrix, row, col)) << " "; - } - stream << "\n"; - } - return stream.str(); -} - -template -void TestSet::VerifyTestResults() const { - const int depth = lhs.matrix.layout.cols; - for (int i = 0; i < results.size() - 1; i++) { - if (!Agree(*results[i], *results[i + 1], depth)) { - std::string paths_in_agreement; - paths_in_agreement.append(PathName(*results[0])); - for (int j = 1; j <= i; j++) { - paths_in_agreement.append(", "); - paths_in_agreement.append(PathName(*results[j])); - } - ErrorAnalysis error_analysis; - AnalyzeTestError(*this, i + 1, &error_analysis); - std::cerr << "Error: path (" << PathName(*results[i + 1]) - << ") disagrees with the other paths (" << paths_in_agreement - << "), which agree with each other." << std::endl; - std::cerr << "Shape: rows = " << rows << ", cols = " << cols - << ", depth = " << depth << std::endl; - std::cerr << "Stats of the good result matrix: " - << StatsAsString(error_analysis.stats_good) << std::endl; - std::cerr << "Stats of the bad result matrix: " - << StatsAsString(error_analysis.stats_bad) << std::endl; - if (error_analysis.error_rows.size() < rows) { - std::cerr << "Rows containing errors: " - << Join(error_analysis.error_rows) << std::endl; - } else { - std::cerr << "Errors found in ALL rows." << std::endl; - } - if (error_analysis.error_cols.size() < cols) { - std::cerr << "Cols containing errors: " - << Join(error_analysis.error_cols) << std::endl; - } else { - std::cerr << "Errors found in ALL cols." << std::endl; - } - std::cerr << "The first error occurs at row " - << error_analysis.row_of_first_error << ", col " - << error_analysis.col_of_first_error << std::endl; - std::cerr << "Good value: " << error_analysis.first_error_good_value - << std::endl; - std::cerr << "Bad value : " << error_analysis.first_error_bad_value - << std::endl; - std::cerr << "Region of Good result matrix around first error:\n\n" - << DumpRegion(results[0]->storage_matrix.matrix, - error_analysis.row_of_first_error, - error_analysis.col_of_first_error) - << std::endl; - std::cerr << "Region of Bad result matrix around first error:\n\n" - << DumpRegion(results[i + 1]->storage_matrix.matrix, - error_analysis.row_of_first_error, - error_analysis.col_of_first_error) - << std::endl; - RUY_CHECK(false); - } - } -} - -template -void TestSet::Verify() { - RUY_CHECK_EQ(life_stage, LifeStage::kEvaluated); - if (expected_outcome == ExpectedOutcome::kSuccess) { - VerifyTestResults(); - } - life_stage = LifeStage::kFinal; -} - -template -void TestRCC(int rows, int depth, int cols, ExpectedOutcome expected_outcome) { - TestSetType test_set; - test_set.rows = rows; - test_set.depth = depth; - test_set.cols = cols; - test_set.lhs_order = Order::kRowMajor; - test_set.rhs_order = Order::kColMajor; - test_set.dst_order = Order::kColMajor; - test_set.layout_style = LayoutStyle::kPackedLinear; - test_set.expected_outcome = expected_outcome; - test_set.Run(); -} - -template -void TestRCC(int rows, int depth, int cols) { - TestRCC(rows, depth, cols, ExpectedOutcome::kSuccess); -} - -template -void TestNonRCC(int rows, int depth, int cols, - ExpectedOutcome expected_outcome) { - TestSetType test_set; - test_set.rows = rows; - test_set.depth = depth; - test_set.cols = cols; - test_set.lhs_order = Order::kColMajor; - test_set.rhs_order = Order::kColMajor; - test_set.dst_order = Order::kColMajor; - test_set.layout_style = LayoutStyle::kPackedLinear; - test_set.expected_outcome = expected_outcome; - test_set.Run(); -} - -template -void TestLinearAllOrders(int rows, int depth, int cols, - ExpectedOutcome expected_outcome) { - const std::vector orders{Order::kColMajor, Order::kRowMajor}; - - for (Order lhs_order : orders) { - for (Order rhs_order : orders) { - for (Order dst_order : orders) { - TestSetType test_set; - test_set.rows = rows; - test_set.depth = depth; - test_set.cols = cols; - test_set.lhs_order = lhs_order; - test_set.rhs_order = rhs_order; - test_set.dst_order = dst_order; - test_set.layout_style = LayoutStyle::kLinear; - test_set.expected_outcome = expected_outcome; - test_set.Run(); - } - } - } -} - -template -void TestLinearAllOrders(int rows, int depth, int cols) { - TestLinearAllOrders(rows, depth, cols, - ExpectedOutcome::kSuccess); -} - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_TEST_H_ diff --git a/test_fast.cc b/test_fast.cc deleted file mode 100644 index 610fc1b..0000000 --- a/test_fast.cc +++ /dev/null @@ -1,110 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This test contains cheap test cases, completes in a few seconds. - -#include - -#include "test.h" - -namespace ruy { - -using LhsScalar = RUY_TEST_LHSSCALAR; -using RhsScalar = RUY_TEST_RHSSCALAR; -using AccumScalar = RUY_TEST_ACCUMSCALAR; -using DstScalar = RUY_TEST_DSTSCALAR; - -using TestSetType = - TestSet>; - -TEST(RuyTest, TestSquareMuls) { - const std::vector sizes{ - // small sizes - 1, - 2, - 3, - 4, - 5, - 6, - 7, - 8, - 9, - 10, - // multiplies of 16 - 16, - 32, - 48, - 64, - // pot-minus-1 sizes - 15, - 31, - 63, - // pot-plus-1 sizes - 17, - 33, - 65, - }; - - for (int size : sizes) { - TestRCC(size, size, size); - TestLinearAllOrders(size, size, size); - } -} - -TEST(RuyTest, TestMiscMuls) { - const int shapes[][3] = { - {2, 3, 4}, {7, 6, 5}, {12, 23, 6}, {19, 3, 11}, {3, 10, 17}, - {30, 21, 43}, {7, 57, 9}, {49, 69, 71}, {38, 111, 29}, {87, 98, 76}, - {16, 96, 16}, {16, 88, 16}, {16, 84, 16}, {16, 92, 16}, {16, 82, 16}, - {16, 81, 16}, {16, 95, 16}, {3, 128, 5}}; - for (const auto& shape : shapes) { - TestLinearAllOrders(shape[0], shape[1], shape[2]); - } -} - -TEST(RuyTest, TestDeepMuls) { - // TODO(b/137649322): clarify what's the max allowed matrix size. - TestRCC(1, 32767, 1); - TestLinearAllOrders(5, 5001, 4); - TestLinearAllOrders(9, 1025, 10); -} - -TEST(RuyTest, TestShallowMuls) { - TestLinearAllOrders(101, 1, 103); - TestLinearAllOrders(71, 2, 53); - TestLinearAllOrders(51, 3, 73); - TestLinearAllOrders(51, 4, 43); -} - -TEST(RuyTest, TestNarrowMuls) { - for (int width : {1, 2, 3, 4, 5, 8}) { - TestLinearAllOrders(width, 12, 13); - TestLinearAllOrders(15, 19, width); - TestLinearAllOrders(width, 123, 137); - TestLinearAllOrders(158, 119, width); - } -} - -TEST(RuyTest, TestGEMV) { - for (int size = 1; size < 1024; size *= 2) { - for (int depth = 1; depth < 500; depth += 47) { - TestLinearAllOrders(size, depth, 1); - } - } - TestLinearAllOrders(5, 5001, 1); - TestLinearAllOrders(8193, 17, 1); -} - -} // namespace ruy diff --git a/test_slow.cc b/test_slow.cc deleted file mode 100644 index 1f3c6bf..0000000 --- a/test_slow.cc +++ /dev/null @@ -1,71 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This test contains more expensive test cases. - -#include "test.h" - -namespace ruy { - -using LhsScalar = RUY_TEST_LHSSCALAR; -using RhsScalar = RUY_TEST_RHSSCALAR; -using AccumScalar = RUY_TEST_ACCUMSCALAR; -using DstScalar = RUY_TEST_DSTSCALAR; - -using TestSetType = - TestSet>; - -TEST(RuyTest, TestBigNarrowMuls) { - for (int width : {1, 2, 3, 4, 5, 8}) { - TestRCC(width, 401, 601); - TestRCC(587, 443, width); - } - TestRCC(7, 45984, - 5); // Large enough to trigger row-sum overflows. - TestRCC(512, 256, 16); -} - -TEST(RuyTest, TestBigShallowMuls) { - TestLinearAllOrders(501, 1, 321); - TestLinearAllOrders(301, 5, 403); - TestLinearAllOrders(256, 32, 512); -} - -TEST(RuyTest, TestBigMuls) { - TestRCC(225, 303, 199); - TestLinearAllOrders(256, 192, 128); -} - -TEST(RuyTest, TestBigPowerOfTwoDepthWithAvoidAliasing) { - // Important to test some power-of-two depths: that's when the - // RUY_AVOID_ALIASING optimization kicks in and makes packed matrices - // strided, exposing bugs in kernels mixing up size and stride. - // Moreover, it's important that the test matrices be sufficiently wide - // that they will result in multiple blocks, exposing bugs in the - // computation of the base address of each block. - TestLinearAllOrders(70, 1024, 80); - TestLinearAllOrders(60, 2048, 70); - TestLinearAllOrders(40, 4096, 50); -} - -TEST(RuyTest, TestGEMV) { - for (int size = 1025; size <= 1409; size += 384) { - for (int depth = 350; depth < 500; depth += 47) { - TestLinearAllOrders(size, depth, 1); - } - } -} - -} // namespace ruy diff --git a/test_special_specs.cc b/test_special_specs.cc deleted file mode 100644 index 41e6e51..0000000 --- a/test_special_specs.cc +++ /dev/null @@ -1,163 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This test covers non-basic specs. - -#include "test.h" - -namespace ruy { - -template -struct LoopStructureSpec : BasicSpec { - static constexpr LoopStructure kLoopStructure = tLoopStructure; -}; - -template -struct ZeroPointSupportSpec : BasicSpec { - static constexpr ZeroPointSupport kZeroPointSupport = tZeroPointSupport; -}; - -template -struct RCCSpec : BasicSpec { - static constexpr LayoutSupport kLayoutSupport = LayoutSupport::kRCC; -}; - -template -struct StandardCppKernelLayoutSpec : BasicSpec { - using StandardCppKernelLhsLayout = LhsKernelLayout; - using StandardCppKernelRhsLayout = RhsKernelLayout; - static int local_data_cache_size() { return 1; } - static int shared_data_cache_size() { return 1; } -}; - -using LhsScalar = RUY_TEST_LHSSCALAR; -using RhsScalar = RUY_TEST_RHSSCALAR; -using AccumScalar = RUY_TEST_ACCUMSCALAR; -using DstScalar = RUY_TEST_DSTSCALAR; - -template -void TestLoopStructure() { - using SpecType = LoopStructureSpec; - using TestSetType = TestSet; - for (int size = 1; size < 10; size++) { - TestLinearAllOrders(size, size, size); - } - TestLinearAllOrders(3, 5, 78); - TestLinearAllOrders(19, 91, 7); - TestLinearAllOrders(71, 26, 44); - TestLinearAllOrders(81, 93, 72); -} - -TEST(TestSpecialSpecs, LoopStructure) { - static_assert(BasicSpec::kLoopStructure == - LoopStructure::kAuto, - ""); - static_assert(BasicSpec::kLoopStructure == LoopStructure::kAuto, - ""); - TestLoopStructure(); - TestLoopStructure(); -} - -template -void TestZeroPointSupport(LhsScalar lhs_zero_point, RhsScalar rhs_zero_point, - DstScalar dst_zero_point, - ExpectedOutcome expected_outcome) { - using SpecType = - ZeroPointSupportSpec; - using TestSetType = TestSet; - TestSetType test_set; - test_set.rows = 11; - test_set.depth = 12; - test_set.cols = 13; - test_set.lhs_order = Order::kRowMajor; - test_set.rhs_order = Order::kColMajor; - test_set.dst_order = Order::kColMajor; - test_set.layout_style = LayoutStyle::kPackedLinear; - test_set.expected_outcome = expected_outcome; - test_set.lhs_zero_point = lhs_zero_point; - test_set.rhs_zero_point = rhs_zero_point; - test_set.dst_zero_point = dst_zero_point; - test_set.use_specified_zero_points = true; - test_set.Run(); -} - -TEST(TestSpecialSpecs, ZeroPointSupport) { - // Sanity check - RUY_CHECK_EQ(SymmetricZeroPoint(), 128); - RUY_CHECK_EQ(SymmetricZeroPoint(), 0); - - if (std::is_floating_point::value) { - return; - } - - TestZeroPointSupport( - SymmetricZeroPoint(), SymmetricZeroPoint(), - SymmetricZeroPoint(), ExpectedOutcome::kSuccess); - TestZeroPointSupport( - SymmetricZeroPoint() - 1, SymmetricZeroPoint(), - SymmetricZeroPoint(), ExpectedOutcome::kSuccess); - TestZeroPointSupport( - SymmetricZeroPoint(), SymmetricZeroPoint(), - SymmetricZeroPoint(), ExpectedOutcome::kSuccess); - TestZeroPointSupport( - SymmetricZeroPoint() + 1, SymmetricZeroPoint(), - SymmetricZeroPoint(), ExpectedOutcome::kDeath); - TestZeroPointSupport( - SymmetricZeroPoint(), SymmetricZeroPoint() + 1, - SymmetricZeroPoint(), ExpectedOutcome::kDeath); - TestZeroPointSupport( - SymmetricZeroPoint(), SymmetricZeroPoint(), - SymmetricZeroPoint() - 1, ExpectedOutcome::kDeath); -} - -TEST(TestSpecialSpecs, RCC) { - using RCCSpec = RCCSpec; - using RCCTestSet = TestSet; - TestRCC(81, 93, 72); - TestNonRCC(81, 93, 72, ExpectedOutcome::kDeath); -} - -template -void TestStandardCppKernelLayout() { - using SpecType = - StandardCppKernelLayoutSpec; - using TestSetType = TestSet; - for (int size = 1; size < 10; size++) { - TestLinearAllOrders(size, size, size); - } - TestLinearAllOrders(87, 34, 56); - TestLinearAllOrders(123, 234, 78); -} - -TEST(TestSpecialSpecs, StandardCppKernelLayoutTrivial1x1) { - TestStandardCppKernelLayout, - FixedKernelLayout>(); -} - -TEST(TestSpecialSpecs, StandardCppKernelLayoutSquare4x4) { - TestStandardCppKernelLayout, - FixedKernelLayout>(); -} - -TEST(TestSpecialSpecs, StandardCppKernelLayoutRectangular4x8) { - TestStandardCppKernelLayout, - FixedKernelLayout>(); -} - -} // namespace ruy diff --git a/thread_pool.cc b/thread_pool.cc deleted file mode 100644 index f5c53dd..0000000 --- a/thread_pool.cc +++ /dev/null @@ -1,200 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "thread_pool.h" - -#include -#include // NOLINT(build/c++11) -#include // NOLINT(build/c++11) -#include -#include -#include -#include // NOLINT(build/c++11) -#include // NOLINT(build/c++11) - -#include "check_macros.h" -#include "wait.h" - -namespace ruy { - -// A worker thread. -class Thread { - public: - enum class State { - Startup, // The initial state before the thread main loop runs. - Ready, // Is not working, has not yet received new work to do. - HasWork, // Has work to do. - ExitAsSoonAsPossible // Should exit at earliest convenience. - }; - - explicit Thread(BlockingCounter* counter_to_decrement_when_ready) - : task_(nullptr), - state_(State::Startup), - counter_to_decrement_when_ready_(counter_to_decrement_when_ready) { - thread_.reset(new std::thread(ThreadFunc, this)); - } - - ~Thread() { - ChangeState(State::ExitAsSoonAsPossible); - thread_->join(); - } - - // Changes State; may be called from either the worker thread - // or the master thread; however, not all state transitions are legal, - // which is guarded by assertions. - // - // The Task argument is to be used only with new_state==HasWork. - // It specifies the Task being handed to this Thread. - void ChangeState(State new_state, Task* task = nullptr) { - state_mutex_.lock(); - State old_state = state_.load(std::memory_order_relaxed); - RUY_DCHECK_NE(old_state, new_state); - switch (old_state) { - case State::Startup: - RUY_DCHECK_EQ(new_state, State::Ready); - break; - case State::Ready: - RUY_DCHECK(new_state == State::HasWork || - new_state == State::ExitAsSoonAsPossible); - break; - case State::HasWork: - RUY_DCHECK(new_state == State::Ready || - new_state == State::ExitAsSoonAsPossible); - break; - default: - abort(); - } - switch (new_state) { - case State::Ready: - if (task_) { - // Doing work is part of reverting to 'ready' state. - task_->Run(); - task_ = nullptr; - } - break; - case State::HasWork: - RUY_DCHECK(!task_); - task_ = task; - break; - default: - break; - } - state_.store(new_state, std::memory_order_relaxed); - state_cond_.notify_all(); - state_mutex_.unlock(); - if (new_state == State::Ready) { - counter_to_decrement_when_ready_->DecrementCount(); - } - } - - static void ThreadFunc(Thread* arg) { arg->ThreadFuncImpl(); } - - // Called by the master thead to give this thread work to do. - void StartWork(Task* task) { ChangeState(State::HasWork, task); } - - private: - // Thread entry point. - void ThreadFuncImpl() { - ChangeState(State::Ready); - - // Thread main loop - while (true) { - // In the 'Ready' state, we have nothing to do but to wait until - // we switch to another state. - const auto& condition = [this]() { - return state_.load(std::memory_order_acquire) != State::Ready; - }; - Wait(condition, &state_cond_, &state_mutex_); - - // Act on new state. - switch (state_.load(std::memory_order_acquire)) { - case State::HasWork: - // Got work to do! So do it, and then revert to 'Ready' state. - ChangeState(State::Ready); - break; - case State::ExitAsSoonAsPossible: - return; - default: - abort(); - } - } - } - - // The underlying thread. - std::unique_ptr thread_; - - // The task to be worked on. - Task* task_; - - // The condition variable and mutex guarding state changes. - std::condition_variable state_cond_; - std::mutex state_mutex_; - - // The state enum tells if we're currently working, waiting for work, etc. - // Its concurrent accesses by the thread and main threads are guarded by - // state_mutex_, and can thus use memory_order_relaxed. This still needs - // to be a std::atomic because we use WaitForVariableChange. - std::atomic state_; - - // pointer to the master's thread BlockingCounter object, to notify the - // master thread of when this thread switches to the 'Ready' state. - BlockingCounter* const counter_to_decrement_when_ready_; -}; - -void ThreadPool::ExecuteImpl(int task_count, int stride, Task* tasks) { - RUY_DCHECK_GE(task_count, 1); - - // Case of 1 thread: just run the single task on the current thread. - if (task_count == 1) { - (tasks + 0)->Run(); - return; - } - - // Task #0 will be run on the current thread. - CreateThreads(task_count - 1); - counter_to_decrement_when_ready_.Reset(task_count - 1); - for (int i = 1; i < task_count; i++) { - auto task_address = reinterpret_cast(tasks) + i * stride; - threads_[i - 1]->StartWork(reinterpret_cast(task_address)); - } - - // Execute task #0 immediately on the current thread. - (tasks + 0)->Run(); - - // Wait for the threads submitted above to finish. - counter_to_decrement_when_ready_.Wait(); -} - -// Ensures that the pool has at least the given count of threads. -// If any new thread has to be created, this function waits for it to -// be ready. -void ThreadPool::CreateThreads(int threads_count) { - if (threads_.size() >= threads_count) { - return; - } - counter_to_decrement_when_ready_.Reset(threads_count - threads_.size()); - while (threads_.size() < threads_count) { - threads_.push_back(new Thread(&counter_to_decrement_when_ready_)); - } - counter_to_decrement_when_ready_.Wait(); -} - -ThreadPool::~ThreadPool() { - for (auto w : threads_) { - delete w; - } -} - -} // end namespace ruy diff --git a/thread_pool.h b/thread_pool.h deleted file mode 100644 index 8e2d141..0000000 --- a/thread_pool.h +++ /dev/null @@ -1,102 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This file is a fork of gemmlowp's multi_thread_gemm.h, under Apache 2.0 -// license. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_THREAD_POOL_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_THREAD_POOL_H_ - -#include - -#include "blocking_counter.h" - -namespace ruy { - -// A workload for a thread. -struct Task { - virtual ~Task() {} - virtual void Run() = 0; -}; - -class Thread; - -// A simple pool of threads, that only allows the very -// specific parallelization pattern that we use here: -// One thread, which we call the 'main thread', calls Execute, distributing -// a Task each to N threads, being N-1 'worker threads' and the main thread -// itself. After the main thread has completed its own Task, it waits for -// the worker threads to have all completed. That is the only synchronization -// performed by this ThreadPool. -// -// In particular, there is a naive 1:1 mapping of Tasks to threads. -// This ThreadPool considers it outside of its own scope to try to work -// with fewer threads than there are Tasks. The idea is that such N:M mappings -// of tasks to threads can be implemented as a higher-level feature on top of -// the present low-level 1:1 threadpool. For example, a user might have a -// Task subclass referencing a shared atomic counter indexing into a vector of -// finer-granularity subtasks. Different threads would then concurrently -// increment this atomic counter, getting each their own subtasks to work on. -// That approach is the one used in ruy's multi-thread matrix multiplication -// implementation --- see ruy's TrMulTask. -class ThreadPool { - public: - ThreadPool() {} - - ~ThreadPool(); - - // Executes task_count tasks on task_count threads. - // Grows the threadpool as needed to have at least (task_count-1) threads. - // The 0-th task is run on the thread on which Execute is called: that - // is by definition what we call the "main thread". Synchronization of all - // threads is performed before this function returns. - // - // As explained in the class comment, there is a 1:1 mapping of tasks to - // threads. If you need something smarter than that, for instance if you - // want to run an unbounded number of tasks on a bounded number of threads, - // then you need something higher-level than this ThreadPool, that can - // be layered on top of it by appropriately subclassing Tasks. - // - // TaskType must be a subclass of ruy::Task. That is implicitly guarded by - // the static_cast in this inline implementation. - template - void Execute(int task_count, TaskType* tasks) { - ExecuteImpl(task_count, sizeof(TaskType), static_cast(tasks)); - } - - private: - // Ensures that the pool has at least the given count of threads. - // If any new thread has to be created, this function waits for it to - // be ready. - void CreateThreads(int threads_count); - - // Non-templatized implementation of the public Execute method. - // See the inline implementation of Execute for how this is used. - void ExecuteImpl(int task_count, int stride, Task* tasks); - - // copy construction disallowed - ThreadPool(const ThreadPool&) = delete; - - // The threads in this pool. They are owned by the pool: - // the pool creates threads and destroys them in its destructor. - std::vector threads_; - - // The BlockingCounter used to wait for the threads. - BlockingCounter counter_to_decrement_when_ready_; -}; - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_THREAD_POOL_H_ diff --git a/time.h b/time.h deleted file mode 100644 index d96ed34..0000000 --- a/time.h +++ /dev/null @@ -1,81 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_TIME_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_TIME_H_ - -#include // NOLINT(build/c++11) -#include // IWYU pragma: keep -#include // NOLINT(build/c++11) - -#ifdef __linux__ -#include -// IWYU pragma: no_include - -#include -#endif - -namespace ruy { - -using InternalDefaultClock = std::chrono::steady_clock; - -using TimePoint = InternalDefaultClock::time_point; -using Duration = InternalDefaultClock::duration; - -template -Duration DurationFromSeconds(RepresentationType representation) { - return std::chrono::duration_cast( - std::chrono::duration(representation)); -} - -template -Duration DurationFromMilliseconds(RepresentationType representation) { - return std::chrono::duration_cast( - std::chrono::duration(representation)); -} - -template -Duration DurationFromNanoseconds(RepresentationType representation) { - return std::chrono::duration_cast( - std::chrono::duration(representation)); -} - -inline float ToFloatSeconds(const Duration& duration) { - return std::chrono::duration_cast>(duration) - .count(); -} - -inline std::int64_t ToInt64Nanoseconds(const Duration& duration) { - return std::chrono::duration_cast< - std::chrono::duration>(duration) - .count(); -} - -inline TimePoint Now() { return InternalDefaultClock::now(); } - -inline TimePoint CoarseNow() { -#ifdef __linux__ - timespec t; - clock_gettime(CLOCK_MONOTONIC_COARSE, &t); - return TimePoint( - DurationFromNanoseconds(1000000000LL * t.tv_sec + t.tv_nsec)); -#else - return Now(); -#endif -} - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_TIME_H_ diff --git a/trace.cc b/trace.cc deleted file mode 100644 index c11fe9b..0000000 --- a/trace.cc +++ /dev/null @@ -1,325 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "trace.h" - -#include -#include // IWYU pragma: keep -#include -#include -#include -#include - -#include "check_macros.h" -#include "side_pair.h" -#include "time.h" - -namespace ruy { - -#ifdef RUY_TRACE - -enum class TraceEvent : std::uint8_t { - kNone, - kThreadStart, - kThreadLoopStart, - kThreadEnd, - kBlockReserved, - kBlockPackedLhs, - kBlockPackedRhs, - kBlockFinished -}; - -struct TraceEntry { - TimePoint time_point; - TraceEvent event; - // ruy-internal thread id i.e. contiguous index into array of threads, - // with 0 designating the main thread. - std::uint16_t thread_id = 0; - // Additional parameters whose meaning depends on the 'event' type. - std::uint32_t params[1]; -}; - -struct Trace { - BlockMap block_map; - // During recording, to avoid having to use locks or atomics, we let - // each thread append to its own specific vector. - std::vector> thread_specific_entries; - // Global vector of entries into which we coalesce thread_specific_entries - // after recording is finished, when dumping a trace. See - // AggregateThreadSpecificEntries. - std::vector entries; - TimePoint time_start; - TimePoint time_execute; - TimePoint time_end; -}; - -namespace { - -// Coalesce Trace::thread_specific_entries into Trace::entries. -void AggregateThreadSpecificEntries(Trace* trace) { - RUY_CHECK(trace->entries.empty()); - for (auto& thread_specific_entries_vector : trace->thread_specific_entries) { - for (const TraceEntry& entry : thread_specific_entries_vector) { - trace->entries.push_back(entry); - } - thread_specific_entries_vector.clear(); - } -} - -// Sort Trace::entries by ascending time. In case of equal timepoints, -// sort by some semi-arbitrary ordering of event types. -void Sort(Trace* trace) { - std::sort(std::begin(trace->entries), std::end(trace->entries), - [](const TraceEntry& a, const TraceEntry& b) -> bool { - return a.time_point < b.time_point || - (a.time_point == b.time_point && - static_cast(a.event) < static_cast(b.event)); - }); -} - -// Dump a trace. Assumes that AggregateThreadSpecificEntries and Sort have -// already been called on it. -// -// On some architectures long long ints are not same as std::int64_t, and -// time is printed as %lld, so static_casts are necessary. -void Dump(const Trace& trace) { - const char* trace_filename = getenv("RUY_TRACE_FILE"); - FILE* trace_file = trace_filename ? fopen(trace_filename, "w") : stderr; - if (!trace_file) { - fprintf(stderr, "Failed to open %s for write, errno=%d\n", trace_filename, - errno); - RUY_CHECK(false); - } - fprintf(trace_file, "thread_count:%d\n", trace.block_map.thread_count); - fprintf(trace_file, "rows:%d\n", trace.block_map.dims[Side::kLhs]); - fprintf(trace_file, "cols:%d\n", trace.block_map.dims[Side::kRhs]); - fprintf(trace_file, "Execute: %lld\n", - static_cast( - ToInt64Nanoseconds(trace.time_execute - trace.time_start))); - for (const TraceEntry& entry : trace.entries) { - long long int time = static_cast( - ToInt64Nanoseconds(entry.time_point - trace.time_start)); - switch (entry.event) { - case TraceEvent::kThreadStart: - fprintf(trace_file, "ThreadStart: %lld, %d\n", time, entry.thread_id); - break; - case TraceEvent::kThreadLoopStart: - fprintf(trace_file, "ThreadLoopStart: %lld, %d\n", time, - entry.thread_id); - break; - case TraceEvent::kThreadEnd: - fprintf(trace_file, "ThreadEnd: %lld, %d\n", time, entry.thread_id); - break; - case TraceEvent::kBlockReserved: { - std::uint32_t block_id = entry.params[0]; - SidePair block; - GetBlockByIndex(trace.block_map, block_id, &block); - SidePair start, end; - GetBlockMatrixCoords(trace.block_map, block, &start, &end); - fprintf(trace_file, - "BlockReserved: %lld, %d, %d, %d, %d, %d, %d, %d, %d\n", time, - entry.thread_id, block_id, block[Side::kLhs], block[Side::kRhs], - start[Side::kLhs], start[Side::kRhs], end[Side::kLhs], - end[Side::kRhs]); - break; - } - case TraceEvent::kBlockPackedLhs: { - std::uint32_t block = entry.params[0]; - int start, end; - GetBlockMatrixCoords(Side::kLhs, trace.block_map, block, &start, &end); - fprintf(trace_file, "BlockPackedLhs: %lld, %d, %d, %d, %d\n", time, - entry.thread_id, block, start, end); - break; - } - case TraceEvent::kBlockPackedRhs: { - std::uint32_t block = entry.params[0]; - int start, end; - GetBlockMatrixCoords(Side::kRhs, trace.block_map, block, &start, &end); - fprintf(trace_file, "BlockPackedRhs: %lld, %d, %d, %d, %d\n", time, - entry.thread_id, block, start, end); - break; - } - case TraceEvent::kBlockFinished: { - std::uint32_t block_id = entry.params[0]; - SidePair block; - GetBlockByIndex(trace.block_map, block_id, &block); - fprintf(trace_file, "BlockFinished: %lld, %d, %d, %d, %d\n", time, - entry.thread_id, block_id, block[Side::kLhs], - block[Side::kRhs]); - break; - } - default: - RUY_CHECK(false); - } - } - fprintf(trace_file, "End: %lld\n", - static_cast( - ToInt64Nanoseconds(trace.time_end - trace.time_start))); - if (trace_filename) { - fclose(trace_file); - } -} - -} // anonymous namespace - -// Get a Trace object to record to, or null of tracing is not enabled. -Trace* NewTraceOrNull(TracingContext* tracing, int rows, int depth, int cols) { - if (!tracing->initialized) { - tracing->initialized = true; - tracing->enabled = getenv("RUY_TRACE"); - if (!tracing->enabled) { - return nullptr; - } - if (getenv("RUY_TRACE_FILTER_ROWS")) { - tracing->filter_shape_rows = std::stoi(getenv("RUY_TRACE_FILTER_ROWS")); - } - if (getenv("RUY_TRACE_FILTER_DEPTH")) { - tracing->filter_shape_depth = std::stoi(getenv("RUY_TRACE_FILTER_DEPTH")); - } - if (getenv("RUY_TRACE_FILTER_COLS")) { - tracing->filter_shape_cols = std::stoi(getenv("RUY_TRACE_FILTER_COLS")); - } - } - if (!tracing->enabled) { - return nullptr; - } - if (tracing->filter_shape_rows && rows != tracing->filter_shape_rows) { - return nullptr; - } - if (tracing->filter_shape_depth && depth != tracing->filter_shape_depth) { - return nullptr; - } - if (tracing->filter_shape_cols && cols != tracing->filter_shape_cols) { - return nullptr; - } - // Delete any existing trace. - delete tracing->trace; - // Create a new one. - tracing->trace = new Trace; - return tracing->trace; -} - -// The trace recorded on a context is finalized and dumped by -// this TracingContext destructor. -// -// The idea of dumping on context destructor is that typically one wants to -// run many matrix multiplications, e.g. to hit a steady state in terms of -// performance characteristics, but only trace the last repetition of the -// workload, when that steady state was attained. -TracingContext::~TracingContext() { - if (trace) { - AggregateThreadSpecificEntries(trace); - Sort(trace); - Dump(*trace); - } - delete trace; -} - -void TraceRecordStart(Trace* trace) { - if (trace) { - trace->time_start = Now(); - } -} - -void TraceRecordExecute(const BlockMap& block_map, Trace* trace) { - if (trace) { - trace->time_execute = Now(); - trace->block_map = block_map; - trace->thread_specific_entries.resize(block_map.thread_count); - for (int thread = 0; thread < block_map.thread_count; thread++) { - trace->thread_specific_entries[thread].clear(); - // Reserve some large size to avoid frequent heap allocations - // affecting the recorded timings. - trace->thread_specific_entries[thread].reserve(16384); - } - } -} - -void TraceRecordEnd(Trace* trace) { - if (trace) { - trace->time_end = Now(); - } -} - -void TraceRecordThreadStart(std::uint32_t thread_id, Trace* trace) { - if (trace) { - TraceEntry entry; - entry.event = TraceEvent::kThreadStart; - entry.time_point = Now(); - entry.thread_id = thread_id; - trace->thread_specific_entries[thread_id].push_back(entry); - } -} - -void TraceRecordThreadLoopStart(std::uint32_t thread_id, Trace* trace) { - if (trace) { - TraceEntry entry; - entry.event = TraceEvent::kThreadLoopStart; - entry.time_point = Now(); - entry.thread_id = thread_id; - trace->thread_specific_entries[thread_id].push_back(entry); - } -} - -void TraceRecordBlockReserved(std::uint32_t thread_id, std::uint32_t block_id, - Trace* trace) { - if (trace) { - TraceEntry entry; - entry.event = TraceEvent::kBlockReserved; - entry.time_point = Now(); - entry.thread_id = thread_id; - entry.params[0] = block_id; - trace->thread_specific_entries[thread_id].push_back(entry); - } -} - -void TraceRecordBlockPacked(std::uint32_t thread_id, Side side, int block, - Trace* trace) { - if (trace) { - TraceEntry entry; - entry.event = side == Side::kLhs ? TraceEvent::kBlockPackedLhs - : TraceEvent::kBlockPackedRhs; - entry.time_point = Now(); - entry.thread_id = thread_id; - entry.params[0] = block; - trace->thread_specific_entries[thread_id].push_back(entry); - } -} - -void TraceRecordBlockFinished(std::uint32_t thread_id, std::uint32_t block_id, - Trace* trace) { - if (trace) { - TraceEntry entry; - entry.event = TraceEvent::kBlockFinished; - entry.time_point = Now(); - entry.thread_id = thread_id; - entry.params[0] = block_id; - trace->thread_specific_entries[thread_id].push_back(entry); - } -} - -void TraceRecordThreadEnd(std::uint32_t thread_id, Trace* trace) { - if (trace) { - TraceEntry entry; - entry.event = TraceEvent::kThreadEnd; - entry.time_point = Now(); - entry.thread_id = thread_id; - trace->thread_specific_entries[thread_id].push_back(entry); - } -} - -#endif - -} // namespace ruy diff --git a/trace.h b/trace.h deleted file mode 100644 index 144065c..0000000 --- a/trace.h +++ /dev/null @@ -1,73 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_TRACE_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_TRACE_H_ - -#include - -#include "block_map.h" -#include "side_pair.h" - -namespace ruy { - -struct Trace; - -#ifdef RUY_TRACE - -struct TracingContext { - bool initialized = false; - bool enabled = false; - int filter_shape_rows = 0; - int filter_shape_cols = 0; - int filter_shape_depth = 0; - Trace* trace = nullptr; - ~TracingContext(); -}; - -Trace* NewTraceOrNull(TracingContext* context, int rows, int depth, int cols); -void TraceRecordThreadStart(std::uint32_t thread_id, Trace* trace); -void TraceRecordThreadLoopStart(std::uint32_t thread_id, Trace* trace); -void TraceRecordBlockReserved(std::uint32_t thread_id, std::uint32_t block_id, - Trace* trace); -void TraceRecordBlockPacked(std::uint32_t thread_id, Side side, int block, - Trace* trace); -void TraceRecordBlockFinished(std::uint32_t thread_id, std::uint32_t block_id, - Trace* trace); -void TraceRecordThreadEnd(std::uint32_t thread_id, Trace* trace); -void TraceRecordStart(Trace* trace); -void TraceRecordExecute(const BlockMap& block_map, Trace* trace); -void TraceRecordEnd(Trace* trace); - -#else - -struct TracingContext {}; - -inline Trace* NewTraceOrNull(TracingContext*, int, int, int) { return nullptr; } -inline void TraceRecordThreadStart(std::uint32_t, Trace*) {} -inline void TraceRecordThreadLoopStart(std::uint32_t, Trace*) {} -inline void TraceRecordBlockReserved(std::uint32_t, std::uint32_t, Trace*) {} -inline void TraceRecordBlockPacked(std::uint32_t, Side, int, Trace*) {} -inline void TraceRecordBlockFinished(std::uint32_t, std::uint32_t, Trace*) {} -inline void TraceRecordThreadEnd(std::uint32_t, Trace*) {} -inline void TraceRecordStart(Trace*) {} -inline void TraceRecordExecute(const BlockMap&, Trace*) {} -inline void TraceRecordEnd(Trace*) {} - -#endif - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_TRACE_H_ diff --git a/trmul.cc b/trmul.cc deleted file mode 100644 index 48ac44f..0000000 --- a/trmul.cc +++ /dev/null @@ -1,401 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "trmul.h" - -#include -#include -#include -#include -#include - -#include "allocator.h" -#include "block_map.h" -#include "check_macros.h" -#include "common.h" -#include "internal_matrix.h" -#include "matrix.h" -#include "opt_set.h" -#include "profiler/instrumentation.h" -#include "side_pair.h" -#include "size_util.h" -#include "spec.h" -#include "thread_pool.h" -#include "trace.h" -#include "tune.h" - -namespace ruy { - -namespace { - -enum class PackingStatus : std::uint8_t { kNotStarted, kInProgress, kFinished }; - -struct TrMulTask final : Task { - TrMulTask(TrMulParams* params_, const BlockMap& block_map_, - std::atomic* atomic_block_id_, int thread_id_, - bool need_atomics_, - SidePair*> packing_status_, - TuningResolver* tuning_resolver_, Allocator* local_allocator_, - Trace* trace_) - : params(params_), - block_map(block_map_), - atomic_block_id(atomic_block_id_), - thread_id(thread_id_), - need_atomics(need_atomics_), - packing_status(packing_status_), - tuning_resolver(tuning_resolver_), - local_allocator(local_allocator_), - trace(trace_), - local_packed{nullptr, nullptr} {} - - void Run() override { - TraceRecordThreadStart(thread_id, trace); - - for (Side side : {Side::kLhs, Side::kRhs}) { - if (!params->is_prepacked[side]) { - const int size = NumBlocksPerSide(side, block_map); - local_allocator->Allocate(size, &local_packed[side]); - memset(local_packed[side], 0, size * sizeof(bool)); - } - } - - const int num_blocks = NumBlocks(block_map); - - const Tuning tuning = tuning_resolver->Resolve(); - - TraceRecordThreadLoopStart(thread_id, trace); - - SidePair block; - SidePair start; - SidePair end; - - // Each thread starts by initially reserving the block whose id - // is the thread id. - int block_id = thread_id; - TraceRecordBlockReserved(thread_id, block_id, trace); - - while (block_id < num_blocks) { - // Reserve the next block to handle. In order to hide the latency - // (typically comparable to an access to the level of data cache that - // is shared among CPU cores, e.g. 60 cycles on an ARM CPU as of 2019) - // of this atomic operation, we structure this code so as to avoid - // immediately depending on the `next_n` result. - const int next_block_id = - atomic_block_id->fetch_add(1, std::memory_order_relaxed); - TraceRecordBlockReserved(thread_id, next_block_id, trace); - // Get coordinates of the current block to handle, in "block space". - GetBlockByIndex(block_map, block_id, &block); - // Get coordinates of the current block to handle, in matrix space. - GetBlockMatrixCoords(block_map, block, &start, &end); - // Maybe pack the current LHS/RHS block, if not already packed. - EnsurePacked(block, start, end, tuning); - // Actually do matrix multiplication work - params->RunKernel(tuning, start, end); - TraceRecordBlockFinished(thread_id, block_id, trace); - // Move on to the next block as obtained by the atomic increment - // at the start of this while loop iteration. - block_id = next_block_id; - } - - local_allocator->FreeAll(); - - TraceRecordThreadEnd(thread_id, trace); - } - - private: - // Tries to pack a block, without blocking. - // If the block was already packed, returns true. - // If the block was not started packing, packs it and returns true. - // If the block was being packed by another thread, returns false. - bool TryPack(Side side, int block, int start, int end, Tuning tuning) { - if (params->is_prepacked[side]) { - return true; - } - if (!local_packed[side][block]) { - if (need_atomics) { - // Explanation of this compare_exchange_strong operation: - // This atomically performs all of the following: - // 1. Read `status` with "acquire" memory order. - // * That this read uses "acquire" is because both memory orders - // specified have "acquire" as their read-component. - // 2. Compare (bitwise) with `exchanged_status`. - // 3. If equal, stores the value kInProgress to `status` with "release" - // memory order, and returns true, so we take this 'if' branch. - // * That this store uses "release" is because of the _rel part in - // memory_order_acq_rel passed as the first memory order argument. - // 4. If not equal, stores the loaded value of `status` to - // `exchanged_status` with "relaxed" semantics, and returns false, - // so we take the 'else' branch. - // * That this store uses "relaxed" is because the second memory - // order argument, memory_order_acquire, implies no particular - // store semantics. "relaxed" is acceptable here because this - // stores to a local stack variable. - // - // Rationale for compare_exchange_strong as opposed to - // compare_exchange_weak: - // The spurious-failure case with compare_exchange_weak will actually - // happen a lot here, because the atomic 'status' bytes are stored - // contiguously in arrays and neighboring values will be accessed - // by multiple threads concurrently. On a typical ARM CPU, an exclusives - // reservation granule is 64 bytes, so a lot of false-sharing may - // happen. Using compare_exchange_weak would thus result in often having - // TryPack return 'false' when it could instead have done the packing - // work and returned 'true'. Heuristically, that is not a good thing. - // Moreover, this changes the TryPack contract, loosening it and making - // it harder for the caller to reason about. Finally, the overhead of - // atomic operations is mitigated by the enclosing check on - // local_packed, so maybe the overhead of compare_exchange_strong isn't - // such a problem. But we don't really know for sure, that would be - // interesting to experiment more with. - PackingStatus exchanged_status = PackingStatus::kNotStarted; - std::atomic& status = packing_status[side][block]; - if (status.compare_exchange_strong( - exchanged_status, PackingStatus::kInProgress, - std::memory_order_acq_rel, std::memory_order_acquire)) { - // In this branch, the status was kNotStarted and we just atomically - // changed it to kInProgress as we are about to handle the packing - // ourselves. - params->RunPack(side, tuning, start, end); - TraceRecordBlockPacked(thread_id, side, block, trace); - status.store(PackingStatus::kFinished, std::memory_order_release); - } else if (exchanged_status == PackingStatus::kInProgress) { - // Another thread is currently packing this block. - return false; - } - RUY_DCHECK(status.load(std::memory_order_acquire) == - PackingStatus::kFinished); - } else { - // Single-threaded case: no need for expensive atomics, local_packed - // is the truth already. - params->RunPack(side, tuning, start, end); - TraceRecordBlockPacked(thread_id, side, block, trace); - } - local_packed[side][block] = true; - } - return true; - } - - // Ensures that both the LHS and RHS blocks required by the specified block - // are packed. In the event that they are already being packed on another - // threads, this function may perform the packing of some other block while - // waiting for that other thread to finish packing the requested block. - void EnsurePacked(const SidePair& block, const SidePair& start, - const SidePair& end, Tuning tuning) { -#if RUY_OPT_ENABLED(RUY_OPT_PACK_AHEAD) - SidePair next_runahead_block{block[Side::kLhs] + 1, - block[Side::kRhs] + 1}; - Side next_runahead_side = Side::kLhs; -#endif - while (true) { - bool both_sides_packed = true; - for (Side side : {Side::kLhs, Side::kRhs}) { - both_sides_packed &= - TryPack(side, block[side], start[side], end[side], tuning); - } - if (both_sides_packed) { - break; - } -#if RUY_OPT_ENABLED(RUY_OPT_PACK_AHEAD) - const Side runahead_side = next_runahead_side; - const int runahead_block = next_runahead_block[runahead_side]; - next_runahead_side = - next_runahead_side == Side::kLhs ? Side::kRhs : Side::kLhs; - if (runahead_block >= NumBlocksPerSide(runahead_side, block_map)) { - continue; - } - int runahead_block_start, runahead_block_end; - GetBlockMatrixCoords(runahead_side, block_map, runahead_block, - &runahead_block_start, &runahead_block_end); - TryPack(runahead_side, runahead_block, runahead_block_start, - runahead_block_end, tuning); - next_runahead_block[runahead_side] = runahead_block + 1; -#endif - } - } - - TrMulParams* params; - const BlockMap& block_map; - std::atomic* atomic_block_id; - int thread_id; - bool need_atomics; - SidePair*> packing_status; - TuningResolver* tuning_resolver; - Allocator* local_allocator; - Trace* trace; - - // Local indicators of packedness to avoid the overhead of atomic ops. - SidePair local_packed; -}; - -void AllocatePMatrix(Allocator* allocator, PMatrix* packed) { - packed->data = allocator->AllocateBytes(DataSize(*packed)); - packed->sums = allocator->AllocateBytes(SumsSize(*packed)); -} - -int GetThreadCount(Context* context, int rows, int cols, int depth) { -#if RUY_PLATFORM(EMSCRIPTEN) - // b/139927184, std::thread constructor raises exception - return 1; -#endif - // Empirically determined rule for reasonable number of - // threads to use. This is proportional to the number of arithmetic ops - // in this Mul (product of the 3 sizes). - static constexpr int kDivisorLog2 = 15; - const int guess_log2 = std::max( - 0, ceil_log2(rows) + ceil_log2(cols) + ceil_log2(depth) - kDivisorLog2); - return std::min(1 << guess_log2, context->max_num_threads); -} - -LoopStructure GetLoopStructure(int tentative_thread_count, int rows, int cols, - int depth, int lhs_scalar_size, - int rhs_scalar_size, int local_data_cache_size, - int shared_data_cache_size) { - if (tentative_thread_count == 1) { - const BlockMapTraversalOrder traversal_order = - GetTraversalOrder(rows, cols, depth, lhs_scalar_size, rhs_scalar_size, - local_data_cache_size, shared_data_cache_size); - // If we are in the GEMV case or the block_map would be using linear - // traversal anyway, use the simple loop. - if ((cols == 1) || traversal_order == BlockMapTraversalOrder::kLinear) { - return LoopStructure::kSimple; - } - } - return LoopStructure::kGeneral; -} - -} // namespace - -void TrMul(TrMulParams* params, Context* context) { - profiler::ScopeLabel label( - "TrMul (Path=0x%x, max_num_threads=%d, is_prepacked=(%d,%d))", - static_cast(params->path), context->max_num_threads, - params->is_prepacked[Side::kLhs], params->is_prepacked[Side::kRhs]); - - PMatrix& packed_lhs = params->packed[Side::kLhs]; - PMatrix& packed_rhs = params->packed[Side::kRhs]; - DMatrix& lhs = params->src[Side::kLhs]; - DMatrix& rhs = params->src[Side::kRhs]; - - const int rows = lhs.layout.cols; - const int cols = rhs.layout.cols; - const int depth = lhs.layout.rows; - - const int tentative_thread_count = GetThreadCount(context, rows, cols, depth); - const auto loop_structure = GetLoopStructure( - tentative_thread_count, rows, cols, depth, lhs.data_type.size, - rhs.data_type.size, params->local_data_cache_size, - params->shared_data_cache_size); - Allocator* allocator = context->GetMainAllocator(); - - // Allocate packed matrices - for (Side side : {Side::kLhs, Side::kRhs}) { - if (!params->is_prepacked[side]) { - AllocatePMatrix(allocator, ¶ms->packed[side]); - } - } - - // Case of running this TrMul as a simple loop. - // This is a good place to start reading this function: all the rest - // of this function is just an optimized, but functionally equivalent, - // version of that. - if (loop_structure == LoopStructure::kSimple) { - profiler::ScopeLabel label_simple("TrMulImpl, simple loop"); - Tuning tuning = context->GetMainThreadTuning(); - - const SidePair origin{0, 0}; - const SidePair rounded_dims{packed_lhs.layout.cols, - packed_rhs.layout.cols}; - for (Side side : {Side::kLhs, Side::kRhs}) { - if (!params->is_prepacked[side]) { - params->RunPack(side, tuning, origin[side], rounded_dims[side]); - } - } - params->RunKernel(tuning, origin, rounded_dims); - - allocator->FreeAll(); - return; - } - - profiler::ScopeLabel label_general("TrMulImpl, general case"); - - auto* trace = NewTraceOrNull(&context->tracing, rows, depth, cols); - TraceRecordStart(trace); - - // Initialize block map. - BlockMap block_map; - MakeBlockMap(packed_lhs.layout.cols, packed_rhs.layout.cols, depth, - packed_lhs.layout.kernel.cols, packed_rhs.layout.kernel.cols, - packed_lhs.data_type.size, packed_rhs.data_type.size, - tentative_thread_count, params->path, - params->local_data_cache_size, params->shared_data_cache_size, - &block_map); - - // Initialize per-thread state. - const int thread_count = block_map.thread_count; - const bool need_atomics = thread_count > 1; - context->EnsureNPerThreadStates(thread_count); - for (auto& per_thread_state : context->per_thread_states) { - per_thread_state->tuning_resolver.SetTuning(context->explicit_tuning); - } - - // In the need_atomics case, allocate and initialize atomic values tracking - // the packing status of blocks. - SidePair*> packing_status{nullptr, nullptr}; - if (need_atomics) { - for (Side side : {Side::kLhs, Side::kRhs}) { - if (!params->is_prepacked[side]) { - const int size = NumBlocksPerSide(side, block_map); - allocator->Allocate(size, &packing_status[side]); - for (int i = 0; i < size; i++) { - packing_status[side][i].store(PackingStatus::kNotStarted, - std::memory_order_relaxed); - } - } - } - } - - // Create the atomic block id, allocate it using Allocator so that - // we get the alignment ensuring that it sits alone in its exclusives - // reservation granule. - std::atomic* atomic_block_id; - allocator->Allocate(1, &atomic_block_id); - - // Create task objects. - TrMulTask* tasks; - allocator->Allocate(thread_count, &tasks); - - atomic_block_id->store(thread_count); - - for (int i = 0; i < thread_count; i++) { - new (tasks + i) TrMulTask(params, block_map, atomic_block_id, i, - need_atomics, packing_status, - &context->per_thread_states[i]->tuning_resolver, - &context->per_thread_states[i]->allocator, trace); - } - - // Do the computation. - TraceRecordExecute(block_map, trace); - context->workers_pool.Execute(thread_count, tasks); - - // Finish up. - for (int i = 0; i < thread_count; i++) { - tasks[i].~TrMulTask(); - } - - allocator->FreeAll(); - TraceRecordEnd(trace); -} - -} // namespace ruy diff --git a/trmul.h b/trmul.h deleted file mode 100644 index adb6cb3..0000000 --- a/trmul.h +++ /dev/null @@ -1,38 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// As a matrix multiplication library, Ruy offers a Mul entry point, performing -// matrix multiplication. For implementation purposes, it is much nicer to -// be dealing with the transpose-and-multiply operation, doing -// Destination = Transpose(LHS) * RHS -// Indeed, the latter is performing dot-products between the *columns* of LHS -// and the columns of RHS, whereas a plain matrix multiplication is performing -// dot-products between the *rows* of LHS and the columns of RHS. -// That is why TrMul is nicer to implement, allowing for a more symmetric -// treatment of LHS and RHS. - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_TRMUL_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_TRMUL_H_ - -#include "context.h" -#include "trmul_params.h" - -namespace ruy { - -void TrMul(TrMulParams* params, Context* context); - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_TRMUL_H_ diff --git a/trmul_params.h b/trmul_params.h deleted file mode 100644 index fc7970e..0000000 --- a/trmul_params.h +++ /dev/null @@ -1,67 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_TRMUL_PARAMS_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_TRMUL_PARAMS_H_ - -#include "internal_matrix.h" -#include "side_pair.h" -#include "tune.h" - -namespace ruy { - -using RunKernelFn = void(Tuning, const SidePair&, void*, - const SidePair&, const SidePair&, DMatrix*); - -using RunPackFn = void(Tuning, const DMatrix&, PMatrix*, int, int); - -// Type-erased data needed for implementing TrMul. -struct TrMulParams { - TrMulParams() : run_pack{nullptr, nullptr}, is_prepacked{false, false} {} - // Helper functions for invoking the function pointers. - void RunPack(Side side, Tuning tuning, int start, int end) { - run_pack[side](tuning, src[side], &packed[side], start, end); - } - void RunKernel(Tuning tuning, const SidePair& start, - const SidePair& end) { - run_kernel(tuning, packed, spec, start, end, &dst); - } - - // path id, can be useful info for some fine-tuning, e.g. to guess reasonable - // cache sizes when not runtime-detectable. - Path path; - - // See Spec::local_data_cache_size(). - int local_data_cache_size = 0; - // See Spec::shared_data_cache_size(). - int shared_data_cache_size = 0; - - // Function pointers to type-erased entry points for kernels and packers. - SidePair run_pack; - RunKernelFn* run_kernel = nullptr; - - // Matrices and packed matrices. - SidePair src; - DMatrix dst; - SidePair packed; - SidePair is_prepacked; - - // Type-erased Spec. - void* spec = nullptr; -}; - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_TRMUL_PARAMS_H_ diff --git a/tune.cc b/tune.cc deleted file mode 100644 index cb615d3..0000000 --- a/tune.cc +++ /dev/null @@ -1,161 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tune.h" - -#include -#include - -namespace ruy { - -#ifdef RUY_IMPLEMENT_TUNING - -namespace { - -void PoorlyOrderedKernel(int iters) { - asm volatile( - "mov w0, %w[iters]\n" - "1:\n" - "subs w0, w0, #1\n" - "mul v0.4s, v0.4s, v0.4s\n" - "mul v0.4s, v0.4s, v0.4s\n" - "mul v0.4s, v0.4s, v0.4s\n" - "mul v0.4s, v0.4s, v0.4s\n" - "mul v1.4s, v1.4s, v1.4s\n" - "mul v1.4s, v1.4s, v1.4s\n" - "mul v1.4s, v1.4s, v1.4s\n" - "mul v1.4s, v1.4s, v1.4s\n" - "mul v2.4s, v2.4s, v2.4s\n" - "mul v2.4s, v2.4s, v2.4s\n" - "mul v2.4s, v2.4s, v2.4s\n" - "mul v2.4s, v2.4s, v2.4s\n" - "mul v3.4s, v3.4s, v3.4s\n" - "mul v3.4s, v3.4s, v3.4s\n" - "mul v3.4s, v3.4s, v3.4s\n" - "mul v3.4s, v3.4s, v3.4s\n" - "bne 1b\n" ::[iters] "r"(iters) - : "cc", "x0", "v0", "v1", "v2", "v3"); -} - -void NicelyOrderedKernel(int iters) { - asm volatile( - "mov w0, %w[iters]\n" - "1:\n" - "subs w0, w0, #1\n" - "mul v0.4s, v0.4s, v0.4s\n" - "mul v1.4s, v1.4s, v1.4s\n" - "mul v2.4s, v2.4s, v2.4s\n" - "mul v3.4s, v3.4s, v3.4s\n" - "mul v0.4s, v0.4s, v0.4s\n" - "mul v1.4s, v1.4s, v1.4s\n" - "mul v2.4s, v2.4s, v2.4s\n" - "mul v3.4s, v3.4s, v3.4s\n" - "mul v0.4s, v0.4s, v0.4s\n" - "mul v1.4s, v1.4s, v1.4s\n" - "mul v2.4s, v2.4s, v2.4s\n" - "mul v3.4s, v3.4s, v3.4s\n" - "mul v0.4s, v0.4s, v0.4s\n" - "mul v1.4s, v1.4s, v1.4s\n" - "mul v2.4s, v2.4s, v2.4s\n" - "mul v3.4s, v3.4s, v3.4s\n" - "bne 1b\n" ::[iters] "r"(iters) - : "cc", "x0", "v0", "v1", "v2", "v3"); -} - -} // namespace - -float TuningResolver::EvalRatio() { - // With the current settings, 400 iterations and 4 repeats, this test has - // a latency of roughly 80 microseconds on a Cortex-A53 at 1.4 GHz. - static constexpr int kLoopIters = 400; - static constexpr int kRepeats = 4; - - Duration timing_poorly_ordered = Duration::max(); - Duration timing_nicely_ordered = Duration::max(); - - for (int r = 0; r < kRepeats; r++) { - TimePoint t0 = Now(); - PoorlyOrderedKernel(kLoopIters); - TimePoint t1 = Now(); - NicelyOrderedKernel(kLoopIters); - TimePoint t2 = Now(); - timing_poorly_ordered = std::min(timing_poorly_ordered, t1 - t0); - timing_nicely_ordered = std::min(timing_nicely_ordered, t2 - t1); - } - - return ToFloatSeconds(timing_nicely_ordered) / - ToFloatSeconds(timing_poorly_ordered); -} - -float TuningResolver::ThresholdRatio() { - // Empirically (see :tune_tool) determined threshold to distinguish in-order - // Cortex-A53/A55 cores from out-of-order Cortex-A57/A73/A75/A76 cores. Based - // on these experimental results, which were obtained with much lower - // (kLoopIters=1000, kRepeats=1) so as to make them resilient to noise, we - // have: - // - // CPU core type | in/out of order | observed ratio - // --------------+-----------------+----------------------------------------- - // Cortex-A53 | in-order | 0.32 -- 0.329 - // Cortex-A55 | in-order | 0.319 -- 0.325 - // Cortex-A55r1 | in-order | 0.319 -- 0.325 - // Cortex-A57 | out-of-order | 0.99 -- 1.01 - // Cortex-A73 | out-of-order | 0.922 -- 0.927 - // Cortex-A75 | out-of-order | 0.921 -- 0.93 - // Cortex-A76 | out-of-order | 1 - // Kryo (pixel1) | out-of-order | 0.73 -- 0.76 - // - // Thus the allowable range for the threshold is [0.35 .. 0.70]. - // We pick a value closer to the upper bound because really any out-of-order - // CPU should by definition produce a ratio close to 1. - return 0.65f; -} - -Tuning TuningResolver::ResolveNow() { - const bool is_probably_inorder = EvalRatio() < ThresholdRatio(); - return is_probably_inorder ? Tuning::kInOrder : Tuning::kOutOfOrder; -} - -#else // not defined RUY_IMPLEMENT_TUNING - -float TuningResolver::EvalRatio() { return 0; } -float TuningResolver::ThresholdRatio() { return 0; } - -Tuning TuningResolver::ResolveNow() { return Tuning::kOutOfOrder; } - -#endif - -TuningResolver::TuningResolver() - : expiry_duration_(DurationFromMilliseconds(250)) {} - -Tuning TuningResolver::Resolve() { -#ifdef RUY_IMPLEMENT_TUNING - if (unresolved_tuning_ != Tuning::kAuto) { - return unresolved_tuning_; - } - TimePoint new_timepoint = CoarseNow(); - if (last_resolved_tuning_ != Tuning::kAuto && - (new_timepoint - last_resolved_timepoint_) < expiry_duration_) { - return last_resolved_tuning_; - } - last_resolved_timepoint_ = new_timepoint; - last_resolved_tuning_ = ResolveNow(); - return last_resolved_tuning_; -#else - return Tuning::kOutOfOrder; -#endif -} - -} // namespace ruy diff --git a/tune.h b/tune.h deleted file mode 100644 index db321fd..0000000 --- a/tune.h +++ /dev/null @@ -1,163 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Library doing minimal CPU detection to decide what to tune asm code for. -// -// # Tuning vs Path -// -// Tunings are merely local variations of optimized code paths, that are -// drop-in replacements for each other --- the input and output data layouts -// are identical. By contrast, what ruy calls a Path dictates its own -// data layouts. For example, Path::kNeonDotprod will use different -// layouts compared to Path::kNeon; but within each, different tunings -// will share that same layout. -// -// # Tuning is for now only based on 1 bit: OutOfOrder / InOrder -// -// In practice, each of our asm code paths only needs one bit information to -// decide on tuning: whether the CPU is out-of-order or in-order. -// That is because out-of-order CPUs are by definition relatively insensitive -// to small-scale asm details (which is what "tuning" is about); and for each -// asm code path, there tends to be one main in-order CPU architecture that -// we focus our tuning effort on. Examples: -// * For Path::kNeon, the main in-order CPU is Cortex-A53/A55 (pre-dotprod) -// * For Path::kNeonDotprod, the main in-order CPU is Cortex-A55r1 (dotprod) -// -// Because having tuned code paths is a compromise of efficiency gains -// versus implementation effort and code size, we are happy to stop at just this -// single bit of information, OutOfOrder/InOrder, at least in the current CPU -// landscape. This could change in the future. -// -// # Implementation notes and alternatives. -// -// The current implementation uses a nano-benchmark, see tune.cc. -// That is why it's quite expensive, making caching / -// statefulness necessary (see TuningResolver class comment). -// -// An interesting alternative, which was explained to us by Marat Dukhan -// (maratek@) after this was implemented, would be to use the -// getcpu(2) system call on Linux. This returns a -// numeric CPU identifier that could be mapped to a OutOfOrder/InOrder -// classification given additional information about the CPU. Such -// additional information could be obtained by the cpuinfo library, -// https://github.com/pytorch/cpuinfo -// which obtains this information mainly from parsing /proc/cpuinfo. -// Pros: -// * Would remove the need for the relatively expensive nano-benchmark -// (dozens of microseconds, which have to be reevaluated again several -// times per second). -// * Would conceivably be more reliable. -// Cons: -// * Linux-specific. -// * Modest binary size increase (Marat mentioned the cpuinfo lib is 20k). -// * Won't support exactly 100% of devices (nonstandard /proc/cpuinfo etc). -// -// We could also have both: -// * Maybe by trying getcpu first if supported, then falling back to a -// nano-benchmark. -// * Maybe using getcpu in conjunction with the nano-benchmark to cache -// per-CPU-id nano-benchmark results. -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_TUNE_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_TUNE_H_ - -#include "opt_set.h" -#include "platform.h" -#include "time.h" - -// Tuning only implemented on NEON_64 at the moment (see assembly code -// in the nano-benchmark) and not on Apple (some Apple CPUs produce incorrect -// results on in-order-tuned kernels combining ARM and NEON load instructions -// and NEON `ins` instructions). -// -// When tuning is not implemented, we simply always use Tuning::kOutOfOrder. -#if RUY_OPT_ENABLED(RUY_OPT_TUNING) && RUY_PLATFORM(NEON_64) && \ - !RUY_PLATFORM(APPLE) -#define RUY_IMPLEMENT_TUNING -#endif - -namespace ruy { - -enum class Tuning { - // kAuto means please use auto-detection. It's the default in the - // user-visible parts (see Context). It's meant to be resolved to an - // actual tuning at some point by means of TuningResolver. - kAuto, - // Target an out-order CPU. Example: ARM Cortex-A75. - kOutOfOrder, - // Target an in-order CPU. Example: ARM Cortex-A55. - kInOrder -}; - -// Why a TuningResolver class? -// -// Ideally, this Library would offer a single function, -// Tuning GetCurrentCPUTuning(); -// -// However, determining information about the current CPU is not necessarily, -// cheap, so we currently cache that and only invalidate/reevaluate after -// a fixed amount of time. This need to store state is why this library -// has to expose a class, TuningResolver, not just a function. -class TuningResolver { - public: - TuningResolver(); - - // Allows the user to specify an explicit Tuning value, bypassing auto - // detection; or to specify Tuning::kAuto, reverting to auto detection. - void SetTuning(Tuning tuning) { unresolved_tuning_ = tuning; } - - // Get an actual tuning --- that is the function that this class wanted to be. - Tuning Resolve(); - - private: - TuningResolver(const TuningResolver&) = delete; - - // TuningTool is a demo/tool used to tweak the tuning implementation to - // specific devices. It needs to access some finer granularity information - // than just the Tuning returned by Resolve. Nothing else should need - // access to that. - friend class TuneTool; - // Actually runs a nano-benchmark, producing a real number called 'ratio' - // whose meaning is generally opaque / implementation defined. Typically, - // this would be the ratio between the latencies of two different - // pieces of asm code differing only by the ordering of instructions, - // revealing whether the CPU cares about such ordering details. - // An implementation may just return a dummy value if it is not based on - // such nanobenchmarking / ratio evaluation. - float EvalRatio(); - // Empirically determined threshold on ratio values delineating - // out-of-order (ratios closer to 1) from in-order (ratios farther from 1). - // An implementation may just return a dummy value if it is not based on - // such nanobenchmarking / ratio evaluation. - float ThresholdRatio(); - // Perform the tuning resolution now. That may typically use EvalRatio and - // ThresholdRatio, but an implementation may use a different approach instead. - Tuning ResolveNow(); - - // The tuning as specified by the user, before actual resolution happens - // i.e. before querying any specifics of the current CPU. - // The default value kAuto means try to auto-detect. Other values mean - // bypass auto-detect, use explicit value instead. See SetTuning(). - Tuning unresolved_tuning_ = Tuning::kAuto; - // Cached last resolved tuning. - Tuning last_resolved_tuning_ = Tuning::kAuto; - // Timepoint of cached last resolved tuning, for invalidation purposes. - TimePoint last_resolved_timepoint_; - // Cached last resolved tunings that are older than this age are invalid. - const Duration expiry_duration_; -}; - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_TUNE_H_ diff --git a/tune_test.cc b/tune_test.cc deleted file mode 100644 index 1c09dba..0000000 --- a/tune_test.cc +++ /dev/null @@ -1,53 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tune.h" - -#include // NOLINT(build/c++11) -#include // NOLINT(build/c++11) - -#include "testing/base/public/gunit.h" - -namespace ruy { -namespace { - -TEST(TuneTest, TuneTest) { - TuningResolver tuning_resolver; - ASSERT_FALSE(tuning_resolver.Resolve() == Tuning::kAuto); - // 1 second is likely higher than TuningResolver's internal cache expiry, - // exercising the logic invalidating earlier tuning resolutions. - std::this_thread::sleep_for(std::chrono::seconds(1)); - ASSERT_FALSE(tuning_resolver.Resolve() == Tuning::kAuto); - - tuning_resolver.SetTuning(Tuning::kAuto); - -#ifdef RUY_IMPLEMENT_TUNING - for (auto tuning : {Tuning::kOutOfOrder, Tuning::kInOrder}) { - tuning_resolver.SetTuning(tuning); - ASSERT_TRUE(tuning_resolver.Resolve() == tuning); - // See above comment about 1 second. - std::this_thread::sleep_for(std::chrono::seconds(1)); - ASSERT_TRUE(tuning_resolver.Resolve() == tuning); - } -#endif -} - -} // namespace -} // namespace ruy - -int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); -} diff --git a/tune_tool.cc b/tune_tool.cc deleted file mode 100644 index 749e4ae..0000000 --- a/tune_tool.cc +++ /dev/null @@ -1,56 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Self-contained tool used to tune the tune code --- see the -// threshold ratios used in tune.cc. - -#include // NOLINT(build/c++11) -#include -#include // NOLINT(build/c++11) - -#include "tune.h" - -#ifdef _WIN32 -#define getpid() 0 -#else -#include -#endif - -namespace ruy { - -class TuneTool { - public: - static void Query(float* eval, float* threshold) { - TuningResolver resolver; - *eval = resolver.EvalRatio(); - *threshold = resolver.ThresholdRatio(); - } -}; - -} // namespace ruy - -int main() { - // Infinite loop: the user can hit Ctrl-C - while (true) { - float eval; - float threshold; - ruy::TuneTool::Query(&eval, &threshold); - printf("[%d] eval=%.3f %c threshold=%.3f ==> probably %s...\n", getpid(), - eval, eval < threshold ? '<' : '>', threshold, - eval < threshold ? "in-order" : "out-of-order"); - fflush(stdout); - std::this_thread::sleep_for(std::chrono::seconds(1)); - } -} diff --git a/wait.cc b/wait.cc deleted file mode 100644 index 330b7dd..0000000 --- a/wait.cc +++ /dev/null @@ -1,69 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "wait.h" - -#include // NOLINT(build/c++11) - -namespace ruy { - -void Wait(const std::function& condition, const Duration& spin_duration, - std::condition_variable* condvar, std::mutex* mutex) { - // First, trivial case where the `condition` is already true; - if (condition()) { - return; - } - - // Then try busy-waiting. - const TimePoint wait_start = Now(); - while (Now() - wait_start < spin_duration) { - if (condition()) { - return; - } - } - - // Finally, do real passive waiting. - std::unique_lock lock(*mutex); - condvar->wait(lock, condition); -} - -void Wait(const std::function& condition, - std::condition_variable* condvar, std::mutex* mutex) { - // This value was empirically derived with some microbenchmark, we don't have - // high confidence in it. - // - // TODO(b/135595069): make this value configurable at runtime. - // I almost wanted to file another bug to ask for experimenting in a more - // principled way to tune this value better, but this would have to be tuned - // on real end-to-end applications and we'd expect different applications to - // require different tunings. So the more important point is the need for - // this to be controllable by the application. - // - // That this value means that we may be sleeping substantially longer - // than a scheduler timeslice's duration is not necessarily surprising. The - // idea is to pick up quickly new work after having finished the previous - // workload. When it's new work within the same GEMM as the previous work, the - // time interval that we might be busy-waiting is very small, so for that - // purpose it would be more than enough to sleep for 1 ms. - // That is all what we would observe on a GEMM benchmark. However, in a real - // application, after having finished a GEMM, we might do unrelated work for - // a little while, then start on a new GEMM. In that case the wait interval - // may be a little longer. There may also not be another GEMM for a long time, - // in which case we'll end up passively waiting below. - const Duration spin_duration = DurationFromMilliseconds(2); - Wait(condition, spin_duration, condvar, mutex); -} - -} // namespace ruy diff --git a/wait.h b/wait.h deleted file mode 100644 index 67378ff..0000000 --- a/wait.h +++ /dev/null @@ -1,73 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_WAIT_H_ -#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_WAIT_H_ - -#include // NOLINT(build/c++11) -#include -#include // NOLINT(build/c++11) - -#include "time.h" - -namespace ruy { - -// Waits until some evaluation of `condition` has returned true. -// -// There is no guarantee that calling `condition` again after this function -// has returned would still return true. The only -// contract is that at some point during the execution of that function, -// `condition` has returned true. -// -// First does some spin-waiting for the specified `spin_duration`, -// then falls back to passive waiting for the given condvar, guarded -// by the given mutex. At this point it will try to acquire the mutex lock, -// around the waiting on the condition variable. -// Therefore, this function expects that the calling thread hasn't already -// locked the mutex before calling it. -// This function will always release the mutex lock before returning. -// -// The idea of doing some initial spin-waiting is to help get -// better and more consistent multithreading benefits for small GEMM sizes. -// Spin-waiting help ensuring that if we need to wake up soon after having -// started waiting, then we can wake up quickly (as opposed to, say, -// having to wait to be scheduled again by the OS). On the other hand, -// we must still eventually revert to passive waiting for longer waits -// (e.g. worker threads having finished a GEMM and waiting until the next GEMM) -// so as to avoid permanently spinning. -// -// In situations where other threads might have more useful things to do with -// these CPU cores than our spin-waiting, it may be best to reduce the value -// of `spin_duration`. Setting it to zero disables the spin-waiting entirely. -// -// There is a risk that the std::function used here might use a heap allocation -// to store its context. The expected usage pattern is that these functions' -// contexts will consist of a single pointer value (typically capturing only -// [this]), and that in this case the std::function implementation will use -// inline storage, avoiding a heap allocation. However, we can't effectively -// guard that assumption, and that's not a big concern anyway because the -// latency of a small heap allocation is probably low compared to the intrinsic -// latency of what this Wait function does. -void Wait(const std::function& condition, const Duration& spin_duration, - std::condition_variable* condvar, std::mutex* mutex); - -// Convenience overload using a default `spin_duration`. -// TODO(benoitjacob): let this be controlled from the ruy API. -void Wait(const std::function& condition, - std::condition_variable* condvar, std::mutex* mutex); - -} // namespace ruy - -#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_WAIT_H_ diff --git a/wait_test.cc b/wait_test.cc deleted file mode 100644 index 41816c4..0000000 --- a/wait_test.cc +++ /dev/null @@ -1,117 +0,0 @@ -/* Copyright 2019 Google LLC. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "wait.h" - -#include -#include // NOLINT(build/c++11) -#include // NOLINT(build/c++11) -#include // NOLINT(build/c++11) - -#include "testing/base/public/gunit.h" -#include "platform.h" - -namespace ruy { -namespace { - -// Thread taking a `value` atomic counter and incrementing it until it equals -// `end_value`, then notifying the condition variable as long as -// `value == end_value`. If `end_value` is increased, it will then resume -// incrementing `value`, etc. Terminates if `end_value == -1`. -class ThreadCountingUpToValue { - public: - ThreadCountingUpToValue(const std::atomic& end_value, - std::atomic* value, - std::condition_variable* condvar, std::mutex* mutex) - : end_value_(end_value), - value_(value), - condvar_(condvar), - mutex_(mutex) {} - void operator()() { - // end_value_==-1 is how the master thread will tell us it's OK to terminate - while (end_value_.load() != -1) { - // wait until end_value is set to a higher value - while (value_->load() == end_value_.load()) { - } - // increment value as long as it's lower than end_value - while (value_->fetch_add(1) < end_value_.load() - 1) { - } - // when value has reached end_value, notify the master thread. - while (value_->load() == end_value_.load()) { - std::lock_guard lock(*mutex_); - condvar_->notify_all(); - } - } - } - - private: - const std::atomic& end_value_; - std::atomic* value_; - std::condition_variable* condvar_; - std::mutex* mutex_; -}; - -void WaitTest(const Duration& spin_duration, const Duration& delay) { -#if RUY_PLATFORM(EMSCRIPTEN) - // b/139927184, std::thread constructor raises exception - return; -#endif - std::condition_variable condvar; - std::mutex mutex; - std::atomic value(0); - std::atomic end_value(0); - ThreadCountingUpToValue thread_callable(end_value, &value, &condvar, &mutex); - std::thread thread(thread_callable); - std::this_thread::sleep_for(delay); - for (int i = 1; i < 10; i++) { - end_value.store(1000 * i); - const auto& condition = [&value, &end_value]() { - return value.load() == end_value.load(); - }; - ruy::Wait(condition, spin_duration, &condvar, &mutex); - EXPECT_EQ(value.load(), end_value.load()); - } - end_value.store(-1); - thread.join(); -} - -TEST(WaitTest, WaitTestNoSpin) { - WaitTest(DurationFromSeconds(0), DurationFromSeconds(0)); -} - -TEST(WaitTest, WaitTestSpinOneMicrosecond) { - WaitTest(DurationFromSeconds(1e-6), DurationFromSeconds(0)); -} - -TEST(WaitTest, WaitTestSpinOneMillisecond) { - WaitTest(DurationFromSeconds(1e-3), DurationFromSeconds(0)); -} - -TEST(WaitTest, WaitTestSpinOneSecond) { - WaitTest(DurationFromSeconds(1), DurationFromSeconds(0)); -} - -// Testcase to consistently reproduce the hang in b/139062384. -TEST(WaitTest, WaitTestNoSpinWithDelayBug139062384) { - WaitTest(DurationFromSeconds(0), DurationFromSeconds(1)); -} - -} // namespace -} // namespace ruy - -int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); -} -- cgit v1.2.3