diff options
author | Mateusz Chudyk <mateuszchudyk@gmail.com> | 2020-05-12 18:01:05 +0300 |
---|---|---|
committer | Mateusz Chudyk <mateuszchudyk@gmail.com> | 2020-05-12 18:01:05 +0300 |
commit | 1eb105919544d49200d93050c3eb7434f8ae5190 (patch) | |
tree | d049c4cbc1294db89fe5102f2fc8b02a28f3fccd | |
parent | 5babff8f474b677155109146c888dde24d5bbd9b (diff) |
Add AddBias callback
-rw-r--r-- | callbacks/configs.h | 6 | ||||
-rw-r--r-- | callbacks/implementations.inl | 19 |
2 files changed, 25 insertions, 0 deletions
diff --git a/callbacks/configs.h b/callbacks/configs.h index 326d14f..1f6f491 100644 --- a/callbacks/configs.h +++ b/callbacks/configs.h @@ -26,6 +26,12 @@ struct Unquantize { Unquantize(float unquant_mult) : unquant_mult(unquant_mult) {} }; +struct AddBias { + const float* bias_addr; + + AddBias(const float* bias_addr) : bias_addr(bias_addr) {} +}; + struct UnquantizeAndWrite { float unquant_mult; float* output_addr; diff --git a/callbacks/implementations.inl b/callbacks/implementations.inl index 2615c32..2f76f84 100644 --- a/callbacks/implementations.inl +++ b/callbacks/implementations.inl @@ -138,6 +138,25 @@ private: }; /* + * AddBias + */ +template <> class CallbackImpl<CPUType::CPU_NAME, AddBias> { +public: + CPU_ATTR CallbackImpl(const AddBias& config) : config(config) {} + + CPU_ATTR vf operator()(vf input, const OutputBufferInfo& info) { + return kernels::add_bias(input, config.bias_addr, info.col_idx); + } + + float operator()(float input, const OutputBufferInfo& info) { + return input + config.bias_addr[info.col_idx]; + } + +private: + AddBias config; +}; + +/* * UnquantizeAndWrite */ template <> class CallbackImpl<CPUType::CPU_NAME, UnquantizeAndWrite> { |