From 721f4802464431dfecbc7c4bed68850f81b7af70 Mon Sep 17 00:00:00 2001 From: Mateusz Chudyk Date: Fri, 19 Jul 2019 17:20:39 +0100 Subject: Add AddBiasAndWrite callback --- callbacks/configs.h | 7 +++++++ callbacks/implementations.inl | 17 +++++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/callbacks/configs.h b/callbacks/configs.h index 4a01e2e..8e2eacc 100644 --- a/callbacks/configs.h +++ b/callbacks/configs.h @@ -20,6 +20,13 @@ struct UnquantizeAndWrite { UnquantizeAndWrite(float unquant_mult, float* addr) : unquant_mult(unquant_mult), addr(addr) {} }; +struct AddBiasAndWrite { + const int* bias_addr; + int* output_addr; + + AddBiasAndWrite(const int* bias_addr, int* output_addr) : bias_addr(bias_addr), output_addr(output_addr) {} +}; + struct UnquantizeAndAddBiasAndWrite { float unquant_mult; const float* bias_addr; diff --git a/callbacks/implementations.inl b/callbacks/implementations.inl index 773a5f8..f80b2ed 100644 --- a/callbacks/implementations.inl +++ b/callbacks/implementations.inl @@ -80,11 +80,28 @@ public: auto result = kernels::unquantize(input, unquant_mult); kernels::write(result, config.addr, info.row_idx * info.cols + info.col_idx); } + private: UnquantizeAndWrite config; vf unquant_mult; }; +/* + * AddBiasAndWrite + */ +template <> class CallbackImpl { +public: + CPU_ATTR CallbackImpl(const AddBiasAndWrite& config) : config(config) {} + + CPU_ATTR void operator()(vi input, const OutputBufferInfo& info) { + auto result = kernels::add_bias(input, config.bias_addr, info.col_idx); + kernels::write(result, config.output_addr, info.row_idx * info.cols + info.col_idx); + } + +private: + AddBiasAndWrite config; +}; + /* * UnquantizeAndAddBiasAndWrite */ -- cgit v1.2.3