Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'intgemm/callbacks/implementations.inl')
-rw-r--r--intgemm/callbacks/implementations.inl197
1 files changed, 197 insertions, 0 deletions
diff --git a/intgemm/callbacks/implementations.inl b/intgemm/callbacks/implementations.inl
new file mode 100644
index 0000000..47d2aa4
--- /dev/null
+++ b/intgemm/callbacks/implementations.inl
@@ -0,0 +1,197 @@
+/* This file is included multiple times, once per architecture. */
+#if defined(CALLBACKS_THIS_IS_SSE2)
+ #define CPU_NAME SSE2
+ #define CPU_ATTR INTGEMM_SSE2
+#elif defined(CALLBACKS_THIS_IS_AVX2)
+ #define CPU_NAME AVX2
+ #define CPU_ATTR INTGEMM_AVX2
+#elif defined(CALLBACKS_THIS_IS_AVX512BW)
+ #define CPU_NAME AVX512BW
+ #define CPU_ATTR INTGEMM_AVX512BW
+#else
+ #error "Only SSE2, AVX2 and AVX512BW are supported"
+#endif
+
+#if defined(CALLBACKS_THIS_IS_SSE2)
+ #define vi vector_t<CPUType::SSE2, int>
+ #define vf vector_t<CPUType::SSE2, float>
+ #define vd vector_t<CPUType::SSE2, double>
+#else
+ #define vi vector_t<CPUType::AVX2, int>
+ #define vf vector_t<CPUType::AVX2, float>
+ #define vd vector_t<CPUType::AVX2, double>
+#endif
+
+namespace intgemm {
+namespace callbacks {
+
+template <CPUType CpuType, typename CallbackConfig>
+class CallbackImpl;
+
+}}
+
+/*
+ * Callbacks implementations....
+ */
+namespace intgemm {
+namespace callbacks {
+
+/*
+ * Sequence
+ */
+template <typename... Configs>
+class CallbackImpl<CPUType::CPU_NAME, std::tuple<Configs...>> {
+public:
+ CPU_ATTR CallbackImpl(const std::tuple<Configs...>& configs) : callbacks(init_callbacks(configs, make_sequence<sizeof...(Configs)>())) {}
+
+ CPU_ATTR void operator()(vi input, const OutputBufferInfo& info) {
+ run_callbacks(input, info, callbacks, make_sequence<sizeof...(Configs)>());
+ }
+
+private:
+ using CallbacksTupleType = std::tuple<CallbackImpl<CPUType::CPU_NAME, Configs>...>;
+
+ CallbacksTupleType callbacks;
+
+ template <unsigned... Indices>
+ CallbacksTupleType init_callbacks(const std::tuple<Configs...>& configs, sequence<Indices...>) {
+ return std::make_tuple(CallbackImpl<CPUType::CPU_NAME, typename std::tuple_element<Indices, std::tuple<Configs...>>::type>(std::get<Indices>(configs))...);
+ }
+
+#define RUN_CALLBACKS_PIPELINE_IMPL(vtype) \
+ template <unsigned FirstIndex> \
+ CPU_ATTR static inline void run_callbacks(vtype input, const OutputBufferInfo& info, CallbacksTupleType& tuple, sequence<FirstIndex>) { \
+ std::get<FirstIndex>(tuple)(input, info); \
+ } \
+ template <unsigned FirstIndex, unsigned SecondIndex, unsigned... RestIndices> \
+ CPU_ATTR static inline void run_callbacks(vtype input, const OutputBufferInfo& info, CallbacksTupleType& tuple, sequence<FirstIndex, SecondIndex, RestIndices...>) { \
+ auto output = std::get<FirstIndex>(tuple)(input, info); \
+ run_callbacks(output, info, tuple, sequence<SecondIndex, RestIndices...>()); \
+ }
+
+ RUN_CALLBACKS_PIPELINE_IMPL(vi)
+ RUN_CALLBACKS_PIPELINE_IMPL(vf)
+ RUN_CALLBACKS_PIPELINE_IMPL(vd)
+
+#undef RUN_CALLBACKS_PIPELINE_IMPL
+};
+
+/*
+ * Dummy
+ */
+template <> class CallbackImpl<CPUType::CPU_NAME, Dummy> {
+public:
+ CPU_ATTR CallbackImpl(const Dummy&) {}
+ CPU_ATTR void operator()(vi, const OutputBufferInfo&) {}
+};
+
+/*
+ * Write
+ */
+template <typename Type>
+class CallbackImpl<CPUType::CPU_NAME, Write<Type>> {
+public:
+ CPU_ATTR CallbackImpl(const Write<Type>& config) : config(config) {}
+
+ CPU_ATTR void operator()(vector_t<CPUType::CPU_NAME, Type> input, const OutputBufferInfo& info) {
+ kernels::write(input, config.output_addr, info.row_idx * info.cols + info.col_idx);
+ }
+
+private:
+ Write<Type> config;
+};
+
+/*
+ * Unquantize
+ */
+template <> class CallbackImpl<CPUType::CPU_NAME, Unquantize> {
+public:
+ CPU_ATTR CallbackImpl(const Unquantize& config) : config(config) {
+ unquant_mult = set1_ps<vf>(config.unquant_mult);
+ }
+
+ CPU_ATTR vf operator()(vi input, const OutputBufferInfo&) {
+ return kernels::unquantize(input, unquant_mult);
+ }
+
+private:
+ vf unquant_mult;
+ Unquantize config;
+};
+
+/*
+ * UnquantizeAndWrite
+ */
+template <> class CallbackImpl<CPUType::CPU_NAME, UnquantizeAndWrite> {
+public:
+ CPU_ATTR CallbackImpl(const UnquantizeAndWrite& config) : config(config) {
+ unquant_mult = set1_ps<vf>(config.unquant_mult);
+ }
+
+ CPU_ATTR void operator()(vi input, const OutputBufferInfo& info) {
+ // Workaround gcc 5 internal compiler error that can't read register members in debug.
+ vf mult_reg;
+#if !defined(__OPTIMIZE__) && (__GNUC__ == 5) && !defined(__clang__) && !defined(__INTEL_COMPILER)
+ asm ("vmovdqa %1, %0" : "=x" (mult_reg) : "m" (unquant_mult));
+#else
+ mult_reg = unquant_mult;
+#endif
+ auto result = kernels::unquantize(input, mult_reg);
+ kernels::write(result, config.output_addr, info.row_idx * info.cols + info.col_idx);
+ }
+
+private:
+ vf unquant_mult;
+ UnquantizeAndWrite config;
+};
+
+/*
+ * AddBiasAndWrite
+ */
+template <> class CallbackImpl<CPUType::CPU_NAME, AddBiasAndWrite> {
+public:
+ CPU_ATTR CallbackImpl(const AddBiasAndWrite& config) : config(config) {}
+
+ CPU_ATTR void operator()(vi input, const OutputBufferInfo& info) {
+ auto result = kernels::add_bias(input, config.bias_addr, info.col_idx);
+ kernels::write(result, config.output_addr, info.row_idx * info.cols + info.col_idx);
+ }
+
+private:
+ AddBiasAndWrite config;
+};
+
+/*
+ * UnquantizeAndAddBiasAndWrite
+ */
+template <> class CallbackImpl<CPUType::CPU_NAME, UnquantizeAndAddBiasAndWrite> {
+public:
+ CPU_ATTR CallbackImpl(const UnquantizeAndAddBiasAndWrite& config) : config(config) {
+ unquant_mult = set1_ps<vf>(config.unquant_mult);
+ }
+
+ CPU_ATTR void operator()(vi input, const OutputBufferInfo& info) {
+ // Workaround gcc 5 internal compiler error that can't read register members in debug.
+ vf mult_reg;
+#if !defined(__OPTIMIZE__) && (__GNUC__ == 5) && !defined(__clang__) && !defined(__INTEL_COMPILER)
+ asm ("vmovdqa %1, %0" : "=x" (mult_reg) : "m" (unquant_mult));
+#else
+ mult_reg = unquant_mult;
+#endif
+ auto result = kernels::unquantize(input, mult_reg);
+ result = kernels::add_bias(result, config.bias_addr, info.col_idx);
+ kernels::write(result, config.output_addr, info.row_idx * info.cols + info.col_idx);
+ }
+private:
+ vf unquant_mult;
+ UnquantizeAndAddBiasAndWrite config;
+};
+
+}
+}
+
+#undef CPU_NAME
+#undef CPU_ATTR
+#undef vi
+#undef vf
+#undef vd