Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMateusz Chudyk <mateuszchudyk@gmail.com>2019-07-19 19:20:39 +0300
committerMateusz Chudyk <mateuszchudyk@gmail.com>2019-07-19 19:20:39 +0300
commit721f4802464431dfecbc7c4bed68850f81b7af70 (patch)
tree5dd4d7a7e56ef8ec77a011bf7da7d7f491c547b9
parent87e51cd18a05f503d4f04709ac4121388b206c48 (diff)
Add AddBiasAndWrite callback
-rw-r--r--callbacks/configs.h7
-rw-r--r--callbacks/implementations.inl17
2 files changed, 24 insertions, 0 deletions
diff --git a/callbacks/configs.h b/callbacks/configs.h
index 4a01e2e..8e2eacc 100644
--- a/callbacks/configs.h
+++ b/callbacks/configs.h
@@ -20,6 +20,13 @@ struct UnquantizeAndWrite {
UnquantizeAndWrite(float unquant_mult, float* addr) : unquant_mult(unquant_mult), addr(addr) {}
};
+struct AddBiasAndWrite {
+ const int* bias_addr;
+ int* output_addr;
+
+ AddBiasAndWrite(const int* bias_addr, int* output_addr) : bias_addr(bias_addr), output_addr(output_addr) {}
+};
+
struct UnquantizeAndAddBiasAndWrite {
float unquant_mult;
const float* bias_addr;
diff --git a/callbacks/implementations.inl b/callbacks/implementations.inl
index 773a5f8..f80b2ed 100644
--- a/callbacks/implementations.inl
+++ b/callbacks/implementations.inl
@@ -80,12 +80,29 @@ public:
auto result = kernels::unquantize(input, unquant_mult);
kernels::write(result, config.addr, info.row_idx * info.cols + info.col_idx);
}
+
private:
UnquantizeAndWrite config;
vf unquant_mult;
};
/*
+ * AddBiasAndWrite
+ */
+template <> class CallbackImpl<CPUType::CPU_NAME, AddBiasAndWrite> {
+public:
+ CPU_ATTR CallbackImpl(const AddBiasAndWrite& config) : config(config) {}
+
+ CPU_ATTR void operator()(vi input, const OutputBufferInfo& info) {
+ auto result = kernels::add_bias(input, config.bias_addr, info.col_idx);
+ kernels::write(result, config.output_addr, info.row_idx * info.cols + info.col_idx);
+ }
+
+private:
+ AddBiasAndWrite config;
+};
+
+/*
* UnquantizeAndAddBiasAndWrite
*/
template <> class CallbackImpl<CPUType::CPU_NAME, UnquantizeAndAddBiasAndWrite> {