Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2018-04-11 06:46:29 +0300
committerMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2018-04-11 06:46:29 +0300
commitafc3bde59710c2d5b8f4b0b4b731752dc2cf9a41 (patch)
tree4b32ed16d45c98516b190ab2c2e0d51ea77a99dd
parentb0622430caab902efd9681ac9582d0f274168764 (diff)
add int16 operators, attempt at memoization
-rw-r--r--src/common/config_parser.cpp3
-rw-r--r--src/graph/chainable.h2
-rw-r--r--src/graph/expression_graph.cpp4
-rw-r--r--src/graph/expression_graph.h10
-rw-r--r--src/graph/expression_operators.cpp32
-rw-r--r--src/graph/expression_operators.h1
-rw-r--r--src/graph/node.cpp10
-rw-r--r--src/graph/node.h11
-rw-r--r--src/graph/node_operators.cpp2
-rw-r--r--src/graph/node_operators.h6
-rw-r--r--src/graph/node_operators_unary.h6
-rw-r--r--src/tensors/cpu/int16.h128
-rw-r--r--src/tensors/cpu/prod.cpp19
-rwxr-xr-xsrc/tensors/cpu/sharp/sse_gemm.h222
-rw-r--r--src/translator/translator.h4
15 files changed, 281 insertions, 179 deletions
diff --git a/src/common/config_parser.cpp b/src/common/config_parser.cpp
index aa93a027..a6fb1633 100644
--- a/src/common/config_parser.cpp
+++ b/src/common/config_parser.cpp
@@ -652,6 +652,8 @@ void ConfigParser::addOptionsTranslate(po::options_description& desc) {
("cpu-threads", po::value<size_t>()->default_value(1),
"Use CPU-based computation with this many independent threads, 0 means GPU-based computation")
#endif
+ ("optimize", po::value<bool>()->zero_tokens()->default_value(false),
+ "Optimize speed aggressively sacrificing memory or precision")
("mini-batch", po::value<int>()->default_value(1),
"Size of mini-batch used during update")
("maxi-batch", po::value<int>()->default_value(1),
@@ -945,6 +947,7 @@ void ConfigParser::parseOptions(int argc, char** argv, bool doValidate) {
SET_OPTION_NONDEFAULT("weights", std::vector<float>);
SET_OPTION_NONDEFAULT("shortlist", std::vector<std::string>);
SET_OPTION("port", size_t);
+ SET_OPTION("optimize", bool);
}
/** valid **/
diff --git a/src/graph/chainable.h b/src/graph/chainable.h
index 9e5d8fb2..281d3653 100644
--- a/src/graph/chainable.h
+++ b/src/graph/chainable.h
@@ -72,7 +72,9 @@ struct Chainable {
// virtual const std::string& type() = 0;
virtual Ptr<ExpressionGraph> graph() = 0;
+
virtual const Shape& shape() = 0;
+ virtual const Type& value_type() = 0;
virtual std::vector<Expr>& children() = 0;
virtual Expr child(size_t) = 0;
diff --git a/src/graph/expression_graph.cpp b/src/graph/expression_graph.cpp
index 4a0edb34..e89f45f3 100644
--- a/src/graph/expression_graph.cpp
+++ b/src/graph/expression_graph.cpp
@@ -5,8 +5,8 @@
namespace marian {
-ExpressionGraph::ExpressionGraph(bool inference)
- : inferenceOnly_(inference), backend_(nullptr) {}
+ExpressionGraph::ExpressionGraph(bool inference, bool optimized)
+ : inferenceOnly_(inference), optimized_(optimized), backend_(nullptr) {}
void ExpressionGraph::setDevice(DeviceId deviceId) {
if(!backend_) {
diff --git a/src/graph/expression_graph.h b/src/graph/expression_graph.h
index c5f12335..c9f2b245 100644
--- a/src/graph/expression_graph.h
+++ b/src/graph/expression_graph.h
@@ -38,6 +38,8 @@ private:
std::unordered_map<size_t, std::vector<WExpr>> hashMap_;
bool inferenceOnly_{false};
+ bool optimized_{false};
+
bool reloaded_{false};
std::string namespace_;
@@ -53,9 +55,10 @@ public:
*
* Constructor should be used as New<ExpressionGraph>()
*/
- ExpressionGraph(bool inference = false);
+ ExpressionGraph(bool inference = false, bool optimized = false);
void setInference(bool inference) { inferenceOnly_ = inference; }
+ bool isInference() { return inferenceOnly_; }
~ExpressionGraph() {
clear();
@@ -63,9 +66,14 @@ public:
}
void setDevice(DeviceId deviceId = {0, DeviceType::gpu});
+
DeviceId getDevice() { return backend_->getDevice(); }
+
Ptr<Backend> getBackend() { return backend_; }
+ void setOptimized(bool optimized) { optimized_ = optimized; }
+ bool isOptimized() { return (optimized_ && inferenceOnly_); }
+
void switchParams(const std::string& newNamespace) {
namespace_ = newNamespace;
}
diff --git a/src/graph/expression_operators.cpp b/src/graph/expression_operators.cpp
index a4a8b079..5f649318 100644
--- a/src/graph/expression_operators.cpp
+++ b/src/graph/expression_operators.cpp
@@ -5,6 +5,8 @@
#include "graph/node_operators_binary.h"
#include "graph/node_operators_unary.h"
+#include "tensors/cpu/int16.h"
+
namespace marian {
Expr debug(Expr a, const std::string& message) {
@@ -200,13 +202,36 @@ Expr weighted_average(Expr in, Expr weights, keywords::axis_k ax) {
}
Expr dot(Expr a, Expr b, bool transA, bool transB, float scalar) {
- return Expression<DotNodeOp>(a, b, transA, transB, scalar);
+ auto device = a->graph()->getDevice().type;
+ if(a->graph()->isOptimized() && device == DeviceType::cpu) {
+ // dotInt16 computes A * B.T, hence the transpose for B to get A * B
+ // if transA = false and transB = false.
+ return cpu::int16::dot(cpu::int16::quantize(transA ? transpose(a) : a),
+ cpu::int16::quantize(transB ? b : transpose(b)),
+ scalar);
+ }
+ else {
+ return Expression<DotNodeOp>(a, b, transA, transB, scalar);
+ }
}
Expr bdot(Expr a, Expr b, bool transA, bool transB, float scalar) {
return Expression<DotBatchedNodeOp>(a, b, transA, transB, scalar);
}
+Expr affine(Expr a, Expr b, Expr bias, bool transA, bool transB, float scalar) {
+ auto device = a->graph()->getDevice().type;
+ if(a->graph()->isOptimized() && device == DeviceType::cpu) {
+ return cpu::int16::affine(cpu::int16::quantize(transA ? transpose(a) : a),
+ cpu::int16::quantize(transB ? b : transpose(b)),
+ bias, scalar);
+ }
+ else {
+ std::vector<Expr> nodes = {a, b, bias};
+ return Expression<AffineNodeOp>(nodes, transA, transB, scalar);
+ }
+}
+
Expr transpose(Expr a) {
std::vector<int> axes(a->shape().size());
for(int i = 0; i < axes.size(); ++i) {
@@ -237,11 +262,6 @@ Expr cross_entropy(Expr a, Expr b) {
return Expression<CrossEntropyNodeOp>(a, b);
}
-Expr affine(Expr a, Expr b, Expr c, bool transA, bool transB, float scalar) {
- std::vector<Expr> nodes = {a, b, c};
- return Expression<AffineNodeOp>(nodes, transA, transB, scalar);
-}
-
Expr plus(const std::vector<Expr>&) {
ABORT("Not implemented");
}
diff --git a/src/graph/expression_operators.h b/src/graph/expression_operators.h
index c637105f..070ee0ba 100644
--- a/src/graph/expression_operators.h
+++ b/src/graph/expression_operators.h
@@ -63,6 +63,7 @@ Expr dot(Expr a,
bool transA = false,
bool transB = false,
float scalar = 1.f);
+
Expr bdot(Expr a,
Expr b,
bool transA = false,
diff --git a/src/graph/node.cpp b/src/graph/node.cpp
index 12399f53..66c32d39 100644
--- a/src/graph/node.cpp
+++ b/src/graph/node.cpp
@@ -7,7 +7,7 @@ namespace marian {
size_t Node::allocate() {
size_t elements = 0;
if(!val_) {
- graph()->tensor(val_, shape_);
+ graph()->tensor(val_, shape_, value_type_);
elements = val_->shape().elements();
}
return elements;
@@ -24,15 +24,15 @@ void Node::free() {
void Node::init_dependent() {
if(!adj_) {
- graph()->tensor(adj_, shape_);
- adj_->set(1);
+ graph()->tensor(adj_, shape_, value_type_);
+ adj_->set(1.f);
}
}
void Node::set_zero_adjoint() {
if(!adj_) {
- graph()->tensor(adj_, shape_);
- adj_->set(0);
+ graph()->tensor(adj_, shape_, value_type_);
+ adj_->set(0.f);
}
}
diff --git a/src/graph/node.h b/src/graph/node.h
index 15f223aa..6cc1595d 100644
--- a/src/graph/node.h
+++ b/src/graph/node.h
@@ -23,6 +23,8 @@ protected:
Weak<ExpressionGraph> graph_;
Shape shape_{1, 1, 1, 1};
+ Type value_type_{Type::float32};
+
std::string name_{"none"};
Tensor val_{nullptr};
@@ -32,8 +34,8 @@ protected:
std::string debugMessage_;
public:
- Node(Ptr<ExpressionGraph> graph, Shape shape)
- : graph_(graph), shape_(shape) {}
+ Node(Ptr<ExpressionGraph> graph, Shape shape, Type value_type = Type::float32)
+ : graph_(graph), shape_(shape), value_type_(value_type) {}
virtual ~Node() {
if(destroy_) {
@@ -99,6 +101,7 @@ public:
virtual Tensor& grad() { return adj_; };
virtual const Shape& shape() { return shape_; }
+ virtual const Type& value_type() { return value_type_; }
void set_name(const std::string& name) { name_ = name; }
@@ -139,8 +142,8 @@ public:
struct NaryNodeOp : public Node {
size_t hash_{0};
- NaryNodeOp(const std::vector<Expr>& nodes, Shape shape)
- : Node(nodes.front()->graph(), shape) {
+ NaryNodeOp(const std::vector<Expr>& nodes, Shape shape, Type value_type = Type::float32)
+ : Node(nodes.front()->graph(), shape, value_type) {
children_.resize(nodes.size());
for(int i = 0; i < nodes.size(); ++i)
children_[i] = nodes[i];
diff --git a/src/graph/node_operators.cpp b/src/graph/node_operators.cpp
index 146da4b1..4125d4d1 100644
--- a/src/graph/node_operators.cpp
+++ b/src/graph/node_operators.cpp
@@ -7,7 +7,6 @@
namespace marian {
size_t ConstantNode::allocate() {
- // @TODO params
size_t elements = 0;
if(!val_) {
graph()->tensor(val_, shape_);
@@ -25,7 +24,6 @@ void ConstantNode::init() {
}
size_t ParamNode::allocate() {
- // @TODO params
size_t elements = 0;
if(!val_) {
graph()->tensor(val_, shape_);
diff --git a/src/graph/node_operators.h b/src/graph/node_operators.h
index 7626073d..c53cfe1b 100644
--- a/src/graph/node_operators.h
+++ b/src/graph/node_operators.h
@@ -10,7 +10,7 @@ struct ConstantNode : public Node {
ConstantNode(Ptr<ExpressionGraph> graph,
const Shape& shape,
const NodeInitializer& init)
- : Node(graph, shape),
+ : Node(graph, shape), // TODO: add value_type
init_(new NodeInitializer(init)),
initialized_(false) {
setTrainable(false);
@@ -28,7 +28,7 @@ struct ConstantNode : public Node {
const std::string color() { return "white"; }
virtual size_t hash() {
- std::size_t seed = boost::hash<std::string>()(name());
+ std::size_t seed = boost::hash<std::string>()(name()); // TODO: add value_type
boost::hash_combine(seed, type());
boost::hash_combine(seed, this);
return seed;
@@ -49,7 +49,7 @@ struct ParamNode : public Node {
const Shape& shape,
const NodeInitializer& init,
bool fixed = false)
- : Node(graph, shape),
+ : Node(graph, shape), // TODO: add value_type
init_(new NodeInitializer(init)),
initialized_(false) {
setTrainable(!fixed);
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h
index 8ab249a8..dc4015b2 100644
--- a/src/graph/node_operators_unary.h
+++ b/src/graph/node_operators_unary.h
@@ -12,9 +12,11 @@
namespace marian {
struct UnaryNodeOp : public NaryNodeOp {
- UnaryNodeOp(Expr a, Shape shape) : NaryNodeOp({a}, shape) {}
+ UnaryNodeOp(Expr a, Shape shape, Type value_type = Type::float32)
+ : NaryNodeOp({a}, shape, value_type) {}
- UnaryNodeOp(Expr a) : NaryNodeOp({a}, a->shape()) {}
+ UnaryNodeOp(Expr a, Type value_type = Type::float32)
+ : NaryNodeOp({a}, a->shape(), value_type) {}
const std::string color() { return "yellow"; }
};
diff --git a/src/tensors/cpu/int16.h b/src/tensors/cpu/int16.h
new file mode 100644
index 00000000..aca49e17
--- /dev/null
+++ b/src/tensors/cpu/int16.h
@@ -0,0 +1,128 @@
+#pragma once
+
+#include "graph/node.h"
+#include "tensors/cpu/sharp/sse_gemm.h"
+
+namespace marian {
+namespace cpu {
+namespace int16 {
+
+struct QuantizeNodeOp : public UnaryNodeOp {
+ QuantizeNodeOp(Expr a) : UnaryNodeOp(a, Type::int16) {}
+
+ NodeOps forwardOps() {
+ return {
+ NodeOp(Quantize(val_, child(0)->val()))
+ };
+ }
+
+ NodeOps backwardOps() {
+ ABORT("Only used for inference");
+ return {NodeOp()};
+ }
+
+ const std::string type() { return "quantizeInt16"; }
+};
+
+class DotNodeOp : public NaryNodeOp {
+private:
+ float scalar_;
+
+public:
+ DotNodeOp(Expr a, Expr b, float scalar)
+ : NaryNodeOp({a, b}, newShape(a, b)),
+ scalar_(scalar) {}
+
+ Shape newShape(Expr a, Expr b) {
+ auto shapeA = a->shape();
+ auto shapeB = b->shape();
+
+ // Computing A * B^T
+ shapeB.set(-2, b->shape()[-1]);
+ shapeB.set(-1, b->shape()[-2]);
+
+ Shape outShape = shapeA;
+ outShape.set(-1, shapeB[-1]);
+ ABORT_IF(shapeA[-1] != shapeB[-2],
+ "matrix product requires dimensions to match");
+ return outShape;
+ }
+
+ NodeOps forwardOps() {
+ return {
+ NodeOp(ProdInt(val_,
+ child(0)->val(),
+ child(1)->val(),
+ scalar_))
+ };
+ }
+
+ NodeOps backwardOps() {
+ ABORT("Only used for inference");
+ return {NodeOp()};
+ }
+
+ const std::string type() { return "dotInt16"; }
+};
+
+
+class AffineNodeOp : public NaryNodeOp {
+private:
+ float scalar_;
+
+public:
+ AffineNodeOp(const std::vector<Expr>& nodes,
+ float scalar)
+ : NaryNodeOp(nodes, newShape(nodes[0], nodes[1])),
+ scalar_(scalar) {}
+
+ Shape newShape(Expr a, Expr b) {
+ auto shapeA = a->shape();
+ auto shapeB = b->shape();
+
+ // Computing A * B^T
+ shapeB.set(-2, b->shape()[-1]);
+ shapeB.set(-1, b->shape()[-2]);
+
+ Shape outShape = shapeA;
+ outShape.set(-1, shapeB[-1]);
+ ABORT_IF(shapeA[-1] != shapeB[-2],
+ "matrix product requires dimensions to match");
+ return outShape;
+ }
+
+ NodeOps forwardOps() {
+ return {
+ NodeOp(ProdInt(val_,
+ child(0)->val(),
+ child(1)->val(),
+ scalar_);
+ AddBias(val_, child(2)->val()))
+ };
+ }
+
+ NodeOps backwardOps() {
+ ABORT("Only used for inference");
+ return {NodeOp()};
+ }
+
+ const std::string type() { return "affineInt16"; }
+};
+
+static inline Expr dot(Expr a, Expr b, float scalar) {
+ return Expression<cpu::int16::DotNodeOp>(a, b, scalar);
+}
+
+static inline Expr affine(Expr a, Expr b, Expr c, float scalar) {
+ std::vector<Expr> nodes = {a, b, c};
+ return Expression<cpu::int16::AffineNodeOp>(nodes, scalar);
+}
+
+static inline Expr quantize(Expr a) {
+ return Expression<cpu::int16::QuantizeNodeOp>(a);
+}
+
+
+}
+}
+}
diff --git a/src/tensors/cpu/prod.cpp b/src/tensors/cpu/prod.cpp
index c0a81eb2..cc2a5b58 100644
--- a/src/tensors/cpu/prod.cpp
+++ b/src/tensors/cpu/prod.cpp
@@ -53,10 +53,6 @@ void Prod(marian::Tensor C,
bool transB,
float beta,
float scalar) {
- //if(B->type() == Type::int16) {
- // ProdInt(C, A, B, transA, transB, beta, scalar);
- // return;
- //}
#if BLAS_FOUND
float alpha = scalar;
@@ -168,19 +164,8 @@ void ProdWithBias(marian::Tensor C,
bool transB,
float beta,
float scalar) {
-
- //if(B->type() == Type::int16) {
- // ProdIntWithBias(C, A, B, bias, transA, transB, beta, scalar);
- //}
- //else {
- cpu::Prod(C, A, B, transA, transB, beta, scalar);
- //cpu::Add(functional::_1, 1.f, C, bias);
- SSE_AddBias(C->data(),
- C->data(),
- bias->data(),
- C->shape().elements() / C->shape()[-1],
- C->shape()[-1]);
- //}
+ cpu::Prod(C, A, B, transA, transB, beta, scalar);
+ cpu::int16::AddBias(C, bias);
}
}
diff --git a/src/tensors/cpu/sharp/sse_gemm.h b/src/tensors/cpu/sharp/sse_gemm.h
index fe4a8d26..84f8f908 100755
--- a/src/tensors/cpu/sharp/sse_gemm.h
+++ b/src/tensors/cpu/sharp/sse_gemm.h
@@ -92,12 +92,22 @@ namespace marian {
// graph()->free(temp);
//}
-static inline void Quantize(const float* input,
- __m128i* output,
- float quant_mult,
- int num_rows,
- int width) {
- assert(width % 8 == 0);
+namespace cpu {
+namespace int16 {
+
+const int BITS = 10;
+
+static inline void Quantize(marian::Tensor out,
+ const marian::Tensor in) {
+
+ int num_rows = in->shape().elements() / in->shape()[-1];
+ int width = in->shape()[-1];
+ ABORT_IF(width % 8 != 0, "Width {} is not divisble by 8", width);
+
+ const float* input = in->data();
+ __m128i* output = out->data<__m128i>();
+
+ float quant_mult = pow(2.0, (float)BITS);
int num_input_chunks = width / 8;
@@ -136,15 +146,21 @@ static inline void Quantize(const float* input,
//
// B is typically a weight matrix, so it can be pre-processed offline, and therefore this transpose does not cost anything.
// A is typically an activation minibatch matrix.
-static inline void SSE_MatrixMult(const __m128i* A,
- const __m128i* B,
- float* C,
+static inline void SSE_MatrixMult(marian::Tensor C,
+ const marian::Tensor A,
+ const marian::Tensor B,
float unquant_mult,
- int num_A_rows,
- int num_B_rows,
- int width)
+ float scale)
{
- assert(width % 8 == 0);
+ const __m128i* qA = A->data<__m128i>();
+ const __m128i* qB = B->data<__m128i>();
+ float* fC = C->data();
+
+ int num_A_rows = A->shape().elements() / A->shape()[-1];
+ int num_B_rows = B->shape().elements() / B->shape()[-1];
+ int width = B->shape()[-1];
+
+ ABORT_IF(width % 8 != 0, "Width {} is not divisble by 8", width);
int sse_width = width / 8;
@@ -165,13 +181,13 @@ static inline void SSE_MatrixMult(const __m128i* A,
int i = 0;
for (; i < mult4; i += 4) {
- const __m128i* A1_row = A + (i + 0) * sse_width;
- const __m128i* A2_row = A + (i + 1) * sse_width;
- const __m128i* A3_row = A + (i + 2) * sse_width;
- const __m128i* A4_row = A + (i + 3) * sse_width;
+ const __m128i* A1_row = qA + (i + 0) * sse_width;
+ const __m128i* A2_row = qA + (i + 1) * sse_width;
+ const __m128i* A3_row = qA + (i + 2) * sse_width;
+ const __m128i* A4_row = qA + (i + 3) * sse_width;
for (int j = 0; j < num_B_rows; j++) {
- const __m128i* B_row = B + j * sse_width;
+ const __m128i* B_row = qB + j * sse_width;
__m128i sum1 = _mm_setzero_si128();
__m128i sum2 = _mm_setzero_si128();
@@ -216,10 +232,10 @@ static inline void SSE_MatrixMult(const __m128i* A,
sum4 = _mm_hadd_epi32(sum4, sum4);
sum4 = _mm_hadd_epi32(sum4, sum4);
- float* C1 = C + (i + 0) * num_B_rows + j;
- float* C2 = C + (i + 1) * num_B_rows + j;
- float* C3 = C + (i + 2) * num_B_rows + j;
- float* C4 = C + (i + 3) * num_B_rows + j;
+ float* C1 = fC + (i + 0) * num_B_rows + j;
+ float* C2 = fC + (i + 1) * num_B_rows + j;
+ float* C3 = fC + (i + 2) * num_B_rows + j;
+ float* C4 = fC + (i + 3) * num_B_rows + j;
// Now that we have the full sum in each 32-bit register, we convert them to an integer with _mm_cvtepi32_ps
// and take the first one with _mm_store_ss.
@@ -228,26 +244,26 @@ static inline void SSE_MatrixMult(const __m128i* A,
// loop over the width.
//
// Also note that the memory acceses on C are not consecutive, but this is a tradeoff that we have to make.
- // We can't have consecutive accesses of A, B, *and* C. But we access A and B a lot more so it makes
+ // We can't have consecutive accesses of qA, qB, *and* C. But we access qA and qB a lot more so it makes
// sense to do it this way.
_mm_store_ss(C1, _mm_cvtepi32_ps(sum1));
- *(C1) *= unquant_mult;
+ *(C1) *= unquant_mult * scale;
_mm_store_ss(C2, _mm_cvtepi32_ps(sum2));
- *(C2) *= unquant_mult;
+ *(C2) *= unquant_mult * scale;
_mm_store_ss(C3, _mm_cvtepi32_ps(sum3));
- *(C3) *= unquant_mult;
+ *(C3) *= unquant_mult * scale;
_mm_store_ss(C4, _mm_cvtepi32_ps(sum4));
- *(C4) *= unquant_mult;
+ *(C4) *= unquant_mult * scale;
}
}
if(rest == 1) {
- const __m128i *A1_row = A + (i+0)*sse_width;
+ const __m128i *A1_row = qA + (i+0)*sse_width;
for (int j = 0; j < num_B_rows; j++) {
- const __m128i *B_row = B + j * sse_width;
+ const __m128i *B_row = qB + j * sse_width;
__m128i sum1 = _mm_setzero_si128();
@@ -262,18 +278,18 @@ static inline void SSE_MatrixMult(const __m128i* A,
sum1 = _mm_hadd_epi32(sum1, sum1);
sum1 = _mm_hadd_epi32(sum1, sum1);
- float * C1 = C + (i + 0) * num_B_rows + j;
+ float * C1 = fC + (i + 0) * num_B_rows + j;
_mm_store_ss(C1, _mm_cvtepi32_ps(sum1));
- *(C1) *= unquant_mult;
+ *(C1) *= unquant_mult * scale;
}
}
else if(rest == 2) {
- const __m128i *A1_row = A + (i + 0) * sse_width;
- const __m128i *A2_row = A + (i + 1) * sse_width;
+ const __m128i *A1_row = qA + (i + 0) * sse_width;
+ const __m128i *A2_row = qA + (i + 1) * sse_width;
for (int j = 0; j < num_B_rows; j++) {
- const __m128i *B_row = B + j * sse_width;
+ const __m128i *B_row = qB + j * sse_width;
__m128i sum1 = _mm_setzero_si128();
__m128i sum2 = _mm_setzero_si128();
@@ -293,23 +309,23 @@ static inline void SSE_MatrixMult(const __m128i* A,
sum2 = _mm_hadd_epi32(sum2, sum2);
sum2 = _mm_hadd_epi32(sum2, sum2);
- float * C1 = C + (i+0)*num_B_rows + j;
- float * C2 = C + (i+1)*num_B_rows + j;
+ float * C1 = fC + (i+0)*num_B_rows + j;
+ float * C2 = fC + (i+1)*num_B_rows + j;
_mm_store_ss(C1, _mm_cvtepi32_ps(sum1));
- *(C1) *= unquant_mult;
+ *(C1) *= unquant_mult * scale;
_mm_store_ss(C2, _mm_cvtepi32_ps(sum2));
- *(C2) *= unquant_mult;
+ *(C2) *= unquant_mult * scale;
}
}
else if(rest == 3) {
- const __m128i * A1_row = A + (i+0)*sse_width;
- const __m128i * A2_row = A + (i+1)*sse_width;
- const __m128i * A3_row = A + (i+2)*sse_width;
+ const __m128i * A1_row = qA + (i+0)*sse_width;
+ const __m128i * A2_row = qA + (i+1)*sse_width;
+ const __m128i * A3_row = qA + (i+2)*sse_width;
for (int j = 0; j < num_B_rows; j++) {
- const __m128i * B_row = B + j*sse_width;
+ const __m128i * B_row = qB + j*sse_width;
__m128i sum1 = _mm_setzero_si128();
__m128i sum2 = _mm_setzero_si128();
@@ -334,66 +350,30 @@ static inline void SSE_MatrixMult(const __m128i* A,
sum3 = _mm_hadd_epi32(sum3, sum3);
sum3 = _mm_hadd_epi32(sum3, sum3);
- float * C1 = C + (i+0)*num_B_rows + j;
- float * C2 = C + (i+1)*num_B_rows + j;
- float * C3 = C + (i+2)*num_B_rows + j;
+ float * C1 = fC + (i+0)*num_B_rows + j;
+ float * C2 = fC + (i+1)*num_B_rows + j;
+ float * C3 = fC + (i+2)*num_B_rows + j;
_mm_store_ss(C1, _mm_cvtepi32_ps(sum1));
- *(C1) *= unquant_mult;
+ *(C1) *= unquant_mult * scale;
_mm_store_ss(C2, _mm_cvtepi32_ps(sum2));
- *(C2) *= unquant_mult;
+ *(C2) *= unquant_mult * scale;
_mm_store_ss(C3, _mm_cvtepi32_ps(sum3));
- *(C3) *= unquant_mult;
+ *(C3) *= unquant_mult * scale;
}
}
}
+static void AddBias(marian::Tensor C, const marian::Tensor Bias) {
+ float* y = C->data();
+ const float* x = C->data();
+ const float* bias = Bias->data();
-//// Program takes no input
-//static void ProdInt(marian::Tensor C,
-// const marian::Tensor A,
-// const marian::Tensor B,
-// bool transA,
-// bool transB,
-// float beta,
-// float scalar) {
-//
-// double quant_mult = pow(2.0, 12.0);
-//
-// int width = B->shape()[-2];
-// int num_B_rows = B->shape()[-1];
-//
-// __m128i* quant_B = B->data<__m128i>();
-// assert(width % 8 == 0);
-//
-// assert(transA == false);
-// int num_A_rows = A->shape().elements() / A->shape()[-1];
-//
-// // Each __m128i fits 8 16-bit integers, so we assume the width is a multiple of 8.
-// // We could pad with 0 in the general case.
-//
-// __m128i *quant_A = new __m128i[num_A_rows * width / 8];
-// // The activation matrix must be quantized on-the-fly.
-// Quantize(A->data(), quant_A, (float)quant_mult, num_A_rows, width);
-//
-// // If we quantize to n bits and then multiple the values together, the result will be quantized to n^2 bits.
-// // So we must divide by 1.0/(n^2) to get back the original value.
-// double unquant_mult = 1.0 / (quant_mult * quant_mult);
-//
-// SSE_MatrixMult(quant_A, quant_B, C->data(),
-// (float)unquant_mult,
-// num_A_rows,
-// num_B_rows,
-// width);
-//
-// //std::cerr << C->debug() << std::endl;
-//
-// delete[] quant_A;
-//}
+ int m = C->shape().elements() / C->shape()[-1];
+ int n = C->shape()[-1];
-static void SSE_AddBias(float* y, const float* x, const float* bias, int m, int n) {
for(int j = 0; j < m; ++j) {
int i = 0;
for (; i < n; i += 4) {
@@ -408,50 +388,22 @@ static void SSE_AddBias(float* y, const float* x, const float* bias, int m, int
}
}
+static void ProdInt(marian::Tensor C,
+ const marian::Tensor A,
+ const marian::Tensor B,
+ float scale) {
+
+ // @TODO: make this a parameter
+ float quant_mult = pow(2.0, (float)BITS);
+
+ // If we quantize to n bits and then multiple the values together, the result will be quantized to n^2 bits.
+ // So we must divide by 1.0/(n^2) to get back the original value.
+ float unquant_mult = 1.0 / (quant_mult * quant_mult);
+
+ SSE_MatrixMult(C, A, B, unquant_mult, scale);
+}
+
+}
+}
-//static void ProdIntWithBias(marian::Tensor C,
-// const marian::Tensor A,
-// const marian::Tensor B,
-// const marian::Tensor bias,
-// bool transA,
-// bool transB,
-// float beta,
-// float scalar) {
-//
-// double quant_mult = pow(2.0, 10.0);
-//
-// int width = B->shape()[-2];
-// int num_B_rows = B->shape()[-1];
-//
-// __m128i* quant_B = B->data<__m128i>();
-// assert(width % 8 == 0);
-//
-// assert(transA == false);
-// int num_A_rows = A->shape().elements() / A->shape()[-1];
-//
-// // Each __m128i fits 8 16-bit integers, so we assume the width is a multiple of 8.
-// // We could pad with 0 in the general case.
-//
-// __m128i *quant_A = new __m128i[num_A_rows * width / 8];
-// // The activation matrix must be quantized on-the-fly.
-// Quantize(A->data(), quant_A, (float)quant_mult, num_A_rows, width);
-//
-// // If we quantize to n bits and then multiple the values together, the result will be quantized to n^2 bits.
-// // So we must divide by 1.0/(n^2) to get back the original value.
-// double unquant_mult = 1.0 / (quant_mult * quant_mult);
-//
-// SSE_MatrixMult(quant_A,
-// quant_B,
-// C->data(),
-// (float)unquant_mult,
-// num_A_rows,
-// num_B_rows,
-// width);
-// SSE_AddBias(C->data(), C->data(), bias->data(), num_A_rows, num_B_rows);
-//
-//
-// //std::cerr << C->debug() << std::endl;
-//
-// delete[] quant_A;
-//}
}
diff --git a/src/translator/translator.h b/src/translator/translator.h
index d58fb28e..71e37475 100644
--- a/src/translator/translator.h
+++ b/src/translator/translator.h
@@ -53,7 +53,7 @@ public:
size_t id = 0;
for(auto device : devices) {
auto task = [&](DeviceId device, size_t id) {
- auto graph = New<ExpressionGraph>(true);
+ auto graph = New<ExpressionGraph>(true, options_->get<bool>("optimize"));
graph->setDevice(device);
graph->reserveWorkspaceMB(options_->get<size_t>("workspace"));
graphs_[id] = graph;
@@ -152,7 +152,7 @@ public:
// initialize scorers
for(auto device : devices_) {
- auto graph = New<ExpressionGraph>(true);
+ auto graph = New<ExpressionGraph>(true, options_->get<bool>("optimize"));
graph->setDevice(device);
graph->reserveWorkspaceMB(options_->get<size_t>("workspace"));
graphs_.push_back(graph);