diff options
author | Mateusz Chudyk <mateuszchudyk@gmail.com> | 2019-07-09 16:20:51 +0300 |
---|---|---|
committer | Mateusz Chudyk <mateuszchudyk@gmail.com> | 2019-07-09 22:39:50 +0300 |
commit | 6d2dc039394300b29ef463bef3e3338fb7f68bd4 (patch) | |
tree | 116d2cd13e7ff561358088a1a11d09a9c23df4c4 | |
parent | e9f616514ca31d96cd5eb55ad07d35efc4b1fad0 (diff) |
Add UnquantizeAndWrite callback
-rw-r--r-- | benchmark.cc | 4 | ||||
-rw-r--r-- | callbacks/configs.h | 7 | ||||
-rw-r--r-- | callbacks/implementations.inl | 31 | ||||
-rw-r--r-- | example.cc | 4 | ||||
-rw-r--r-- | test/multiply_test.cc | 2 |
5 files changed, 41 insertions, 7 deletions
diff --git a/benchmark.cc b/benchmark.cc index 669ea77..ad546cc 100644 --- a/benchmark.cc +++ b/benchmark.cc @@ -79,10 +79,10 @@ template <class Backend> void Run(const RandomMatrices &m, std::vector<uint64_t> Backend::PrepareB(m.B.begin(), B_prepared.begin(), quant_mult, m.width, m.B_cols); AlignedVector<float> output(m.A_rows * m.B_cols); // Burn in - Backend::Multiply(A_prepared.begin(), B_prepared.begin(), m.A_rows, m.width, m.B_cols, callbacks::Dummy()); + Backend::Multiply(A_prepared.begin(), B_prepared.begin(), m.A_rows, m.width, m.B_cols, callbacks::UnquantizeAndWrite(unquant_mult, output.begin())); { StopWatch w(stats); - Backend::Multiply(A_prepared.begin(), B_prepared.begin(), m.A_rows, m.width, m.B_cols, callbacks::Dummy()); + Backend::Multiply(A_prepared.begin(), B_prepared.begin(), m.A_rows, m.width, m.B_cols, callbacks::UnquantizeAndWrite(unquant_mult, output.begin())); } } diff --git a/callbacks/configs.h b/callbacks/configs.h index fae60bd..a9836f4 100644 --- a/callbacks/configs.h +++ b/callbacks/configs.h @@ -6,5 +6,12 @@ namespace callbacks { struct Dummy { }; +struct UnquantizeAndWrite { + float unquant_mult; + float* addr; + + UnquantizeAndWrite(float unquant_mult, float* addr) : unquant_mult(unquant_mult), addr(addr) {} +}; + } } diff --git a/callbacks/implementations.inl b/callbacks/implementations.inl index 01195f1..4e3b81b 100644 --- a/callbacks/implementations.inl +++ b/callbacks/implementations.inl @@ -1,6 +1,7 @@ #include "callbacks/configs.h" #include "intrinsics.h" +#include "kernels.h" #include "types.h" #include "vec_traits.h" @@ -25,9 +26,15 @@ #define dvd dvector_t<CPUType::CPU_NAME, double> #if defined(THIS_IS_SSE2) -#define vinput dvector_t<CPUType::SSE2, int> + #define vinput dvector_t<CPUType::SSE2, int> + #define vinput_i vector_t<CPUType::SSE2, int> + #define vinput_f vector_t<CPUType::SSE2, float> + #define vinput_d vector_t<CPUType::SSE2, double> #else -#define vinput vector_t<CPUType::AVX2, int> + #define vinput vector_t<CPUType::AVX2, int> + #define vinput_i vector_t<CPUType::AVX2, int> + #define vinput_f vector_t<CPUType::AVX2, float> + #define vinput_d vector_t<CPUType::AVX2, double> #endif namespace intgemm { @@ -53,6 +60,23 @@ public: CPU_ATTR void operator()(vinput, Index, Index, Index, Index, Index) {} }; +/* + * UnquantizeAndWrite + */ +template <> class CallbackImpl<UnquantizeAndWrite, CPUType::CPU_NAME> { +public: + CPU_ATTR CallbackImpl(const UnquantizeAndWrite& config) : config(config) { + unquant_mult = set1_ps<vinput_f>(config.unquant_mult); + } + CPU_ATTR void operator()(vinput input, Index A_rowidx, Index B_colidx, Index A_rows, Index width, Index B_cols) { + auto result = kernels::unquantize(input, unquant_mult); + kernels::write(result, config.addr, A_rowidx * B_cols + B_colidx); + } +private: + UnquantizeAndWrite config; + vinput_f unquant_mult; +}; + } } @@ -65,3 +89,6 @@ public: #undef dvf #undef dvd #undef vinput +#undef vinput_i +#undef vinput_f +#undef vinput_d @@ -52,7 +52,7 @@ int main() { AlignedVector<float> C(A_rows * B_cols); // Do the actual multiply. - intgemm::Int16::Multiply(A_prepared.begin(), B_prepared.begin(), A_rows, width, B_cols, intgemm::callbacks::Dummy()); + intgemm::Int16::Multiply(A_prepared.begin(), B_prepared.begin(), A_rows, width, B_cols, intgemm::callbacks::UnquantizeAndWrite(1.0 / (quant_mult * quant_mult), C.begin())); // Sanity check. C will be row major. assert(fabs(C[0] - top_left_reference) < 0.05); } @@ -71,7 +71,7 @@ int main() { AlignedVector<float> C(A_rows * B_cols); // Do the actual multiply. - intgemm::Int8::Multiply(A_prepared.begin(), B_prepared.begin(), A_rows, width, B_cols, intgemm::callbacks::Dummy()); + intgemm::Int8::Multiply(A_prepared.begin(), B_prepared.begin(), A_rows, width, B_cols, intgemm::callbacks::UnquantizeAndWrite(1.0 / (quant_mult * quant_mult), C.begin())); // Sanity check. C will be row major. assert(fabs(C[0] - top_left_reference) < 0.05); } diff --git a/test/multiply_test.cc b/test/multiply_test.cc index e7fcf77..148f618 100644 --- a/test/multiply_test.cc +++ b/test/multiply_test.cc @@ -361,7 +361,7 @@ template <class Routine> void TestMultiply(Index A_rows, Index width, Index B_co Routine::PrepareB(B.begin(), B_prep.begin(), quant_mult, width, B_cols); AlignedVector<float> test_C(A_rows * B_cols); - Routine::Multiply(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::Dummy()); + Routine::Multiply(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndWrite(unquant_mult, test_C.begin())); AlignedVector<Integer> B_quant(B.size()); Routine::Quantize(B.begin(), B_quant.begin(), quant_mult, B.size()); |