Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMateusz Chudyk <mateuszchudyk@gmail.com>2019-07-09 16:20:51 +0300
committerMateusz Chudyk <mateuszchudyk@gmail.com>2019-07-09 22:39:50 +0300
commit6d2dc039394300b29ef463bef3e3338fb7f68bd4 (patch)
tree116d2cd13e7ff561358088a1a11d09a9c23df4c4
parente9f616514ca31d96cd5eb55ad07d35efc4b1fad0 (diff)
Add UnquantizeAndWrite callback
-rw-r--r--benchmark.cc4
-rw-r--r--callbacks/configs.h7
-rw-r--r--callbacks/implementations.inl31
-rw-r--r--example.cc4
-rw-r--r--test/multiply_test.cc2
5 files changed, 41 insertions, 7 deletions
diff --git a/benchmark.cc b/benchmark.cc
index 669ea77..ad546cc 100644
--- a/benchmark.cc
+++ b/benchmark.cc
@@ -79,10 +79,10 @@ template <class Backend> void Run(const RandomMatrices &m, std::vector<uint64_t>
Backend::PrepareB(m.B.begin(), B_prepared.begin(), quant_mult, m.width, m.B_cols);
AlignedVector<float> output(m.A_rows * m.B_cols);
// Burn in
- Backend::Multiply(A_prepared.begin(), B_prepared.begin(), m.A_rows, m.width, m.B_cols, callbacks::Dummy());
+ Backend::Multiply(A_prepared.begin(), B_prepared.begin(), m.A_rows, m.width, m.B_cols, callbacks::UnquantizeAndWrite(unquant_mult, output.begin()));
{
StopWatch w(stats);
- Backend::Multiply(A_prepared.begin(), B_prepared.begin(), m.A_rows, m.width, m.B_cols, callbacks::Dummy());
+ Backend::Multiply(A_prepared.begin(), B_prepared.begin(), m.A_rows, m.width, m.B_cols, callbacks::UnquantizeAndWrite(unquant_mult, output.begin()));
}
}
diff --git a/callbacks/configs.h b/callbacks/configs.h
index fae60bd..a9836f4 100644
--- a/callbacks/configs.h
+++ b/callbacks/configs.h
@@ -6,5 +6,12 @@ namespace callbacks {
struct Dummy {
};
+struct UnquantizeAndWrite {
+ float unquant_mult;
+ float* addr;
+
+ UnquantizeAndWrite(float unquant_mult, float* addr) : unquant_mult(unquant_mult), addr(addr) {}
+};
+
}
}
diff --git a/callbacks/implementations.inl b/callbacks/implementations.inl
index 01195f1..4e3b81b 100644
--- a/callbacks/implementations.inl
+++ b/callbacks/implementations.inl
@@ -1,6 +1,7 @@
#include "callbacks/configs.h"
#include "intrinsics.h"
+#include "kernels.h"
#include "types.h"
#include "vec_traits.h"
@@ -25,9 +26,15 @@
#define dvd dvector_t<CPUType::CPU_NAME, double>
#if defined(THIS_IS_SSE2)
-#define vinput dvector_t<CPUType::SSE2, int>
+ #define vinput dvector_t<CPUType::SSE2, int>
+ #define vinput_i vector_t<CPUType::SSE2, int>
+ #define vinput_f vector_t<CPUType::SSE2, float>
+ #define vinput_d vector_t<CPUType::SSE2, double>
#else
-#define vinput vector_t<CPUType::AVX2, int>
+ #define vinput vector_t<CPUType::AVX2, int>
+ #define vinput_i vector_t<CPUType::AVX2, int>
+ #define vinput_f vector_t<CPUType::AVX2, float>
+ #define vinput_d vector_t<CPUType::AVX2, double>
#endif
namespace intgemm {
@@ -53,6 +60,23 @@ public:
CPU_ATTR void operator()(vinput, Index, Index, Index, Index, Index) {}
};
+/*
+ * UnquantizeAndWrite
+ */
+template <> class CallbackImpl<UnquantizeAndWrite, CPUType::CPU_NAME> {
+public:
+ CPU_ATTR CallbackImpl(const UnquantizeAndWrite& config) : config(config) {
+ unquant_mult = set1_ps<vinput_f>(config.unquant_mult);
+ }
+ CPU_ATTR void operator()(vinput input, Index A_rowidx, Index B_colidx, Index A_rows, Index width, Index B_cols) {
+ auto result = kernels::unquantize(input, unquant_mult);
+ kernels::write(result, config.addr, A_rowidx * B_cols + B_colidx);
+ }
+private:
+ UnquantizeAndWrite config;
+ vinput_f unquant_mult;
+};
+
}
}
@@ -65,3 +89,6 @@ public:
#undef dvf
#undef dvd
#undef vinput
+#undef vinput_i
+#undef vinput_f
+#undef vinput_d
diff --git a/example.cc b/example.cc
index 3e5cce7..b4407ff 100644
--- a/example.cc
+++ b/example.cc
@@ -52,7 +52,7 @@ int main() {
AlignedVector<float> C(A_rows * B_cols);
// Do the actual multiply.
- intgemm::Int16::Multiply(A_prepared.begin(), B_prepared.begin(), A_rows, width, B_cols, intgemm::callbacks::Dummy());
+ intgemm::Int16::Multiply(A_prepared.begin(), B_prepared.begin(), A_rows, width, B_cols, intgemm::callbacks::UnquantizeAndWrite(1.0 / (quant_mult * quant_mult), C.begin()));
// Sanity check. C will be row major.
assert(fabs(C[0] - top_left_reference) < 0.05);
}
@@ -71,7 +71,7 @@ int main() {
AlignedVector<float> C(A_rows * B_cols);
// Do the actual multiply.
- intgemm::Int8::Multiply(A_prepared.begin(), B_prepared.begin(), A_rows, width, B_cols, intgemm::callbacks::Dummy());
+ intgemm::Int8::Multiply(A_prepared.begin(), B_prepared.begin(), A_rows, width, B_cols, intgemm::callbacks::UnquantizeAndWrite(1.0 / (quant_mult * quant_mult), C.begin()));
// Sanity check. C will be row major.
assert(fabs(C[0] - top_left_reference) < 0.05);
}
diff --git a/test/multiply_test.cc b/test/multiply_test.cc
index e7fcf77..148f618 100644
--- a/test/multiply_test.cc
+++ b/test/multiply_test.cc
@@ -361,7 +361,7 @@ template <class Routine> void TestMultiply(Index A_rows, Index width, Index B_co
Routine::PrepareB(B.begin(), B_prep.begin(), quant_mult, width, B_cols);
AlignedVector<float> test_C(A_rows * B_cols);
- Routine::Multiply(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::Dummy());
+ Routine::Multiply(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndWrite(unquant_mult, test_C.begin()));
AlignedVector<Integer> B_quant(B.size());
Routine::Quantize(B.begin(), B_quant.begin(), quant_mult, B.size());