Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2018-06-08 00:34:39 +0300
committerMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2018-06-08 00:34:39 +0300
commit7bb558ecfcdfef5c629f5a9d85ea2b4680bb60aa (patch)
tree0664fc883448c68ad81b91e4b76bef319d45b638 /src
parent68d61a662294cb3f26b3935da95a8ce1c404c293 (diff)
parentac21830517e75e31a0bca3b071292acff0d9610d (diff)
Merge branch 'master' into jonathac/windows_build
Diffstat (limited to 'src')
-rw-r--r--src/CMakeLists.txt5
-rw-r--r--src/common/config_parser.cpp44
-rw-r--r--src/data/dataset.h6
-rw-r--r--src/examples/iris/helper.cpp18
-rw-r--r--src/functional/predicates.h3
-rw-r--r--src/graph/expression_operators.cpp48
-rw-r--r--src/graph/expression_operators.h2
-rw-r--r--src/graph/node_operators.cpp1
-rw-r--r--src/graph/node_operators_unary.h39
-rw-r--r--src/models/encoder_decoder.cpp7
-rw-r--r--src/models/s2s.h8
-rw-r--r--src/models/states.h16
-rw-r--r--src/models/transformer.h270
-rw-r--r--src/tensors/backend.h11
-rw-r--r--src/tensors/cpu/element.h84
-rw-r--r--src/tensors/cpu/int16.h17
-rw-r--r--src/tensors/cpu/prod.cpp16
-rw-r--r--src/tensors/cpu/sharp/avx_gemm.cpp554
-rw-r--r--src/tensors/cpu/sharp/int_gemm.h137
-rw-r--r--[-rwxr-xr-x]src/tensors/cpu/sharp/sse_gemm.cpp (renamed from src/tensors/cpu/sharp/sse_gemm.h)121
-rw-r--r--src/tensors/gpu/add.inc2
-rw-r--r--src/tensors/gpu/element.inc3
-rw-r--r--src/tensors/gpu/prod.cu23
-rw-r--r--src/tensors/gpu/prod.h14
-rw-r--r--src/tensors/tensor_allocator.h2
-rw-r--r--src/tensors/tensor_operators.h6
-rw-r--r--src/tests/sqlite_test.cpp157
-rw-r--r--src/training/graph_group_async.h1
-rw-r--r--src/training/graph_group_multinode.cpp115
-rw-r--r--src/training/graph_group_multinode.h40
-rw-r--r--src/training/graph_group_singleton.h1
-rw-r--r--src/training/graph_group_sync.h2
-rw-r--r--src/training/training.h5
-rw-r--r--src/training/training_state.h4
-rw-r--r--src/translator/beam_search.h2
-rw-r--r--src/translator/scorers.cpp5
-rw-r--r--src/translator/translator.h3
37 files changed, 1429 insertions, 363 deletions
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index ca799770..fc8cc8b3 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -27,6 +27,9 @@ add_library(marian STATIC
tensors/cpu/prod.cpp
tensors/cpu/tensor_operators.cpp
+ tensors/cpu/sharp/avx_gemm.cpp
+ tensors/cpu/sharp/sse_gemm.cpp
+
graph/expression_graph.cpp
graph/expression_operators.cpp
graph/node.cpp
@@ -128,7 +131,7 @@ if(COMPILE_SERVER)
endif(COMPILE_SERVER)
foreach(exec ${EXECUTABLES})
- target_link_libraries(${exec} marian ${EXT_LIBS})
+ target_link_libraries(${exec} marian ${EXT_LIBS} ${EXT_LIBS} ${CMAKE_THREAD_LIBS_INIT})
if(CUDA_FOUND)
target_link_libraries(${exec} marian marian_cuda ${EXT_LIBS} ${CMAKE_THREAD_LIBS_INIT})
cuda_add_cublas_to_target(${exec})
diff --git a/src/common/config_parser.cpp b/src/common/config_parser.cpp
index e455cc3e..988e4746 100644
--- a/src/common/config_parser.cpp
+++ b/src/common/config_parser.cpp
@@ -290,6 +290,8 @@ void ConfigParser::addOptionsCommon(po::options_description& desc) {
"Suppress logging for translation")
("seed", po::value<size_t>()->default_value(0),
"Seed for all random number generators. 0 means initialize randomly")
+ ("clip-gemm", po::value<float>()->default_value(0.f),
+ "If not 0 clip GEMM input values to +/- arg")
("interpolate-env-vars", po::value<bool>()->zero_tokens()->default_value(false),
"allow the use of environment variables in paths, of the form ${VAR_NAME}")
("relative-paths", po::value<bool>()->zero_tokens()->default_value(false),
@@ -374,14 +376,24 @@ void ConfigParser::addOptionsModel(po::options_description& desc) {
"Tie all embedding layers and output layer")
("transformer-heads", po::value<int>()->default_value(8),
"Number of heads in multi-head attention (transformer)")
- ("transformer-dim-ffn", po::value<int>()->default_value(2048),
- "Size of position-wise feed-forward network (transformer)")
("transformer-no-projection", po::value<bool>()->zero_tokens()->default_value(false),
"Omit linear projection after multi-head attention (transformer)")
+ ("transformer-dim-ffn", po::value<int>()->default_value(2048),
+ "Size of position-wise feed-forward network (transformer)")
("transformer-ffn-depth", po::value<int>()->default_value(2),
- "Activation between filters: swish or relu (transformer)")
+ "Depth of filters (transformer)")
("transformer-ffn-activation", po::value<std::string>()->default_value("swish"),
"Activation between filters: swish or relu (transformer)")
+ ("transformer-dim-aan", po::value<int>()->default_value(2048),
+ "Size of position-wise feed-forward network in AAN (transformer)")
+ ("transformer-aan-depth", po::value<int>()->default_value(2),
+ "Depth of filter for AAN (transformer)")
+ ("transformer-aan-activation", po::value<std::string>()->default_value("swish"),
+ "Activation between filters in AAN: swish or relu (transformer)")
+ ("transformer-aan-nogate", po::value<bool>()->zero_tokens()->default_value(false),
+ "Omit gate in AAN (transformer)")
+ ("transformer-decoder-autoreg", po::value<std::string>()->default_value("self-attention"),
+ "Type of autoregressive layer in transformer decoder: self-attention, average-attention (transformer)")
("transformer-preprocess", po::value<std::string>()->default_value(""),
"Operation before each transformer layer: d = dropout, a = add, n = normalize")
("transformer-postprocess-emb", po::value<std::string>()->default_value("d"),
@@ -500,7 +512,6 @@ void ConfigParser::addOptionsTraining(po::options_description& desc) {
"Number of batches to preload for length-based sorting")
("maxi-batch-sort", po::value<std::string>()->default_value("trg"),
"Sorting strategy for maxi-batch: trg (default) src none")
-
("optimizer,o", po::value<std::string>()->default_value("adam"),
"Optimization algorithm (possible values: sgd, adagrad, adam")
("optimizer-params", po::value<std::vector<float>>()
@@ -596,6 +607,10 @@ void ConfigParser::addOptionsTraining(po::options_description& desc) {
("multi-node-overlap", po::value<bool>()
->default_value(true),
"Overlap model computations with MPI communication")
+ ("multi-node-local-optimizers", po::value<bool>()
+ ->zero_tokens()
+ ->default_value(false),
+ "Enable local optimizers with multi-node. Requires optimizer delay to be turned on.")
;
// clang-format on
desc.add(training);
@@ -642,6 +657,8 @@ void ConfigParser::addOptionsValid(po::options_description& desc) {
"Divide translation score by pow(translation length, arg) ")
("word-penalty", po::value<float>()->default_value(0.f)->implicit_value(0.f),
"Subtract (arg * translation length) from translation score ")
+ ("max-length-factor", po::value<float>()->default_value(3),
+ "Maximum target length as source length times factor")
("allow-unk", po::value<bool>()->zero_tokens()->default_value(false),
"Allow unknown words to appear in output")
("n-best", po::value<bool>()->zero_tokens()->default_value(false),
@@ -670,8 +687,12 @@ void ConfigParser::addOptionsTranslate(po::options_description& desc) {
"Subtract (arg * translation length) from translation score ")
("allow-unk", po::value<bool>()->zero_tokens()->default_value(false),
"Allow unknown words to appear in output")
+ ("skip-cost", po::value<bool>()->zero_tokens()->default_value(false),
+ "Ignore model cost during translation, not recommended for beam-size > 1")
("max-length", po::value<size_t>()->default_value(1000),
"Maximum length of a sentence in a training sentence pair")
+ ("max-length-factor", po::value<float>()->default_value(3),
+ "Maximum target length as source length times factor")
("max-length-crop", po::value<bool>()->zero_tokens()->default_value(false),
"Crop a sentence to max-length instead of ommitting it if longer than max-length")
("devices,d", po::value<std::vector<std::string>>()
@@ -691,6 +712,8 @@ void ConfigParser::addOptionsTranslate(po::options_description& desc) {
"Optimize speed aggressively sacrificing memory or precision")
("mini-batch", po::value<int>()->default_value(1),
"Size of mini-batch used during update")
+ ("mini-batch-words", po::value<int>()->default_value(0),
+ "Set mini-batch size based on words instead of sentences")
("maxi-batch", po::value<int>()->default_value(1),
"Number of batches to preload for length-based sorting")
("maxi-batch-sort", po::value<std::string>()->default_value("none"),
@@ -891,6 +914,11 @@ void ConfigParser::parseOptions(int argc, char** argv, bool doValidate) {
SET_OPTION("transformer-dim-ffn", int);
SET_OPTION("transformer-ffn-depth", int);
SET_OPTION("transformer-ffn-activation", std::string);
+ SET_OPTION("transformer-dim-aan", int);
+ SET_OPTION("transformer-aan-depth", int);
+ SET_OPTION("transformer-aan-activation", std::string);
+ SET_OPTION("transformer-aan-nogate", bool);
+ SET_OPTION("transformer-decoder-autoreg", std::string);
#ifdef CUDNN
SET_OPTION("char-stride", int);
@@ -916,7 +944,7 @@ void ConfigParser::parseOptions(int argc, char** argv, bool doValidate) {
SET_OPTION("transformer-dropout", float);
SET_OPTION("transformer-dropout-attention", float);
SET_OPTION("transformer-dropout-ffn", float);
-
+
SET_OPTION("overwrite", bool);
SET_OPTION("no-reload", bool);
if(!vm_["train-sets"].empty()) {
@@ -978,6 +1006,7 @@ void ConfigParser::parseOptions(int argc, char** argv, bool doValidate) {
SET_OPTION("multi-node", bool);
SET_OPTION("multi-node-overlap", bool);
+ SET_OPTION("multi-node-local-optimizers", bool);
}
if(mode_ == ConfigMode::rescoring) {
@@ -999,10 +1028,13 @@ void ConfigParser::parseOptions(int argc, char** argv, bool doValidate) {
SET_OPTION("word-penalty", float);
SET_OPTION("allow-unk", bool);
SET_OPTION("n-best", bool);
+ SET_OPTION("mini-batch-words", int);
SET_OPTION_NONDEFAULT("weights", std::vector<float>);
SET_OPTION_NONDEFAULT("shortlist", std::vector<std::string>);
SET_OPTION("port", size_t);
SET_OPTION("optimize", bool);
+ SET_OPTION("max-length-factor", float);
+ SET_OPTION("skip-cost", bool);
}
/** valid **/
@@ -1024,6 +1056,7 @@ void ConfigParser::parseOptions(int argc, char** argv, bool doValidate) {
SET_OPTION_NONDEFAULT("valid-translation-output", std::string);
SET_OPTION("beam-size", size_t);
SET_OPTION("normalize", float);
+ SET_OPTION("max-length-factor", float);
SET_OPTION("word-penalty", float);
SET_OPTION("allow-unk", bool);
SET_OPTION("n-best", bool);
@@ -1035,6 +1068,7 @@ void ConfigParser::parseOptions(int argc, char** argv, bool doValidate) {
SET_OPTION("quiet-translation", bool);
SET_OPTION_NONDEFAULT("log", std::string);
SET_OPTION("seed", size_t);
+ SET_OPTION("clip-gemm", float);
SET_OPTION("interpolate-env-vars", bool);
SET_OPTION("relative-paths", bool);
SET_OPTION("devices", std::vector<std::string>);
diff --git a/src/data/dataset.h b/src/data/dataset.h
index bb882615..3525de10 100644
--- a/src/data/dataset.h
+++ b/src/data/dataset.h
@@ -2,6 +2,7 @@
#include "common/definitions.h"
#include "data/batch.h"
+#include "data/rng_engine.h"
#include "data/vocab.h"
#include "training/training_state.h"
@@ -97,7 +98,8 @@ public:
}
};
-class Dataset : public DatasetBase<Example, ExampleIterator, DataBatch> {
+class Dataset : public DatasetBase<Example, ExampleIterator, DataBatch>,
+ public RNGEngine {
protected:
Examples examples_;
@@ -110,7 +112,7 @@ public:
iterator end() { return ExampleIterator(examples_.end()); }
- void shuffle() { std::random_shuffle(examples_.begin(), examples_.end()); }
+ void shuffle() { std::shuffle(examples_.begin(), examples_.end(), eng_); }
batch_ptr toBatch(const Examples& batchVector) {
int batchSize = batchVector.size();
diff --git a/src/examples/iris/helper.cpp b/src/examples/iris/helper.cpp
index feea7d49..c47457ca 100644
--- a/src/examples/iris/helper.cpp
+++ b/src/examples/iris/helper.cpp
@@ -38,15 +38,15 @@ void readIrisData(const std::string fileName,
}
void shuffleData(std::vector<float>& features, std::vector<float>& labels) {
- // Create a list of indeces 0...K
- std::vector<int> indeces;
- indeces.reserve(labels.size());
+ // Create a list of indices 0...K
+ std::vector<int> indices;
+ indices.reserve(labels.size());
for(int i = 0; i < labels.size(); ++i)
- indeces.push_back(i);
+ indices.push_back(i);
- // Shuffle indeces
- std::srand(marian::Config::seed);
- std::random_shuffle(indeces.begin(), indeces.end());
+ // Shuffle indices
+ static std::mt19937 urng(marian::Config::seed);
+ std::shuffle(indices.begin(), indices.end(), urng);
std::vector<float> featuresTemp;
featuresTemp.reserve(features.size());
@@ -54,8 +54,8 @@ void shuffleData(std::vector<float>& features, std::vector<float>& labels) {
labelsTemp.reserve(labels.size());
// Get shuffled features and labels
- for(auto i = 0; i < indeces.size(); ++i) {
- auto idx = indeces[i];
+ for(auto i = 0; i < indices.size(); ++i) {
+ auto idx = indices[i];
labelsTemp.push_back(labels[idx]);
featuresTemp.insert(featuresTemp.end(),
features.begin() + (idx * NUM_FEATURES),
diff --git a/src/functional/predicates.h b/src/functional/predicates.h
index 41a741bb..8420e9e4 100644
--- a/src/functional/predicates.h
+++ b/src/functional/predicates.h
@@ -115,6 +115,9 @@ BINARY(Pow, pow, pow(x, y));
BINARY(Clip, clip, fabs(x) >= y ? sgn(x) * y : x);
+// derivative of Clip, cut-off function
+BINARY(Bump, bump, fabs(x) >= y ? 0.f : 1.f);
+
UNARY(sReLU, ReLU, x > 0.f ? x : 0.f);
UNARY(sReLUBack, ReLUback, x > 0.f ? 1.f : 0.f);
BINARY(sPReLU, PReLU, x > 0.f ? x : x * y);
diff --git a/src/graph/expression_operators.cpp b/src/graph/expression_operators.cpp
index 287f771d..fd985d4a 100644
--- a/src/graph/expression_operators.cpp
+++ b/src/graph/expression_operators.cpp
@@ -31,6 +31,13 @@ Expr prelu(Expr a, float alpha) {
return Expression<PReLUNodeOp>(alpha, a);
}
+Expr clip(Expr a, float c) {
+ if(c == 0)
+ return a;
+ else
+ return Expression<ClipNodeOp>(a, c);
+}
+
Expr log(Expr a) {
return Expression<LogNodeOp>(a);
};
@@ -204,16 +211,21 @@ Expr weighted_average(Expr in, Expr weights, keywords::axis_k ax) {
Expr dot(Expr a, Expr b, bool transA, bool transB, float scale) {
auto device = a->graph()->getDevice().type;
+ float clipValue = a->graph()->getBackend()->getClip();
+
+ // Currently only true when command line options
+ // --optimize --cpu-thread=N with N > 0 are set.
if(a->graph()->isOptimized() && device == DeviceType::cpu) {
// dotInt16 computes A * B.T, hence the transpose for B to get A * B
// if transA = false and transB = false.
- return cpu::int16::dot(cpu::int16::quantize(transA ? transpose(a) : a),
- cpu::int16::quantize(transB ? b : transpose(b)),
+ return cpu::int16::dot(cpu::int16::quantize(transA ? transpose(a) : a, clipValue),
+ cpu::int16::quantize(transB ? b : transpose(b), clipValue),
scale);
}
else {
- return Expression<DotNodeOp>(a, b, transA, transB, scale);
+ return Expression<DotNodeOp>(clip(a, clipValue), clip(b, clipValue),
+ transA, transB, scale);
}
}
@@ -223,6 +235,9 @@ Expr bdot(Expr a, Expr b, bool transA, bool transB, float scale) {
Expr affine(Expr a, Expr b, Expr bias, bool transA, bool transB, float scale) {
auto device = a->graph()->getDevice().type;
+
+ float clipValue = a->graph()->getBackend()->getClip();
+
if(a->graph()->isOptimized() && device == DeviceType::cpu) {
bool autotune = true;
@@ -255,8 +270,8 @@ Expr affine(Expr a, Expr b, Expr bias, bool transA, bool transB, float scale) {
return e;
};
auto alg1 = [=]() {
- return rec1(cpu::int16::affine(rec1(cpu::int16::quantize(transA ? rec1(transpose(a)) : a)),
- cpu::int16::quantize(transB ? b : transpose(b)),
+ return rec1(cpu::int16::affine(rec1(cpu::int16::quantize(transA ? rec1(transpose(a)) : a, clipValue)),
+ cpu::int16::quantize(transB ? b : transpose(b), clipValue),
bias,
scale),
true);
@@ -270,8 +285,18 @@ Expr affine(Expr a, Expr b, Expr bias, bool transA, bool transB, float scale) {
e->record(tuner, hash2, stop);
return e;
};
+
+
auto alg2 = [=]() {
- std::vector<Expr> nodes = {a, b, bias};
+ auto ac = clip(a, clipValue);
+ if(ac != a)
+ ac = rec2(ac);
+
+ auto bc = clip(b, clipValue);
+ if(bc != b)
+ bc = rec2(bc);
+
+ std::vector<Expr> nodes = {ac, bc, bias};
return rec2(Expression<AffineNodeOp>(nodes, transA, transB, scale),
true);
};
@@ -283,16 +308,21 @@ Expr affine(Expr a, Expr b, Expr bias, bool transA, bool transB, float scale) {
}
else {
// cpu int16 version
- return cpu::int16::affine(cpu::int16::quantize(transA ? transpose(a) : a),
- cpu::int16::quantize(transB ? b : transpose(b)),
+ return cpu::int16::affine(cpu::int16::quantize(transA ? transpose(a) : a, clipValue),
+ cpu::int16::quantize(transB ? b : transpose(b), clipValue),
bias,
scale);
}
}
else {
// general version, MKL, CBlas or CUDA
- std::vector<Expr> nodes = {a, b, bias};
+ // if clipValue > 0, the inputs will be clipped to range [-clipValue, clipValue]
+ // This is meant to keep values at the same range as used during training when
+ // optimizing for 8-bit integer products. Likely to be removed in the future
+ // when we explore better ways to handle this.
+ std::vector<Expr> nodes = {clip(a, clipValue), clip(b, clipValue), bias};
return Expression<AffineNodeOp>(nodes, transA, transB, scale);
+
}
}
diff --git a/src/graph/expression_operators.h b/src/graph/expression_operators.h
index 070ee0ba..036445ee 100644
--- a/src/graph/expression_operators.h
+++ b/src/graph/expression_operators.h
@@ -34,6 +34,8 @@ Expr log(Expr a);
Expr exp(Expr a);
+Expr clip(Expr a, float c);
+
Expr operator-(Expr a);
/*********************************************************/
diff --git a/src/graph/node_operators.cpp b/src/graph/node_operators.cpp
index 788b3fc6..315cb5eb 100644
--- a/src/graph/node_operators.cpp
+++ b/src/graph/node_operators.cpp
@@ -2,7 +2,6 @@
#include "expression_graph.h"
#include "tensors/tensor_operators.h"
-#include "tensors/cpu/sharp/sse_gemm.h"
namespace marian {
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h
index e14f6546..259e6072 100644
--- a/src/graph/node_operators_unary.h
+++ b/src/graph/node_operators_unary.h
@@ -99,6 +99,45 @@ public:
}
};
+struct ClipNodeOp : public UnaryNodeOp {
+private:
+ float clip_{0};
+
+public:
+ ClipNodeOp(Expr a, float clip) : UnaryNodeOp(a), clip_{clip} {}
+
+ NodeOps forwardOps() {
+ using namespace functional;
+ return {NodeOp(Element(_1 = clip(_2, clip_), val_, child(0)->val()))};
+ }
+
+ NodeOps backwardOps() {
+ using namespace functional;
+ return {NodeOp(Add(bump(_1, clip_) * _2, child(0)->grad(), child(0)->val(), adj_))};
+ }
+
+ const std::string type() { return "clip"; }
+
+ virtual size_t hash() {
+ if(!hash_) {
+ hash_ = NaryNodeOp::hash();
+ boost::hash_combine(hash_, clip_);
+ }
+ return hash_;
+ }
+
+ virtual bool equal(Expr node) {
+ if(!NaryNodeOp::equal(node))
+ return false;
+ auto cnode = std::dynamic_pointer_cast<ClipNodeOp>(node);
+ if(!cnode)
+ return false;
+ if(clip_ != cnode->clip_)
+ return false;
+ return true;
+ }
+};
+
struct LogitNodeOp : public UnaryNodeOp {
LogitNodeOp(Expr a) : UnaryNodeOp(a) {}
diff --git a/src/models/encoder_decoder.cpp b/src/models/encoder_decoder.cpp
index 10110313..c17adb69 100644
--- a/src/models/encoder_decoder.cpp
+++ b/src/models/encoder_decoder.cpp
@@ -26,7 +26,7 @@ EncoderDecoder::EncoderDecoder(Ptr<Options> options)
"special-vocab",
"tied-embeddings",
"tied-embeddings-src",
- "tied-embeddings-all",
+ "tied-embeddings-all"
};
modelFeatures_.insert("transformer-heads");
@@ -34,9 +34,14 @@ EncoderDecoder::EncoderDecoder(Ptr<Options> options)
modelFeatures_.insert("transformer-dim-ffn");
modelFeatures_.insert("transformer-ffn-depth");
modelFeatures_.insert("transformer-ffn-activation");
+ modelFeatures_.insert("transformer-dim-aan");
+ modelFeatures_.insert("transformer-aan-depth");
+ modelFeatures_.insert("transformer-aan-activation");
+ modelFeatures_.insert("transformer-aan-nogate");
modelFeatures_.insert("transformer-preprocess");
modelFeatures_.insert("transformer-postprocess");
modelFeatures_.insert("transformer-postprocess-emb");
+ modelFeatures_.insert("transformer-decoder-autoreg");
}
std::vector<Ptr<EncoderBase>>& EncoderDecoder::getEncoders() {
diff --git a/src/models/s2s.h b/src/models/s2s.h
index abfb8bdd..e0d376d8 100644
--- a/src/models/s2s.h
+++ b/src/models/s2s.h
@@ -271,6 +271,7 @@ public:
options_->has("original-type")
&& opt<std::string>("original-type") == "nematus") //
);
+
start = mlp->apply(meanContexts);
} else {
int dimBatch = batch->size();
@@ -364,8 +365,13 @@ public:
else
logits = output_->apply(embeddings, decoderContext);
+
// return unormalized(!) probabilities
- return New<DecoderState>(decoderStates, logits, state->getEncoderStates(), state->getBatch());
+ auto nextState = New<DecoderState>(decoderStates, logits, state->getEncoderStates(), state->getBatch());
+
+ // Advance current target token position by one
+ nextState->setPosition(state->getPosition() + 1);
+ return nextState;
}
// helper function for guided alignment
diff --git a/src/models/states.h b/src/models/states.h
index 0584a111..f7818289 100644
--- a/src/models/states.h
+++ b/src/models/states.h
@@ -38,6 +38,9 @@ protected:
rnn::States states_;
Ptr<data::CorpusBatch> batch_;
+ // Keep track of current target token position during translation
+ size_t position_{0};
+
public:
DecoderState(const rnn::States& states,
Expr probs,
@@ -54,8 +57,12 @@ public:
virtual Ptr<DecoderState> select(const std::vector<size_t>& selIdx,
int beamSize) {
- return New<DecoderState>(
+ auto selectedState = New<DecoderState>(
states_.select(selIdx, beamSize), probs_, encStates_, batch_);
+
+ // Set positon of new state based on the target token position of current state
+ selectedState->setPosition(getPosition());
+ return selectedState;
}
virtual const rnn::States& getStates() { return states_; }
@@ -84,6 +91,13 @@ public:
return batch_;
}
+
+ // Set current target token position in state when decoding
+ size_t getPosition() { return position_; }
+
+ // Set current target token position in state when decoding
+ void setPosition(size_t position) { position_ = position; }
+
virtual void blacklist(Expr totalCosts, Ptr<data::CorpusBatch> batch) {}
};
}
diff --git a/src/models/transformer.h b/src/models/transformer.h
index 25e04567..878ca16a 100644
--- a/src/models/transformer.h
+++ b/src/models/transformer.h
@@ -323,11 +323,43 @@ public:
return output;
}
- static Expr LayerFFN(Ptr<ExpressionGraph> graph,
- Ptr<Options> options,
- std::string prefix,
- Expr input,
- bool inference = false) {
+ Expr DecoderLayerSelfAttention(rnn::State& decoderState,
+ const rnn::State& prevDecoderState,
+ Ptr<ExpressionGraph> graph,
+ Ptr<Options> options,
+ std::string prefix,
+ Expr input,
+ Expr selfMask,
+ int startPos,
+ bool inference = false) {
+
+ using namespace keywords;
+
+ selfMask = transposedLogMask(selfMask);
+
+ auto values = input;
+ if(startPos > 0) {
+ values = concatenate({prevDecoderState.output, input},
+ axis = -2);
+ }
+ decoderState.output = values;
+
+ // TODO: do not recompute matrix multiplies
+ return LayerAttention(graph,
+ options,
+ prefix,
+ input,
+ values,
+ values,
+ selfMask,
+ inference);
+ }
+
+ Expr LayerFFN(Ptr<ExpressionGraph> graph,
+ Ptr<Options> options,
+ std::string prefix,
+ Expr input,
+ bool inference = false) {
using namespace keywords;
int dimModel = input->shape()[-1];
@@ -377,6 +409,108 @@ public:
return output;
}
+
+ // Implementation of Average Attention Network Layer (ANN) from
+ // https://arxiv.org/pdf/1805.00631.pdf
+ Expr LayerAAN(Ptr<ExpressionGraph> graph,
+ Ptr<Options> options,
+ std::string prefix,
+ Expr x,
+ Expr y,
+ bool inference = false) {
+ using namespace keywords;
+
+ int dimModel = x->shape()[-1];
+
+ float dropProb = inference ? 0 : options->get<float>("transformer-dropout");
+ auto opsPre = options->get<std::string>("transformer-preprocess");
+
+ y = PreProcess(graph, prefix + "_ffn", opsPre, y, dropProb);
+
+ // FFN
+ int dimAan = options->get<int>("transformer-dim-aan");
+ int depthAan = options->get<int>("transformer-aan-depth");
+ auto act = options->get<std::string>("transformer-aan-activation");
+ float aanDropProb = inference ? 0 : options->get<float>("transformer-dropout-ffn");
+
+ int i = 1;
+ int dimLast = dimModel;
+ for(; i < depthAan; ++i) {
+ int dimFirst = i == 1 ? dimModel : dimAan;
+ auto W = graph->param(
+ prefix + "_W" + std::to_string(i), {dimFirst, dimAan}, inits::glorot_uniform);
+ auto b = graph->param(prefix + "_b" + std::to_string(i), {1, dimAan}, inits::zeros);
+
+ y = affine(y, W, b);
+
+ if(act == "relu")
+ y = relu(y);
+ else
+ y = swish(y);
+
+ if(aanDropProb)
+ y = dropout(y, aanDropProb);
+
+ dimLast = dimAan;
+ }
+
+ if(dimLast != dimModel) {
+ auto W = graph->param(
+ prefix + "_W" + std::to_string(i), {dimLast, dimModel}, inits::glorot_uniform);
+ auto b = graph->param(prefix + "_b" + std::to_string(i), {1, dimModel}, inits::zeros);
+ y = affine(y, W, b);
+ }
+
+ bool noGate = options->get<bool>("transformer-aan-nogate");
+ if(!noGate) {
+ auto Wi = graph->param(prefix + "_Wi", {dimModel, dimModel}, inits::glorot_uniform);
+ auto bi = graph->param(prefix + "_bi", {1, dimModel}, inits::zeros);
+
+ auto Wf = graph->param(prefix + "_Wf", {dimModel, dimModel}, inits::glorot_uniform);
+ auto bf = graph->param(prefix + "_bf", {1, dimModel}, inits::zeros);
+
+ auto gi = logit(affine(x, Wi, bi));
+ auto gf = logit(affine(y, Wf, bf));
+ y = gi * x + gf * y;
+ }
+
+ auto opsPost = options->get<std::string>("transformer-postprocess");
+ y = PostProcess(graph, prefix + "_ffn", opsPost, y, x, dropProb);
+
+ return y;
+ }
+
+ // Implementation of Average Attention Network Layer (ANN) from
+ // https://arxiv.org/pdf/1805.00631.pdf
+ // Function wrapper using decoderState as input.
+ Expr DecoderLayerAAN(rnn::State& decoderState,
+ const rnn::State& prevDecoderState,
+ Ptr<ExpressionGraph> graph,
+ Ptr<Options> options,
+ std::string prefix,
+ Expr input,
+ Expr selfMask,
+ int startPos,
+ bool inference = false) {
+
+ using namespace keywords;
+
+ auto output = input;
+ if(startPos > 0) {
+ // we are decoding at a position after 0
+ output = (prevDecoderState.output * startPos + input) / (startPos + 1);
+ }
+ else if(startPos == 0 && output->shape()[-2] > 1) {
+ // we are training or scoring, because there is no history and
+ // the context is larger than a single time step. We do not need
+ // to average batch with only single words.
+ selfMask = selfMask / sum(selfMask, axis=-1);
+ output = bdot(selfMask, output);
+ }
+ decoderState.output = output;
+
+ return LayerAAN(graph, options, prefix, input, output, inference);
+ }
};
class EncoderTransformer : public EncoderBase, public Transformer {
@@ -508,7 +642,12 @@ public:
selectedStates.push_back({sel, nullptr});
}
- return New<TransformerState>(selectedStates, probs_, encStates_, batch_);
+ // Create hypothesis-selected state based on current state and hyp indices
+ auto selectedState = New<TransformerState>(selectedStates, probs_, encStates_, batch_);
+
+ // Set the same target token position as the current state
+ selectedState->setPosition(getPosition());
+ return selectedState;
}
};
@@ -551,10 +690,10 @@ public:
// according to paper embeddings are scaled by \sqrt(d_m)
auto scaledEmbeddings = std::sqrt(dimEmb) * embeddings;
- int startPos = 0;
- auto prevDecoderStates = state->getStates();
- if(prevDecoderStates.size() > 0)
- startPos = prevDecoderStates[0].output->shape()[-2];
+ // set current target token position during decoding or training. At training
+ // this should be 0. During translation the current length of the translation.
+ // Used for position embeddings and creating new decoder states.
+ int startPos = state->getPosition();
scaledEmbeddings
= AddPositionalEmbeddings(graph, scaledEmbeddings, startPos);
@@ -569,7 +708,6 @@ public:
query = PreProcess(graph, prefix_ + "_emb", opsEmb, query, dropProb);
- rnn::States decoderStates;
int dimTrgWords = query->shape()[-2];
int dimBatch = query->shape()[-3];
auto selfMask = TriangleMask(graph, dimTrgWords); // [ (1,) 1, max length, max length]
@@ -582,9 +720,6 @@ public:
// selfMask = repeat(selfMask, dimBeam, axis = -4);
}
- selfMask = transposedLogMask(selfMask);
-
- // reorganize batch and timestep for encoder embeddings
std::vector<Expr> encoderContexts;
std::vector<Expr> encoderMasks;
@@ -608,60 +743,59 @@ public:
encoderMasks.push_back(encoderMask);
}
- // apply decoder layers
- auto decDepth = opt<int>("dec-depth");
- for(int i = 1; i <= decDepth; ++i) {
- auto values = query;
- if(prevDecoderStates.size() > 0)
- values
- = concatenate({prevDecoderStates[i - 1].output, query}, axis = -2);
+ rnn::States prevDecoderStates = state->getStates();
+ rnn::States decoderStates;
+ // apply layers
+ for(int i = 1; i <= opt<int>("dec-depth"); ++i) {
+ rnn::State decoderState;
+ rnn::State prevDecoderState;
- decoderStates.push_back({values, nullptr});
+ if(prevDecoderStates.size() > 0)
+ prevDecoderState = prevDecoderStates[i - 1];
+
+ std::string layerType = opt<std::string>("transformer-decoder-autoreg");
+ if(layerType == "self-attention") {
+ query = DecoderLayerSelfAttention(decoderState,
+ prevDecoderState,
+ graph,
+ options_,
+ prefix_ + "_l" + std::to_string(i) + "_self",
+ query,
+ selfMask,
+ startPos,
+ inference_);
+ } else if(layerType == "average-attention") {
+ query = DecoderLayerAAN(decoderState,
+ prevDecoderState,
+ graph,
+ options_,
+ prefix_ + "_l" + std::to_string(i) + "_aan",
+ query,
+ selfMask,
+ startPos,
+ inference_);
+ } else {
+ ABORT("Unknown auto-regressive layer type in transformer decoder {}", layerType);
+ }
- // TODO: do not recompute matrix multiplies
- // self-attention
- query = LayerAttention(graph,
- options_,
- prefix_ + "_l" + std::to_string(i) + "_self",
- query,
- values,
- values,
- selfMask,
- inference_);
+ decoderStates.push_back(decoderState);
- // attention over encoder
+ // Iterate over multiple encoders and simply stack the attention blocks
if(encoderContexts.size() > 0) {
- // auto comb = opt<std::string>("transformer-multi-encoder");
- std::string comb = "stack";
- if(comb == "concat") {
- query
- = LayerAttention(graph,
- options_,
- prefix_ + "_l" + std::to_string(i) + "_context",
- query,
- encoderContexts,
- encoderContexts,
- encoderMasks,
- inference_);
-
- } else if(comb == "stack") {
- for(int j = 0; j < encoderContexts.size(); ++j) { // multiple encoders are applied one after another
- std::string prefix
- = prefix_ + "_l" + std::to_string(i) + "_context";
- if(j > 0)
- prefix += "_enc" + std::to_string(j + 1);
-
- query = LayerAttention(graph,
- options_,
- prefix,
- query,
- encoderContexts[j],
- encoderContexts[j],
- encoderMasks[j],
- inference_);
- }
- } else {
- ABORT("Unknown value for transformer-multi-encoder: {}", comb);
+ for(int j = 0; j < encoderContexts.size(); ++j) {
+ std::string prefix
+ = prefix_ + "_l" + std::to_string(i) + "_context";
+ if(j > 0)
+ prefix += "_enc" + std::to_string(j + 1);
+
+ query = LayerAttention(graph,
+ options_,
+ prefix,
+ query,
+ encoderContexts[j],
+ encoderContexts[j],
+ encoderMasks[j],
+ inference_);
}
}
@@ -704,8 +838,12 @@ public:
Expr logits = output_->apply(decoderContext);
// return unormalized(!) probabilities
- return New<TransformerState>(
- decoderStates, logits, state->getEncoderStates(), state->getBatch());
+ auto nextState = New<TransformerState>(decoderStates,
+ logits,
+ state->getEncoderStates(),
+ state->getBatch());
+ nextState->setPosition(state->getPosition() + 1);
+ return nextState;
}
// helper function for guided alignment
diff --git a/src/tensors/backend.h b/src/tensors/backend.h
index 3cd51ce1..d3687929 100644
--- a/src/tensors/backend.h
+++ b/src/tensors/backend.h
@@ -9,12 +9,23 @@ protected:
DeviceId deviceId_;
size_t seed_;
+ // global clipping value for matrix-multiplies, should soon be removed.
+ float clipValue_{0.f};
+
public:
Backend(DeviceId deviceId, size_t seed) : deviceId_(deviceId), seed_(seed) {}
virtual DeviceId getDevice() { return deviceId_; };
virtual void setDevice() = 0;
virtual void synchronize() = 0;
+
+ virtual void setClip(float clipValue) {
+ clipValue_ = clipValue;
+ }
+
+ float getClip() {
+ return clipValue_;
+ }
};
Ptr<Backend> BackendByDevice(DeviceId deviceId, size_t seed);
diff --git a/src/tensors/cpu/element.h b/src/tensors/cpu/element.h
index 210b9a6c..23750bcd 100644
--- a/src/tensors/cpu/element.h
+++ b/src/tensors/cpu/element.h
@@ -1,8 +1,3 @@
-/* All or part of this file was contributed by Intel under license:
- * Copyright (C) 2017-2018 Intel Corporation
- * SPDX-License-Identifier: MIT
- */
-
#pragma once
#include "tensors/tensor.h"
@@ -10,40 +5,69 @@
namespace marian {
namespace cpu {
-template <size_t K, bool broadcast, class Functor>
-void gElement(Functor functor,
- functional::Array<functional::Tensor<float>, K> tensors) {
- int length = tensors[0].shape().elements();
- functional::Array<int, functional::Shape::size()> dims;
- functional::Array<int, K> indices;
+// Function in this header are supposed to execute element-wise operations
+// (passed in as a Functor) on arbitrary numbers of tensors. The templates
+// are required to implement correct broadcasting of operations across
+// a fixed-at-compile-time but in principle arbitrary number of dimensions.
+
+// @TODO: generalize to vector operations, possible using specializations
+
+// single loop over outer dimension. Recursively creates nested loops
+// down to inner dimension and to single elements. Since this is based
+// on strides, it correctly broadcasts to all dimensions without additional
+// computation.
+// Compiler optimizes this to single construct with nested(?) loops.
+template <size_t I = 0> struct E {
+ template <size_t K, class Functor>
+ static inline void element(const Functor& functor,
+ functional::Array<functional::Tensor<float>, K>& tensors,
+ functional::Array<int, K> indices) {
-#pragma omp parallel for simd
- for(int index = 0; index < length; ++index) {
- indices.fill(index);
- if(broadcast) {
- tensors[0].shape().dims(index, dims);
- for(int i = 1; i < K; ++i)
- indices[i] = tensors[i].shape().bindex(dims);
+ auto& shape = tensors[0].shape();
+
+ // loop for outer-most dimension
+ for(int i = 0; i < shape[I]; ++i) {
+
+ // call loop for next-inner dimension
+ E<I + 1>::element(functor, tensors, indices);
+
+ // increase index for current dimension by stride or 0 if broadcasting. bstride(i)
+ // is look-up value, either equal to stride if the corresponding dim is larger 1 or
+ // 0 if the dim is 1.
+ for(int k = 0; k < K; ++k)
+ indices[k] += tensors[k].shape().bstride(I);
}
- tensors[0][index] = functional::apply(functor, tensors, indices);
}
-}
+};
+// specialization for inner-most single element (recursive stopping criterion)
+// using const reference for indices here to avoid copying. No loop.
+template <> struct E<functional::Shape::size()> {
+ template <size_t K, class Functor>
+ static inline void element(const Functor& functor,
+ functional::Array<functional::Tensor<float>, K>& tensors,
+ const functional::Array<int, K>& indices) {
+
+ // just apply the function for all indexed elements across all tensors
+ tensors[0][indices[0]] = functional::apply(functor, tensors, indices);
+
+ }
+};
+
+// main call to function executing element-wise operation
template <class Functor, class... Tensors>
-void Element(Functor functor, marian::Tensor out, Tensors... tensors) {
+void Element(const Functor& functor, marian::Tensor out, Tensors... tensors) {
constexpr size_t K = sizeof...(tensors) + 1;
functional::Array<functional::Tensor<float>, K> gTensors = {out, tensors...};
- int length = gTensors[0].shape().elements();
-
- bool broadcast = false;
- for(int i = 1; i < K; ++i)
- broadcast = broadcast || gTensors[0].shape() != gTensors[i].shape();
+ // create and initialize indices to 0
+ functional::Array<int, K> indices;
+ indices.fill(0);
- if(broadcast)
- cpu::gElement<K, true>(functor, gTensors);
- else
- cpu::gElement<K, false>(functor, gTensors);
+ // call elementwise operation going from outer-most dimension
+ // to inner-most element.
+ E<>::element(functor, gTensors, indices);
}
+
}
}
diff --git a/src/tensors/cpu/int16.h b/src/tensors/cpu/int16.h
index aca49e17..822d49f9 100644
--- a/src/tensors/cpu/int16.h
+++ b/src/tensors/cpu/int16.h
@@ -1,18 +1,21 @@
#pragma once
#include "graph/node.h"
-#include "tensors/cpu/sharp/sse_gemm.h"
+#include "tensors/cpu/sharp/int_gemm.h"
namespace marian {
namespace cpu {
namespace int16 {
struct QuantizeNodeOp : public UnaryNodeOp {
- QuantizeNodeOp(Expr a) : UnaryNodeOp(a, Type::int16) {}
+ float clipValue_;
+
+ QuantizeNodeOp(Expr a, float clipValue)
+ : UnaryNodeOp(a, Type::int16), clipValue_{clipValue} {}
NodeOps forwardOps() {
return {
- NodeOp(Quantize(val_, child(0)->val()))
+ NodeOp(Quantize16(val_, child(0)->val(), clipValue_))
};
}
@@ -50,7 +53,7 @@ public:
NodeOps forwardOps() {
return {
- NodeOp(ProdInt(val_,
+ NodeOp(ProdInt16(val_,
child(0)->val(),
child(1)->val(),
scalar_))
@@ -93,7 +96,7 @@ public:
NodeOps forwardOps() {
return {
- NodeOp(ProdInt(val_,
+ NodeOp(ProdInt16(val_,
child(0)->val(),
child(1)->val(),
scalar_);
@@ -118,8 +121,8 @@ static inline Expr affine(Expr a, Expr b, Expr c, float scalar) {
return Expression<cpu::int16::AffineNodeOp>(nodes, scalar);
}
-static inline Expr quantize(Expr a) {
- return Expression<cpu::int16::QuantizeNodeOp>(a);
+static inline Expr quantize(Expr a, float clipValue) {
+ return Expression<cpu::int16::QuantizeNodeOp>(a, clipValue);
}
diff --git a/src/tensors/cpu/prod.cpp b/src/tensors/cpu/prod.cpp
index 6728a9b1..c5d86479 100644
--- a/src/tensors/cpu/prod.cpp
+++ b/src/tensors/cpu/prod.cpp
@@ -14,7 +14,7 @@
#endif
#endif
-#include "sharp/sse_gemm.h"
+#include "sharp/int_gemm.h"
namespace marian {
@@ -49,8 +49,8 @@ inline void sgemm(bool transA, bool transB,
#endif
void Prod(marian::Tensor C,
- const marian::Tensor A,
- const marian::Tensor B,
+ const marian::Tensor& A,
+ const marian::Tensor& B,
bool transA,
bool transB,
float beta,
@@ -95,8 +95,8 @@ void Prod(marian::Tensor C,
}
void ProdBatched(marian::Tensor C,
- const marian::Tensor A,
- const marian::Tensor B,
+ const marian::Tensor& A,
+ const marian::Tensor& B,
bool transA,
bool transB,
float beta,
@@ -159,9 +159,9 @@ void ProdBatched(marian::Tensor C,
}
void ProdWithBias(marian::Tensor C,
- const marian::Tensor A,
- const marian::Tensor B,
- const marian::Tensor bias,
+ const marian::Tensor& A,
+ const marian::Tensor& B,
+ const marian::Tensor& bias,
bool transA,
bool transB,
float beta,
diff --git a/src/tensors/cpu/sharp/avx_gemm.cpp b/src/tensors/cpu/sharp/avx_gemm.cpp
new file mode 100644
index 00000000..ae788be6
--- /dev/null
+++ b/src/tensors/cpu/sharp/avx_gemm.cpp
@@ -0,0 +1,554 @@
+#include <cassert>
+#include <cstddef>
+#include <emmintrin.h>
+#include <immintrin.h>
+#include <math.h>
+#include <stdint.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+#include <tmmintrin.h>
+#include <xmmintrin.h>
+
+#ifdef __AVX512F__
+
+namespace marian {
+namespace cpu {
+namespace int16 {
+
+namespace {
+// Load from memory, multiply, and convert to int32_t.
+inline __m512i QuantizerGrab(const float *input, const __m512 quant_mult_reg) {
+ // Load 16 floats
+ __m512 val = _mm512_load_ps(input);
+ // Multiply each by the quantization factor.
+ val = _mm512_mul_ps(val, quant_mult_reg);
+ // Cast to 32-bit int
+ return _mm512_cvtps_epi32(val);
+}
+} // namespace
+
+// Convert
+void AVX_Quantize16(const float *input, int16_t *output, float quant_mult, std::size_t size) {
+ assert(size % 16 == 0);
+ assert(reinterpret_cast<uintptr_t>(input) % 64 == 0);
+ // Fill with the quantization multiplier.
+ const __m512 quant_mult_reg = _mm512_set1_ps(quant_mult);
+ const float *end = input + size;
+ for (; input != end; input += 16, output += 16) {
+ // There doesn't seem to be an unmasked version.
+ _mm512_mask_cvtsepi32_storeu_epi16(output, 0xffff, QuantizerGrab(input, quant_mult_reg));
+ }
+}
+
+void AVX_Quantize8(const float *input, int8_t *output, float quant_mult, std::size_t size) {
+ assert(size % 16 == 0);
+ assert(reinterpret_cast<uintptr_t>(input) % 64 == 0);
+ const __m512i neg127 = _mm512_set1_epi32(-127);
+ const __m512 quant_mult_reg = _mm512_set1_ps(quant_mult);
+ const float *end = input + size;
+ for (; input < end; input += 16, output += 16) {
+ __m512i asint = QuantizerGrab(input, quant_mult_reg);
+ /* Ban -128. We can't negate it.
+ * The largest possbile product is -128 * -128 = 2^14. If two of those are
+ * summed that's 2^15 which is too large for int16_t. By banning -128 we
+ * can accumulate two in int16_t w/o saturation before going to int32_t.
+ * But this is ok because apparently the instruction will saturate.
+ */
+ asint = _mm512_max_epi32(asint, neg127);
+ // There doesn't seem to be an unmasked version.
+ _mm512_mask_cvtsepi32_storeu_epi8(output, 0xffff, asint);
+ }
+}
+
+namespace {
+
+union FloatAccess {
+ float as_f[4];
+ __m128 as_n;
+};
+union IntAccess {
+ int32_t as_i[4];
+ __m128i as_n;
+};
+
+/* Convert 16-bit to 32-bit and add, not caring what parts are added.
+ * Implementations:
+ * 1. https://github.com/tesseract-ocr/tesseract/blob/master/src/arch/intsimdmatrixavx2.cpp#L67 under Apache license:
+ * This does a multiply by 1 and horizontal add:
+ * _mm512_madd_epi16(sum, _mm512_set1_epi16(1))
+ * Current fastest.
+ *
+ * 2. Signed extension and fold halves:
+ * sum = _mm512_add_epi32(
+ * _mm512_cvtepi16_epi32(_mm512_castsi512_si256(sum)),
+ * _mm512_cvtepi16_epi32(_mm512_extracti64x4_epi64(sum, 1)));
+ *
+ * 3. Sign extend by abuse of bitshift, then add.
+ * __m128i shift16 = _mm_set_epi32(0,0,0,16);
+ * sum = _mm512_add_epi32(
+ * _mm512_sra_epi32(_mm512_sll_epi32(sum, shift16), shift16),
+ * _mm512_sra_epi32(sum, shift16));
+ */
+inline void Convert32Sum(__m512i &sum) {
+ sum = _mm512_madd_epi16(sum, _mm512_set1_epi16(1));
+}
+
+// Two sum version.
+struct ReducedPair {
+ int32_t result[2];
+};
+inline ReducedPair Reduce16to32(__m512i sum1, __m512i sum2) {
+ Convert32Sum(sum1);
+ Convert32Sum(sum2);
+ // 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2
+ __m512i pack12 = _mm512_add_epi32(_mm512_unpackhi_epi32(sum1, sum2), _mm512_unpacklo_epi32(sum1, sum2));
+ // 1 2 1 2 1 2 1 2
+ __m256i halves = _mm256_add_epi32(_mm512_castsi512_si256(pack12), _mm512_extracti64x4_epi64(pack12, 1));
+ // 1 2 1 2
+ IntAccess a;
+ a.as_n = _mm_add_epi32(_mm256_castsi256_si128(halves), _mm256_extracti128_si256(halves, 1));
+ ReducedPair ret;
+ ret.result[0] = a.as_i[0] + a.as_i[2];
+ ret.result[1] = a.as_i[1] + a.as_i[3];
+ return ret;
+}
+
+// Assuming sum1, sum2, sum3, and sum4 are arrays 32-bit signed integers,
+// reduce within each.
+// Returns [sum(sum1), sum(sum2), sum(sum3), sum(sum4)]
+// TODO: consider doing in 64-bit, allowing 4 more bits of quantization?
+inline __m128i Reduce32(__m512i sum1, __m512i sum2, __m512i sum3, __m512i sum4) {
+ // 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2
+ __m512i pack12 = _mm512_add_epi32(_mm512_unpackhi_epi32(sum1, sum2), _mm512_unpacklo_epi32(sum1, sum2));
+ // 3 4 3 4 3 4 3 4 3 4 3 4 3 4 3 4
+ __m512i pack34 = _mm512_add_epi32(_mm512_unpackhi_epi32(sum3, sum4), _mm512_unpacklo_epi32(sum3, sum4));
+ // 1 2 3 4 1 2 3 4 1 2 3 4 1 2 3 4
+ __m512i pack1234 = _mm512_add_epi32(_mm512_unpackhi_epi64(pack12, pack34), _mm512_unpacklo_epi64(pack12, pack34));
+ // Cut the register into halves and sum those. 1 2 3 4 1 2 3 4
+ __m256i halves = _mm256_add_epi32(_mm512_castsi512_si256(pack1234), _mm512_extracti64x4_epi64(pack1234, 1));
+ // Again: cut the register into halves and sum those. 1 2 3 4
+ return _mm_add_epi32(_mm256_castsi256_si128(halves), _mm256_extracti128_si256(halves, 1));
+}
+
+// Four sum version
+inline __m128i Reduce16to32(__m512i sum1, __m512i sum2, __m512i sum3, __m512i sum4) {
+ Convert32Sum(sum1);
+ Convert32Sum(sum2);
+ Convert32Sum(sum3);
+ Convert32Sum(sum4);
+ return Reduce32(sum1, sum2, sum3, sum4);
+}
+
+// Somewhat inefficient reduce for single __m256i containing int32_t
+inline int32_t Reduce32(__m256i halves) {
+ IntAccess a;
+ a.as_n = _mm_add_epi32(_mm256_castsi256_si128(halves), _mm256_extracti128_si256(halves, 1));
+ // TODO is there a more efficient way?
+ return a.as_i[0] + a.as_i[1] + a.as_i[2] + a.as_i[3];
+}
+
+// Somewhat inefficient reduce for single __m512i containing int32_t
+inline int32_t Reduce32(__m512i sum1) {
+ // Fold register over itself.
+ return Reduce32(_mm256_add_epi32(_mm512_castsi512_si256(sum1), _mm512_extracti64x4_epi64(sum1, 1)));
+}
+
+inline int32_t Reduce16to32(__m512i sum1) {
+ Convert32Sum(sum1);
+ // Fold register over itself.
+ return Reduce32(_mm256_add_epi32(_mm512_castsi512_si256(sum1), _mm512_extracti64x4_epi64(sum1, 1)));
+}
+
+class ScatterPut {
+ public:
+ explicit ScatterPut(float unquant_mult, int num_B_rows)
+ : unquant_mult_(unquant_mult),
+ unquant_mult_sse_(_mm_set1_ps(unquant_mult)),
+#ifdef __AVX512VL__
+ num_b_rows_scatter_(_mm_set_epi32(num_B_rows * 3 * sizeof(float), num_B_rows * 2 * sizeof(float), num_B_rows * 1 * sizeof(float), num_B_rows * 0 * sizeof(float))),
+#endif
+ num_B_rows_(num_B_rows) {}
+
+ inline void Write(float *base, __m128i reduced) {
+ __m128 float_sums = _mm_cvtepi32_ps(reduced);
+ float_sums = _mm_mul_ps(float_sums, unquant_mult_sse_);
+#ifdef __AVX512VL__
+ // The scatter instruction requires avx512vl
+ _mm_i32scatter_ps(base, num_b_rows_scatter_, float_sums, 1);
+#else
+ FloatAccess a;
+ // Get floats for each of the sums to write.
+ a.as_n = float_sums;
+ // Also note that the memory acceses on C are not consecutive, but this is a tradeoff that we have to make.
+ // We can't have consecutive accesses of A, B, *and* C. But we access A and B a lot more so it makes
+ // sense to do it this way.
+ // Scatter to outputs:
+ base[0] = a.as_f[0];
+ base[num_B_rows_] = a.as_f[1];
+ base[2*num_B_rows_] = a.as_f[2];
+ base[3*num_B_rows_] = a.as_f[3];
+#endif
+ }
+
+ inline void Write(float *base, ReducedPair reduced) {
+ base[0] = unquant_mult_ * static_cast<float>(reduced.result[0]);
+ base[num_B_rows_] = unquant_mult_ * static_cast<float>(reduced.result[1]);
+ }
+
+ inline void Write(float *base, int32_t reduced) {
+ base[0] = unquant_mult_ * static_cast<float>(reduced);
+ }
+
+ private:
+ const float unquant_mult_;
+ const __m128 unquant_mult_sse_;
+#ifdef __AVX512VL__
+ const __m128i num_b_rows_scatter_;
+#endif
+ const int num_B_rows_;
+};
+
+} // namespace
+
+
+// This is an AVX512F implementation of int16_t multiply based on Jacob
+// Devlin's SSE code. The original SSE code was:
+
+// Copyright (c) 2017 Microsoft Corporation
+
+// Permission is hereby granted, free of charge, to any person obtaining a copy
+// of this software and associated documentation files (the "Software"), to deal
+// in the Software without restriction, including without limitation the rights
+// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+// copies of the Software, and to permit persons to whom the Software is
+// furnished to do so, subject to the following conditions:
+
+// The above copyright notice and this permission notice shall be included in all
+// copies or substantial portions of the Software.
+
+// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+// SOFTWARE.
+
+
+// We are multiplying A * B^T, as opposed to A * B. This is important because it means we can do consecutive memory access on A * B^T which allows to to take the most
+// advantage of L1 cache.
+//
+// B is typically a weight matrix, so it can be pre-processed offline, and therefore this transpose does not cost anything.
+// A is typically an activation minibatch matrix.
+// A and B must be 64-byte aligned.
+// C should be the usual 4-byte alignment.
+void AVX_MatrixMult16(const __m512i * A, const __m512i * B, float * C, float unquant_mult, int num_A_rows, int num_B_rows, int width) {
+ assert(width % 32 == 0);
+ assert(reinterpret_cast<uintptr_t>(A) % 64 == 0);
+ assert(reinterpret_cast<uintptr_t>(B) % 64 == 0);
+
+ ScatterPut put(unquant_mult, num_B_rows);
+
+ const int sse_width = width/32;
+
+ // We do loop unrolling over A. This is *significantly* faster
+ // since B can live in the registers. We are assuming that
+ // A is a multiple of 4, but we can add extra code to handle values of 1, 2, 3.
+ //
+ // We could also do loop unrolling over B, which adds some additional speedup.
+ // We don't do that for the sake of clarity.
+ //
+ // There are other memory access patterns we could do, e.g., put B on the outer loop.
+ // The justification is that A is typically small enough that it can live in L1 cache.
+ // B is usually a larger weight matrix, so it might not be able to. However, we are using
+ // each element of B four times while it's still in a register, so caching is not as important.
+
+ // Round down to a multiple of 4.
+ int num_unroll_rows = num_A_rows & ~3;
+ for (int i = 0; i < num_unroll_rows; i += 4) {
+ const __m512i * A1_row = A + (i+0)*sse_width;
+ const __m512i * A2_row = A + (i+1)*sse_width;
+ const __m512i * A3_row = A + (i+2)*sse_width;
+ const __m512i * A4_row = A + (i+3)*sse_width;
+
+ for (int j = 0; j < num_B_rows; j++) {
+ const __m512i * B_row = B + j*sse_width;
+
+ __m512i sum1 = _mm512_setzero_si512();
+ __m512i sum2 = _mm512_setzero_si512();
+ __m512i sum3 = _mm512_setzero_si512();
+ __m512i sum4 = _mm512_setzero_si512();
+
+ // This is just a simple dot product, unrolled four ways.
+ for (int k = 0; k < sse_width; k++) {
+ __m512i b = *(B_row + k);
+
+ __m512i a1 = *(A1_row + k);
+ __m512i a2 = *(A2_row + k);
+ __m512i a3 = *(A3_row + k);
+ __m512i a4 = *(A4_row + k);
+
+ // madd_epi16 does multiply add on 8 16-bit integers and accumulates into a four 32-bit register.
+ // E.g.,
+ // a1 = [f1, f2, f3, f4, f5, f6, f7, h8] (16-bit ints)
+ // b1 = [h1, h2, h3, h4, h5, h6, h7, h8] (16-bit ints)
+ // result = [f1*h1 + f2*h2, f3*h3 + f4*h4, f5*h5 + f6*h6, f7*h7 + f8*h8] (32-bit ints)
+ // Then add_epi32 just effectively does a += on these 32-bit integers.
+ sum1 = _mm512_add_epi32(sum1, _mm512_madd_epi16(b, a1));
+ sum2 = _mm512_add_epi32(sum2, _mm512_madd_epi16(b, a2));
+ sum3 = _mm512_add_epi32(sum3, _mm512_madd_epi16(b, a3));
+ sum4 = _mm512_add_epi32(sum4, _mm512_madd_epi16(b, a4));
+ }
+ put.Write(C + i * num_B_rows + j, Reduce32(sum1, sum2, sum3, sum4));
+ }
+ }
+ // Handle the non-multiples of 4 rows.
+ // TODO: efficient version for 3 rows, 2 rows, etc.
+ for (int i = num_unroll_rows; i < num_A_rows; ++i) {
+ const __m512i * A1_row = A + i * sse_width;
+ for (int j = 0; j < num_B_rows; j++) {
+ __m512i sum1 = _mm512_setzero_si512();
+ for (int k = 0; k < sse_width; k++) {
+ const __m512i * B_row = B + j*sse_width;
+ __m512i b = *(B_row + k);
+ __m512i a1 = *(A1_row + k);
+ sum1 = _mm512_add_epi32(sum1, _mm512_madd_epi16(b, a1));
+ }
+ // TODO is there a more efficient way?
+ *(C + (i)*num_B_rows + j) = unquant_mult * static_cast<float>(Reduce32(sum1));
+ }
+ }
+}
+
+namespace {
+
+/* Three ways considered to apply sign bits:
+ * 1. Use 256-bit sign instruction:
+ * __m256i a_first = _mm256_sign_epi8(_mm512_castsi512_si256(a), _mm512_castsi512_si256(b));
+ * __m256i a_second = _mm256_sign_epi8(_mm512_extracti64x4_epi64(a, 1), b_second);
+ * a = _mm512_inserti64x4(_mm512_castsi256_si512(a_first), a_second, 1);
+ * a = Concat(a_first, a_second);
+ *
+ * 2. Extract a mask and xor + 1
+ * __mmask64 neg_mask _mm512_test_epi8_mask(b, _mm512_set1_epi8(-128));
+ * Use set1 to to build to_xor
+ * a = _mm512_xor_si512(a, to_xor)
+ * And add one:
+ * const __m512i ones8 = _mm512_set1_epi8(1);
+ * a = _mm512_mask_add_epi8(a, neg_mask, a, ones8);
+ *
+ * 3. Extract a mask and subtract from 0
+ * In the outer loop on b:
+ * __mmask64 neg_mask _mm512_test_epi8_mask(b, _mm512_set1_epi8(-128))
+ * For each a:
+ * a = _mm512_mask_sub_epi8(a, neg_mask, _mm512_setzero_si512(), a);
+ *
+ * Finally, subtraction won the benchmark
+ */
+inline void Accum(const __m512i zeros, __m512i a, const __m512i b, const __m512i b_positive, const __mmask64 neg_mask, __m512i &sum) {
+ // Apply sign bits.
+ a = _mm512_mask_sub_epi8(a, neg_mask, zeros, a);
+ // The magic 8-bit multiply then horizontal sum into 16-bit.
+ __m512i multiplied = _mm512_maddubs_epi16(b_positive, a);
+ // Now we have 16-bit results that are the sum of two multiplies.
+ // Choosing to approximate and do adds.
+ // Perhaps every so often we could accumulate by Convert32Sum
+ sum = _mm512_adds_epi16(sum, multiplied);
+}
+
+} // namespace
+
+void AVX_MatrixMult8(const __m512i * A, const __m512i * B, float * C, float unquant_mult, int num_A_rows, int num_B_rows, int width) {
+ assert(width % 32 == 0);
+ assert(reinterpret_cast<uintptr_t>(A) % 64 == 0);
+ assert(reinterpret_cast<uintptr_t>(B) % 64 == 0);
+ ScatterPut put(unquant_mult, num_B_rows);
+ const __m512i zeros = _mm512_setzero_si512();
+
+ const int sse_width = width/64;
+ int i = 0;
+ int mult8rows = num_A_rows & (~7);
+
+ for (; i < mult8rows; i += 8) {
+ const __m512i *A1_row = A + (i+0)*sse_width;
+ const __m512i *A2_row = A + (i+1)*sse_width;
+ const __m512i *A3_row = A + (i+2)*sse_width;
+ const __m512i *A4_row = A + (i+3)*sse_width;
+ const __m512i *A5_row = A + (i+4)*sse_width;
+ const __m512i *A6_row = A + (i+5)*sse_width;
+ const __m512i *A7_row = A + (i+6)*sse_width;
+ const __m512i *A8_row = A + (i+7)*sse_width;
+ for (int j = 0; j < num_B_rows; j++) {
+ const __m512i *B_row = B + j*sse_width;
+ __m512i sum1 = _mm512_setzero_si512();
+ __m512i sum2 = _mm512_setzero_si512();
+ __m512i sum3 = _mm512_setzero_si512();
+ __m512i sum4 = _mm512_setzero_si512();
+ __m512i sum5 = _mm512_setzero_si512();
+ __m512i sum6 = _mm512_setzero_si512();
+ __m512i sum7 = _mm512_setzero_si512();
+ __m512i sum8 = _mm512_setzero_si512();
+ for (int k = 0; k < sse_width; k++) {
+ __m512i b = *(B_row + k);
+ __m512i b_positive = _mm512_abs_epi8(b);
+ /* Didn't seem to make a difference definining sign bits here vs at top */
+ __mmask64 neg_mask = _mm512_test_epi8_mask(b, _mm512_set1_epi8(-128));
+ Accum(zeros, *(A1_row + k), b, b_positive, neg_mask, sum1);
+ Accum(zeros, *(A2_row + k), b, b_positive, neg_mask, sum2);
+ Accum(zeros, *(A3_row + k), b, b_positive, neg_mask, sum3);
+ Accum(zeros, *(A4_row + k), b, b_positive, neg_mask, sum4);
+ Accum(zeros, *(A5_row + k), b, b_positive, neg_mask, sum5);
+ Accum(zeros, *(A6_row + k), b, b_positive, neg_mask, sum6);
+ Accum(zeros, *(A7_row + k), b, b_positive, neg_mask, sum7);
+ Accum(zeros, *(A8_row + k), b, b_positive, neg_mask, sum8);
+ }
+ put.Write(C + i *num_B_rows + j, Reduce16to32(sum1, sum2, sum3, sum4));
+ put.Write(C + (i+4) *num_B_rows + j, Reduce16to32(sum5, sum6, sum7, sum8));
+ }
+ }
+
+ const __m512i *A1_row = A + (i+0)*sse_width;
+ const __m512i *A2_row = A + (i+1)*sse_width;
+ const __m512i *A3_row = A + (i+2)*sse_width;
+ const __m512i *A4_row = A + (i+3)*sse_width;
+ const __m512i *A5_row = A + (i+4)*sse_width;
+ const __m512i *A6_row = A + (i+5)*sse_width;
+ const __m512i *A7_row = A + (i+6)*sse_width;
+ switch (num_A_rows & 7) {
+ case 7:
+ for (int j = 0; j < num_B_rows; j++) {
+ const __m512i *B_row = B + j*sse_width;
+ __m512i sum1 = _mm512_setzero_si512();
+ __m512i sum2 = _mm512_setzero_si512();
+ __m512i sum3 = _mm512_setzero_si512();
+ __m512i sum4 = _mm512_setzero_si512();
+ __m512i sum5 = _mm512_setzero_si512();
+ __m512i sum6 = _mm512_setzero_si512();
+ __m512i sum7 = _mm512_setzero_si512();
+ for (int k = 0; k < sse_width; k++) {
+ __m512i b = *(B_row + k);
+ __m512i b_positive = _mm512_abs_epi8(b);
+ __mmask64 neg_mask = _mm512_test_epi8_mask(b, _mm512_set1_epi8(-128));
+ Accum(zeros, *(A1_row + k), b, b_positive, neg_mask, sum1);
+ Accum(zeros, *(A2_row + k), b, b_positive, neg_mask, sum2);
+ Accum(zeros, *(A3_row + k), b, b_positive, neg_mask, sum3);
+ Accum(zeros, *(A4_row + k), b, b_positive, neg_mask, sum4);
+ Accum(zeros, *(A5_row + k), b, b_positive, neg_mask, sum5);
+ Accum(zeros, *(A6_row + k), b, b_positive, neg_mask, sum6);
+ Accum(zeros, *(A7_row + k), b, b_positive, neg_mask, sum7);
+ }
+ put.Write(C + i *num_B_rows + j, Reduce16to32(sum1, sum2, sum3, sum4));
+ put.Write(C + (i+4) *num_B_rows + j, Reduce16to32(sum5, sum6));
+ put.Write(C + (i+6) *num_B_rows + j, Reduce16to32(sum7));
+ }
+ case 6:
+ for (int j = 0; j < num_B_rows; j++) {
+ const __m512i *B_row = B + j*sse_width;
+ __m512i sum1 = _mm512_setzero_si512();
+ __m512i sum2 = _mm512_setzero_si512();
+ __m512i sum3 = _mm512_setzero_si512();
+ __m512i sum4 = _mm512_setzero_si512();
+ __m512i sum5 = _mm512_setzero_si512();
+ __m512i sum6 = _mm512_setzero_si512();
+ for (int k = 0; k < sse_width; k++) {
+ __m512i b = *(B_row + k);
+ __m512i b_positive = _mm512_abs_epi8(b);
+ __mmask64 neg_mask = _mm512_test_epi8_mask(b, _mm512_set1_epi8(-128));
+ Accum(zeros, *(A1_row + k), b, b_positive, neg_mask, sum1);
+ Accum(zeros, *(A2_row + k), b, b_positive, neg_mask, sum2);
+ Accum(zeros, *(A3_row + k), b, b_positive, neg_mask, sum3);
+ Accum(zeros, *(A4_row + k), b, b_positive, neg_mask, sum4);
+ Accum(zeros, *(A5_row + k), b, b_positive, neg_mask, sum5);
+ Accum(zeros, *(A6_row + k), b, b_positive, neg_mask, sum6);
+ }
+ put.Write(C + i *num_B_rows + j, Reduce16to32(sum1, sum2, sum3, sum4));
+ put.Write(C + (i+4) *num_B_rows + j, Reduce16to32(sum5, sum6));
+ }
+ case 5:
+ for (int j = 0; j < num_B_rows; j++) {
+ const __m512i *B_row = B + j*sse_width;
+ __m512i sum1 = _mm512_setzero_si512();
+ __m512i sum2 = _mm512_setzero_si512();
+ __m512i sum3 = _mm512_setzero_si512();
+ __m512i sum4 = _mm512_setzero_si512();
+ __m512i sum5 = _mm512_setzero_si512();
+ for (int k = 0; k < sse_width; k++) {
+ __m512i b = *(B_row + k);
+ __m512i b_positive = _mm512_abs_epi8(b);
+ __mmask64 neg_mask = _mm512_test_epi8_mask(b, _mm512_set1_epi8(-128));
+ Accum(zeros, *(A1_row + k), b, b_positive, neg_mask, sum1);
+ Accum(zeros, *(A2_row + k), b, b_positive, neg_mask, sum2);
+ Accum(zeros, *(A3_row + k), b, b_positive, neg_mask, sum3);
+ Accum(zeros, *(A4_row + k), b, b_positive, neg_mask, sum4);
+ Accum(zeros, *(A5_row + k), b, b_positive, neg_mask, sum5);
+ }
+ put.Write(C + i *num_B_rows + j, Reduce16to32(sum1, sum2, sum3, sum4));
+ put.Write(C + (i+4) *num_B_rows + j, Reduce16to32(sum5));
+ }
+ case 4:
+ for (int j = 0; j < num_B_rows; j++) {
+ const __m512i *B_row = B + j*sse_width;
+ __m512i sum1 = _mm512_setzero_si512();
+ __m512i sum2 = _mm512_setzero_si512();
+ __m512i sum3 = _mm512_setzero_si512();
+ __m512i sum4 = _mm512_setzero_si512();
+ for (int k = 0; k < sse_width; k++) {
+ __m512i b = *(B_row + k);
+ __m512i b_positive = _mm512_abs_epi8(b);
+ __mmask64 neg_mask = _mm512_test_epi8_mask(b, _mm512_set1_epi8(-128));
+ Accum(zeros, *(A1_row + k), b, b_positive, neg_mask, sum1);
+ Accum(zeros, *(A2_row + k), b, b_positive, neg_mask, sum2);
+ Accum(zeros, *(A3_row + k), b, b_positive, neg_mask, sum3);
+ Accum(zeros, *(A4_row + k), b, b_positive, neg_mask, sum4);
+ }
+ put.Write(C + i *num_B_rows + j, Reduce16to32(sum1, sum2, sum3, sum4));
+ }
+ case 3:
+ for (int j = 0; j < num_B_rows; j++) {
+ const __m512i *B_row = B + j*sse_width;
+ __m512i sum1 = _mm512_setzero_si512();
+ __m512i sum2 = _mm512_setzero_si512();
+ __m512i sum3 = _mm512_setzero_si512();
+ for (int k = 0; k < sse_width; k++) {
+ __m512i b = *(B_row + k);
+ __m512i b_positive = _mm512_abs_epi8(b);
+ __mmask64 neg_mask = _mm512_test_epi8_mask(b, _mm512_set1_epi8(-128));
+ Accum(zeros, *(A1_row + k), b, b_positive, neg_mask, sum1);
+ Accum(zeros, *(A2_row + k), b, b_positive, neg_mask, sum2);
+ Accum(zeros, *(A3_row + k), b, b_positive, neg_mask, sum3);
+ }
+ put.Write(C + i *num_B_rows + j, Reduce16to32(sum1, sum2));
+ put.Write(C + (i+2) *num_B_rows + j, Reduce16to32(sum3));
+ }
+ case 2:
+ for (int j = 0; j < num_B_rows; j++) {
+ const __m512i *B_row = B + j*sse_width;
+ __m512i sum1 = _mm512_setzero_si512();
+ __m512i sum2 = _mm512_setzero_si512();
+ for (int k = 0; k < sse_width; k++) {
+ __m512i b = *(B_row + k);
+ __m512i b_positive = _mm512_abs_epi8(b);
+ __mmask64 neg_mask = _mm512_test_epi8_mask(b, _mm512_set1_epi8(-128));
+ Accum(zeros, *(A1_row + k), b, b_positive, neg_mask, sum1);
+ Accum(zeros, *(A2_row + k), b, b_positive, neg_mask, sum2);
+ }
+ put.Write(C + i *num_B_rows + j, Reduce16to32(sum1, sum2));
+ }
+ case 1:
+ for (int j = 0; j < num_B_rows; j++) {
+ const __m512i *B_row = B + j*sse_width;
+ __m512i sum1 = _mm512_setzero_si512();
+ for (int k = 0; k < sse_width; k++) {
+ __m512i b = *(B_row + k);
+ __m512i b_positive = _mm512_abs_epi8(b);
+ __mmask64 neg_mask = _mm512_test_epi8_mask(b, _mm512_set1_epi8(-128));
+ Accum(zeros, *(A1_row + k), b, b_positive, neg_mask, sum1);
+ }
+ put.Write(C + i *num_B_rows + j, Reduce16to32(sum1));
+ }
+ }
+}
+
+}}} // namespaces
+#endif
diff --git a/src/tensors/cpu/sharp/int_gemm.h b/src/tensors/cpu/sharp/int_gemm.h
new file mode 100644
index 00000000..908c7cd6
--- /dev/null
+++ b/src/tensors/cpu/sharp/int_gemm.h
@@ -0,0 +1,137 @@
+#pragma once
+
+#include "tensors/tensor.h"
+#include "tensors/tensor_allocator.h"
+#include "tensors/tensor_operators.h"
+
+#include <cassert>
+#include <cstddef>
+#include <emmintrin.h>
+#include <immintrin.h>
+#include <tmmintrin.h>
+#include <xmmintrin.h>
+
+namespace marian {
+namespace cpu {
+namespace int16 {
+
+const int BITS = 10;
+
+#ifdef __AVX512F__
+void AVX_Quantize16(const float *input, int16_t *output, float quant_mult, std::size_t size);
+void AVX_Quantize8(const float *input, int8_t *output, float quant_mult, std::size_t size);
+void AVX_MatrixMult16(const __m512i * A, const __m512i * B, float * C, float unquant_mult, int num_A_rows, int num_B_rows, int width);
+void AVX_MatrixMult8(const __m512i * A, const __m512i * B, float * C, float unquant_mult, int num_A_rows, int num_B_rows, int width);
+#endif
+void SSE_Quantize16(const float * input, __m128i * output, float quant_mult, int num_rows, int width);
+void SSE_MatrixMult16(const __m128i * A, const __m128i * B, float * C, float unquant_mult, int num_A_rows, int num_B_rows, int width);
+
+static inline void Quantize16(marian::Tensor out,
+ const marian::Tensor in,
+ float clipValue) {
+
+ float quant_mult = pow(2.0, (float)BITS);
+#ifdef __AVX512F__
+ AVX_Quantize16(in->data(), out->data<int16_t>(), quant_mult, in->shape().elements());
+#else
+ int num_rows = in->shape().elements() / in->shape()[-1];
+ int width = in->shape()[-1];
+ SSE_Quantize16(in->data(), out->data<__m128i>(), quant_mult, num_rows, width);
+#endif
+}
+
+static inline void Quantize8(marian::Tensor out,
+ const marian::Tensor in,
+ float clipValue) {
+#ifdef __AVX512F__
+ float quant_mult = 127.0 / clipValue;
+ AVX_Quantize8(in->data(), out->data<int8_t>(), quant_mult, in->shape().elements());
+#else
+ ABORT("8-bit is currently only AVX512");
+#endif
+}
+
+// This operates on floats after processing so doesn't care about int8_t vs int16_t.
+static void AddBias(marian::Tensor C, const marian::Tensor Bias) {
+ float* y = C->data();
+ const float* x = C->data();
+ const float* bias = Bias->data();
+
+ int m = C->shape().elements() / C->shape()[-1];
+ int n = C->shape()[-1];
+#ifdef __AVX512F__
+ int n16 = n & ~15;
+#else
+ int n4 = (n / 4) * 4;
+#endif
+
+ for(int j = 0; j < m; ++j) {
+ int i = 0;
+#ifdef __AVX512F__
+ for (; i < n16; i += 16) {
+ __m512 ai = _mm512_loadu_ps(x + j * n + i);
+ __m512 bi = _mm512_loadu_ps(bias + i);
+ __m512 yi = _mm512_add_ps(ai, bi);
+ _mm512_storeu_ps(y + j * n + i, yi);
+ }
+#else
+ for (; i < n4; i += 4) {
+ __m128 ai = _mm_loadu_ps(x + j * n + i);
+ __m128 bi = _mm_loadu_ps(bias + i);
+ __m128 yi = _mm_add_ps(ai, bi);
+ _mm_storeu_ps(y + j * n + i, yi);
+ }
+#endif
+ for (; i < n; i++) {
+ y[j * n + i] = x[j * n + i] + bias[i];
+ }
+ }
+}
+
+static void ProdInt16(marian::Tensor C,
+ const marian::Tensor A,
+ const marian::Tensor B,
+ float scale) {
+ ABORT_IF(scale != 1, "Scale other than 1 not supported");
+
+ // @TODO: make this a parameter
+ float quant_mult = pow(2.0, (float)BITS);
+
+ // If we quantize to n bits and then multiple the values together, the result will be quantized to n^2 bits.
+ // So we must divide by 1.0/(n^2) to get back the original value.
+ float unquant_mult = 1.0 / (quant_mult * quant_mult);
+
+ float* fC = C->data();
+ int num_A_rows = A->shape().elements() / A->shape()[-1];
+ int num_B_rows = B->shape().elements() / B->shape()[-1];
+ int width = B->shape()[-1];
+#ifdef __AVX512F__
+ AVX_MatrixMult16(A->data<__m512i>(), B->data<__m512i>(), fC, unquant_mult, num_A_rows, num_B_rows, width);
+#else
+ SSE_MatrixMult16(A->data<__m128i>(), B->data<__m128i>(), fC, unquant_mult, num_A_rows, num_B_rows, width);
+#endif
+}
+
+static void ProdInt8(marian::Tensor C,
+ const marian::Tensor A,
+ const marian::Tensor B,
+ float scale,
+ float clipValue) {
+#ifdef __AVX512F__
+ // This would be easy...
+ ABORT_IF(scale != 1, "Scale other than 1 not supported");
+ float quant_mult = 127.0 / clipValue;
+ float unquant_mult = 1.0 / (quant_mult * quant_mult);
+
+ float* fC = C->data();
+ int num_A_rows = A->shape().elements() / A->shape()[-1];
+ int num_B_rows = B->shape().elements() / B->shape()[-1];
+ int width = B->shape()[-1];
+ AVX_MatrixMult8(A->data<__m512i>(), B->data<__m512i>(), fC, unquant_mult, num_A_rows, num_B_rows, width);
+#else
+ ABORT("8-bit is currently only AVX512");
+#endif
+
+}
+
+}}} // namespaces
diff --git a/src/tensors/cpu/sharp/sse_gemm.h b/src/tensors/cpu/sharp/sse_gemm.cpp
index 542c3f6d..bbe2ace4 100755..100644
--- a/src/tensors/cpu/sharp/sse_gemm.h
+++ b/src/tensors/cpu/sharp/sse_gemm.cpp
@@ -18,12 +18,6 @@
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
// SOFTWARE.
-#pragma once
-
-#include "tensors/tensor.h"
-#include "tensors/tensor_allocator.h"
-#include "tensors/tensor_operators.h"
-
#include <cassert>
#include <emmintrin.h>
#include <immintrin.h>
@@ -35,6 +29,10 @@
#include <tmmintrin.h>
#include <xmmintrin.h>
+namespace marian {
+namespace cpu {
+namespace int16 {
+
// This is a reference implementation of 16-bit matrix multiplication described in "Sharp Models on Dull Hardware: Fast and Accurate Neural Machine Translation Decoding on the CPU".
// This model is not as fast as the one in the paper, becuase it uses SSE2 instead of AVX2. AVX2 instructions are only available on more modern CPUs (Haswell or later).
// The only difference between SSE2 and AVX2 is that SSE operates on 128-bit vectors and AVX2 operates on 256-bit vecetors. So AVX2 can fit 16 16-bit integers intead of 8 8-bit integers.
@@ -52,60 +50,44 @@
// *impossible* to overflow. If we used, say, n = 12 bits, then we have 32-(12*2) = 8 bits left over. So we *could* overflow if width > 2^8.
//
// So, the tradeoff is between quantization precision and possibility of overflow. A good general value is 10 bits, since this gives high precision
-// (precision is 1/2^10 ~= 0.001, which is more than what's needed for almost all neural nets), and cannot overflow unless the matrix width is > 4096.
+// (precision is 1/2^10 ~= 0.001, which is more than what's needed for almost all neural nets), and cannot overflow unless the matrix width is > 4096.
// This quantizes floating point values into fixed-point 16-bit integers. Effectively, we are performing an SSE version of
// float x = ...;
// int16_t y = (int16_t)(quant_mult*x);
-//
+//
// Except that the casting is saturated. However, you should always ensure that the input fits into a fixed range anyways.
// I.e., you should ensure that quant_mult*x fits into the range [-2^15, 2^15].
// This should always be possible because the value you're quantizing will either be NN weights or NN activations, both of
// which can be clipped to a fixed range during training.
-namespace marian {
-
-namespace cpu {
-namespace int16 {
-
-const int BITS = 10;
-
-static inline void Quantize(marian::Tensor out,
- const marian::Tensor in) {
-
- int num_rows = in->shape().elements() / in->shape()[-1];
- int width = in->shape()[-1];
- ABORT_IF(width % 8 != 0, "Width {} is not divisble by 8", width);
-
- const float* input = in->data();
- __m128i* output = out->data<__m128i>();
-
- float quant_mult = pow(2.0, (float)BITS);
-
- int num_input_chunks = width / 8;
-
+void SSE_Quantize16(const float * input, __m128i * output, float quant_mult, int num_rows, int width) {
+ assert(width % 8 == 0);
+
+ int num_input_chunks = width/8;
+
// Fill an SSE float with 4 copies of the quant mult
__m128 sse_quant_mult = _mm_set_ps(quant_mult, quant_mult, quant_mult, quant_mult);
-
+
for (int i = 0; i < num_rows; i++) {
- const float* input_row = input + i * width;
- __m128i* output_row = output + i * num_input_chunks;
+ const float * input_row = input + i*width;
+ __m128i * output_row = output + i*num_input_chunks;
for (int j = 0; j < num_input_chunks; j++) {
- const float* x = input_row + j * 8;
+ const float * x = input_row + j*8;
// Process 8 floats at once, since each __m128i can contain 8 16-bit integers.
-
- // Load floats into SSE registers.
+
+ // Load floats floats into SSE registers.
__m128 f_0 = _mm_loadu_ps(x);
__m128 f_1 = _mm_loadu_ps(x + 4);
-
+
// Multiply by quantization factor (e.g., if quant_mult = 1000.0, 0.34291 --> 342.21)
__m128 m_0 = _mm_mul_ps(f_0, sse_quant_mult);
__m128 m_1 = _mm_mul_ps(f_1, sse_quant_mult);
-
+
// Cast float to 32-bit int (e.g., 342.21 --> 342)
__m128i i_0 = _mm_cvtps_epi32(m_0);
__m128i i_1 = _mm_cvtps_epi32(m_1);
-
+
// Cast 32-bit int to 16-bit int. You must ensure that these fit into the 16-bit range
// by clipping values during training.
*(output_row + j) = _mm_packs_epi32(i_0, i_1);
@@ -113,27 +95,14 @@ static inline void Quantize(marian::Tensor out,
}
}
-
// We are multiplying A * B^T, as opposed to A * B. This is important because it means we can do consecutive memory access on A * B^T which allows to to take the most
// advantage of L1 cache.
-//
+//
// B is typically a weight matrix, so it can be pre-processed offline, and therefore this transpose does not cost anything.
// A is typically an activation minibatch matrix.
-static inline void SSE_MatrixMult(marian::Tensor C,
- const marian::Tensor A,
- const marian::Tensor B,
- float unquant_mult,
- float scale)
+void SSE_MatrixMult16(const __m128i * qA, const __m128i * qB, float * fC, float unquant_mult, int num_A_rows, int num_B_rows, int width)
{
- const __m128i* qA = A->data<__m128i>();
- const __m128i* qB = B->data<__m128i>();
- float* fC = C->data();
-
- int num_A_rows = A->shape().elements() / A->shape()[-1];
- int num_B_rows = B->shape().elements() / B->shape()[-1];
- int width = B->shape()[-1];
-
- ABORT_IF(width % 8 != 0, "Width {} is not divisble by 8", width);
+ assert(width % 8 == 0);
int sse_width = width / 8;
@@ -339,46 +308,4 @@ static inline void SSE_MatrixMult(marian::Tensor C,
}
}
-static void AddBias(marian::Tensor C, const marian::Tensor Bias) {
- float* y = C->data();
- const float* x = C->data();
- const float* bias = Bias->data();
-
- int m = C->shape().elements() / C->shape()[-1];
- int n = C->shape()[-1];
- int n4 = (n / 4) * 4;
-
- for(int j = 0; j < m; ++j) {
- for (int i = 0; i < n4; i += 4) {
- __m128 ai = _mm_loadu_ps(x + j * n + i);
- __m128 bi = _mm_loadu_ps(bias + i);
- __m128 yi = _mm_add_ps(ai, bi);
- _mm_storeu_ps(y + j * n + i, yi);
- }
- for (int i = n4; i < n; i++) {
- y[j * n + i] = x[j * n + i] + bias[i];
- }
- }
-}
-
-static void ProdInt(marian::Tensor C,
- const marian::Tensor A,
- const marian::Tensor B,
- float scale) {
-
- ABORT_IF(scale != 1, "Scale other than 1 not supported");
-
- // @TODO: make this a parameter
- float quant_mult = pow(2.0, (float)BITS);
-
- // If we quantize to n bits and then multiple the values together, the result will be quantized to n^2 bits.
- // So we must divide by 1.0/(n^2) to get back the original value.
- float unquant_mult = 1.0 / (quant_mult * quant_mult);
-
- SSE_MatrixMult(C, A, B, unquant_mult, scale);
-}
-
-}
-}
-
-}
+}}} // namespaces
diff --git a/src/tensors/gpu/add.inc b/src/tensors/gpu/add.inc
index e15ba7b3..9668a0b3 100644
--- a/src/tensors/gpu/add.inc
+++ b/src/tensors/gpu/add.inc
@@ -14,4 +14,4 @@ template void Add<BinaryFunctor<elem::Mult, Assignee<1>, BinaryFunctor<elem::Div
template void Add<BinaryFunctor<elem::Mult, Assignee<1>, BinaryFunctor<elem::sPReLUBack, Assignee<2>, Capture>>, marian::Tensor, marian::Tensor>(BinaryFunctor<elem::Mult, Assignee<1>, BinaryFunctor<elem::sPReLUBack, Assignee<2>, Capture>>, float, marian::Tensor, marian::Tensor, marian::Tensor);
template void Add<BinaryFunctor<elem::Mult, Assignee<1>, UnaryFunctor<elem::sReLUBack, Assignee<2>>>, marian::Tensor, marian::Tensor>(BinaryFunctor<elem::Mult, Assignee<1>, UnaryFunctor<elem::sReLUBack, Assignee<2>>>, float, marian::Tensor, marian::Tensor, marian::Tensor);
template void Add<BinaryFunctor<elem::Mult, BinaryFunctor<elem::Mult, Assignee<1>, Assignee<2>>, BinaryFunctor<elem::Minus, Capture, Assignee<2>>>, marian::Tensor, marian::Tensor>(BinaryFunctor<elem::Mult, BinaryFunctor<elem::Mult, Assignee<1>, Assignee<2>>, BinaryFunctor<elem::Minus, Capture, Assignee<2>>>, float, marian::Tensor, marian::Tensor, marian::Tensor);
-
+template void Add<BinaryFunctor<elem::Mult, BinaryFunctor<elem::Bump, Assignee<1>, Capture>, Assignee<2>>, std::shared_ptr<marian::TensorBase>, std::shared_ptr<marian::TensorBase> >(BinaryFunctor<elem::Mult, BinaryFunctor<elem::Bump, Assignee<1>, Capture>, Assignee<2> >, float, std::shared_ptr<marian::TensorBase>, marian::Tensor, marian::Tensor);
diff --git a/src/tensors/gpu/element.inc b/src/tensors/gpu/element.inc
index 27f12d0d..002c668f 100644
--- a/src/tensors/gpu/element.inc
+++ b/src/tensors/gpu/element.inc
@@ -36,4 +36,5 @@ template void Element<Assign<Var<1>, TernaryFunctor<elem::IfThenElse, BinaryFunc
template void Element<Assign<Var<1>, TernaryFunctor<elem::IfThenElse, BinaryFunctor<elem::Gt, UnaryFunctor<elem::Abs, Assignee<2>>, Capture>, Capture, Assignee<2>>>, marian::Tensor>(Assign<Var<1>, TernaryFunctor<elem::IfThenElse, BinaryFunctor<elem::Gt, UnaryFunctor<elem::Abs, Assignee<2>>, Capture>, Capture, Assignee<2>>>, marian::Tensor, marian::Tensor);
template void Element<Assign<Var<1>, TernaryFunctor<elem::IfThenElse, BinaryFunctor<elem::Lt, UnaryFunctor<elem::Abs, Assignee<1>>, Capture>, Capture, Assignee<1>>>>(Assign<Var<1>, TernaryFunctor<elem::IfThenElse, BinaryFunctor<elem::Lt, UnaryFunctor<elem::Abs, Assignee<1>>, Capture>, Capture, Assignee<1>>>, marian::Tensor);
template void Element<Assign<Var<1>, TernaryFunctor<elem::IfThenElse, BinaryFunctor<elem::Leq, UnaryFunctor<elem::Abs, Assignee<1> >, Capture>, Capture, Assignee<1> > >>(Assign<Var<1>, TernaryFunctor<elem::IfThenElse, BinaryFunctor<elem::Leq, UnaryFunctor<elem::Abs, Assignee<1> >, Capture>, Capture, Assignee<1> > >, marian::Tensor);
-template void Element<Assign<Var<1>, TernaryFunctor<elem::IfThenElse, BinaryFunctor<elem::Leq, UnaryFunctor<elem::Abs, Assignee<2> >, Capture>, Capture, Capture> >, marian::Tensor >(Assign<Var<1>, TernaryFunctor<elem::IfThenElse, BinaryFunctor<elem::Leq, UnaryFunctor<elem::Abs, Assignee<2> >, Capture>, Capture, Capture> >, marian::Tensor, marian::Tensor); \ No newline at end of file
+template void Element<Assign<Var<1>, TernaryFunctor<elem::IfThenElse, BinaryFunctor<elem::Leq, UnaryFunctor<elem::Abs, Assignee<2> >, Capture>, Capture, Capture> >, marian::Tensor >(Assign<Var<1>, TernaryFunctor<elem::IfThenElse, BinaryFunctor<elem::Leq, UnaryFunctor<elem::Abs, Assignee<2> >, Capture>, Capture, Capture> >, marian::Tensor, marian::Tensor);
+template void Element<Assign<Var<1>, BinaryFunctor<elem::Clip, Assignee<2>, Capture>>, marian::Tensor>(Assign<Var<1>, BinaryFunctor<elem::Clip, Assignee<2>, Capture>>, marian::Tensor, marian::Tensor);
diff --git a/src/tensors/gpu/prod.cu b/src/tensors/gpu/prod.cu
index 2529dcde..e102311d 100644
--- a/src/tensors/gpu/prod.cu
+++ b/src/tensors/gpu/prod.cu
@@ -4,6 +4,7 @@
// clang-format off
#include "tensors/gpu/prod.h"
#include "tensors/gpu/backend.h"
+#include "tensors/gpu/cuda_helpers.h"
// clang-format on
namespace marian {
@@ -11,8 +12,8 @@ namespace marian {
namespace gpu {
void Prod(marian::Tensor C,
- const marian::Tensor A,
- const marian::Tensor B,
+ const marian::Tensor& A,
+ const marian::Tensor& B,
bool transA,
bool transB,
float beta,
@@ -44,7 +45,7 @@ void Prod(marian::Tensor C,
->getCublasHandle();
#if CUDA_VERSION >= 9000
-// cublasSetMathMode(cublasHandle, CUBLAS_TENSOR_OP_MATH);
+ cublasSetMathMode(cublasHandle, CUBLAS_TENSOR_OP_MATH);
#endif
cublasSgemm(cublasHandle,
@@ -62,14 +63,14 @@ void Prod(marian::Tensor C,
C->data(),
ldc);
#if CUDA_VERSION >= 9000
-// cublasSetMathMode(cublasHandle, CUBLAS_DEFAULT_MATH);
+ cublasSetMathMode(cublasHandle, CUBLAS_DEFAULT_MATH);
#endif
}
void ProdWithBias(marian::Tensor C,
- const marian::Tensor A,
- const marian::Tensor B,
- const marian::Tensor bias,
+ const marian::Tensor& A,
+ const marian::Tensor& B,
+ const marian::Tensor& bias,
bool transA,
bool transB,
float beta,
@@ -80,8 +81,8 @@ void ProdWithBias(marian::Tensor C,
void ProdBatched(marian::Tensor C,
- const marian::Tensor A,
- const marian::Tensor B,
+ const marian::Tensor& A,
+ const marian::Tensor& B,
bool transA,
bool transB,
float beta,
@@ -116,7 +117,7 @@ void ProdBatched(marian::Tensor C,
->getCublasHandle();
#if CUDA_VERSION >= 9000
-// cublasSetMathMode(cublasHandle, CUBLAS_TENSOR_OP_MATH);
+ cublasSetMathMode(cublasHandle, CUBLAS_TENSOR_OP_MATH);
#endif
cublasSgemmStridedBatched(cublasHandle,
opB,
@@ -137,7 +138,7 @@ void ProdBatched(marian::Tensor C,
n * m,
std::max(batchA, batchB));
#if CUDA_VERSION >= 9000
-// cublasSetMathMode(cublasHandle, CUBLAS_DEFAULT_MATH);
+ cublasSetMathMode(cublasHandle, CUBLAS_DEFAULT_MATH);
#endif
}
}
diff --git a/src/tensors/gpu/prod.h b/src/tensors/gpu/prod.h
index 4ffba03c..5791fb1a 100644
--- a/src/tensors/gpu/prod.h
+++ b/src/tensors/gpu/prod.h
@@ -9,25 +9,25 @@ namespace marian {
namespace gpu {
void Prod(marian::Tensor C,
- const marian::Tensor A,
- const marian::Tensor B,
+ const marian::Tensor& A,
+ const marian::Tensor& B,
bool transA,
bool transB,
float beta = 0,
float scalar = 1);
void ProdWithBias(marian::Tensor C,
- const marian::Tensor A,
- const marian::Tensor B,
- const marian::Tensor bias,
+ const marian::Tensor& A,
+ const marian::Tensor& B,
+ const marian::Tensor& bias,
bool transA,
bool transB,
float beta = 0,
float scalar = 1);
void ProdBatched(marian::Tensor C,
- const marian::Tensor A,
- const marian::Tensor B,
+ const marian::Tensor& A,
+ const marian::Tensor& B,
bool transA,
bool transB,
float beta = 0,
diff --git a/src/tensors/tensor_allocator.h b/src/tensors/tensor_allocator.h
index 46f18646..e4e79b7c 100644
--- a/src/tensors/tensor_allocator.h
+++ b/src/tensors/tensor_allocator.h
@@ -11,7 +11,7 @@ namespace marian {
class TensorAllocator {
private:
- const size_t CHUNK = 512;
+ const size_t CHUNK = 128;
const size_t MBYTE = 1024 * 1024;
const size_t GROW = CHUNK * MBYTE;
const size_t ALIGN = 256;
diff --git a/src/tensors/tensor_operators.h b/src/tensors/tensor_operators.h
index 700b2465..3a4e3edf 100644
--- a/src/tensors/tensor_operators.h
+++ b/src/tensors/tensor_operators.h
@@ -62,9 +62,9 @@ void Reduce(Functor functor, marian::Tensor out, Tensors... tensors) {
}
// clang-format off
- DISPATCH7(Prod, marian::Tensor, const marian::Tensor, const marian::Tensor, bool, bool, float, float)
- DISPATCH8(ProdWithBias, marian::Tensor, const marian::Tensor, const marian::Tensor, const marian::Tensor, bool, bool, float, float)
- DISPATCH7(ProdBatched, marian::Tensor, const marian::Tensor, const marian::Tensor, bool, bool, float, float)
+ DISPATCH7(Prod, marian::Tensor, const marian::Tensor&, const marian::Tensor&, bool, bool, float, float)
+ DISPATCH8(ProdWithBias, marian::Tensor, const marian::Tensor&, const marian::Tensor&, const marian::Tensor&, bool, bool, float, float)
+ DISPATCH7(ProdBatched, marian::Tensor, const marian::Tensor&, const marian::Tensor&, bool, bool, float, float)
DISPATCH2(Dropout, marian::Tensor, float)
diff --git a/src/tests/sqlite_test.cpp b/src/tests/sqlite_test.cpp
index 9d8db825..88ca1ed6 100644
--- a/src/tests/sqlite_test.cpp
+++ b/src/tests/sqlite_test.cpp
@@ -1,78 +1,79 @@
-
- #include <iostream>
- #include <memory>
- #include <fstream>
-
- #include <boost/timer/timer.hpp>
-
- #include <SQLiteCpp/SQLiteCpp.h>
-
- int main(int argc, char** argv) {
-
- SQLite::Database db("corpus.db", SQLite::OPEN_READWRITE|SQLite::OPEN_CREATE);
- db.exec("PRAGMA temp_store_directory = '/data1/marcinjd';");
-
- db.exec("drop table if exists lines");
- db.exec("create table lines (_id integer, line0 text, line1 text);");
-
- boost::timer::auto_cpu_timer total;
-
- std::unique_ptr<boost::timer::auto_cpu_timer> t(new boost::timer::auto_cpu_timer());
-
- SQLite::Statement ps(db, "insert into lines values (?, ?, ?)");
-
- std::string line0, line1;
- size_t lines = 0;
-
- std::cerr << "Reading from " << argv[1] << " and " << argv[2] << std::endl;
-
- std::ifstream file0(argv[1]);
- std::ifstream file1(argv[2]);
-
- db.exec("begin;");
- while(GetLine(file0, line0) && GetLine(file1, line1)) {
- ps.bind(1, (int)lines);
- ps.bind(2, line0);
- ps.bind(3, line1);
-
- ps.exec();
- ps.reset();
-
- lines++;
- if(lines % 1000000 == 0) {
- std::cerr << "[" << lines << "]" << std::endl;
- t.reset(new boost::timer::auto_cpu_timer());
-
- db.exec("commit;");
- db.exec("begin;");
- }
- }
- db.exec("commit;");
-
- std::cerr << "[" << lines << "]" << std::endl;
-
- t.reset(new boost::timer::auto_cpu_timer());
- std::cerr << "creating index" << std::endl;
- db.exec("create unique index idx_line on lines (_id);");
-
- t.reset(new boost::timer::auto_cpu_timer());
-
- std::cout << "count : " << db.execAndGet("select count(*) from lines").getInt() << std::endl;
- t.reset(new boost::timer::auto_cpu_timer());
-
- int count = 0;
- SQLite::Statement sel(db, "select * from lines order by random();");
- t.reset(new boost::timer::auto_cpu_timer());
- while(sel.executeStep()) {
- // Demonstrate how to get some typed column value
- int id = sel.getColumn(0);
- std::string value0 = sel.getColumn(1);
- std::string value1 = sel.getColumn(2);
-
- if(count % 1000000 == 0)
- std::cout << count << " " << id << "\t" << value0 << "\t" << value1 << std::endl;
- count++;
- }
-
- return 0;
- } \ No newline at end of file
+#include <iostream>
+#include <memory>
+#include <fstream>
+
+#include <boost/timer/timer.hpp>
+
+#include <SQLiteCpp/SQLiteCpp.h>
+
+#include "common/utils.h"
+
+int main(int argc, char** argv) {
+
+ SQLite::Database db("corpus.db", SQLite::OPEN_READWRITE|SQLite::OPEN_CREATE);
+ db.exec("PRAGMA temp_store_directory = '/data1/marcinjd';");
+
+ db.exec("drop table if exists lines");
+ db.exec("create table lines (_id integer, line0 text, line1 text);");
+
+ boost::timer::auto_cpu_timer total;
+
+ std::unique_ptr<boost::timer::auto_cpu_timer> t(new boost::timer::auto_cpu_timer());
+
+ SQLite::Statement ps(db, "insert into lines values (?, ?, ?)");
+
+ std::string line0, line1;
+ size_t lines = 0;
+
+ std::cerr << "Reading from " << argv[1] << " and " << argv[2] << std::endl;
+
+ std::ifstream file0(argv[1]);
+ std::ifstream file1(argv[2]);
+
+ db.exec("begin;");
+ while(GetLine(file0, line0) && GetLine(file1, line1)) {
+ ps.bind(1, (int)lines);
+ ps.bind(2, line0);
+ ps.bind(3, line1);
+
+ ps.exec();
+ ps.reset();
+
+ lines++;
+ if(lines % 1000000 == 0) {
+ std::cerr << "[" << lines << "]" << std::endl;
+ t.reset(new boost::timer::auto_cpu_timer());
+
+ db.exec("commit;");
+ db.exec("begin;");
+ }
+ }
+ db.exec("commit;");
+
+ std::cerr << "[" << lines << "]" << std::endl;
+
+ t.reset(new boost::timer::auto_cpu_timer());
+ std::cerr << "creating index" << std::endl;
+ db.exec("create unique index idx_line on lines (_id);");
+
+ t.reset(new boost::timer::auto_cpu_timer());
+
+ std::cout << "count : " << db.execAndGet("select count(*) from lines").getInt() << std::endl;
+ t.reset(new boost::timer::auto_cpu_timer());
+
+ int count = 0;
+ SQLite::Statement sel(db, "select * from lines order by random();");
+ t.reset(new boost::timer::auto_cpu_timer());
+ while(sel.executeStep()) {
+ // Demonstrate how to get some typed column value
+ int id = sel.getColumn(0);
+ std::string value0 = sel.getColumn(1);
+ std::string value1 = sel.getColumn(2);
+
+ if(count % 1000000 == 0)
+ std::cout << count << " " << id << "\t" << value0 << "\t" << value1 << std::endl;
+ count++;
+ }
+
+ return 0;
+}
diff --git a/src/training/graph_group_async.h b/src/training/graph_group_async.h
index d0687b80..76f5a10b 100644
--- a/src/training/graph_group_async.h
+++ b/src/training/graph_group_async.h
@@ -75,6 +75,7 @@ public:
for(auto device : devices_) {
auto graph = New<ExpressionGraph>();
graph->setDevice(device);
+ graph->getBackend()->setClip(options_->get<float>("clip-gemm"));
graph->reserveWorkspaceMB(options_->get<size_t>("workspace"));
graphs_.push_back(graph);
shardOpt_.push_back(Optimizer(options_));
diff --git a/src/training/graph_group_multinode.cpp b/src/training/graph_group_multinode.cpp
index f64c4ec3..16a128c4 100644
--- a/src/training/graph_group_multinode.cpp
+++ b/src/training/graph_group_multinode.cpp
@@ -37,7 +37,6 @@ Tensor MultiNodeGraphGroup::newTensor(int size, Ptr<Backend> backend) {
*/
void MultiNodeGraphGroup::init(Ptr<data::Batch> batch) {
// Setup clients and shards
- setupMPI();
setupClients(batch);
setupServerShards();
if(clientCommOverlap) {
@@ -51,6 +50,22 @@ void MultiNodeGraphGroup::init(Ptr<data::Batch> batch) {
launchCommOverlapThreads(); // For communicating with server shards while
// other threads do computations
}
+
+ // setup delayed gradient storage
+ if (tau_ > 1) {
+ delay_count = std::vector<size_t>(mpi_comm_world_size_);
+ totalBatchWords = std::vector<int>(mpi_comm_world_size_);
+ optDelayMutex_ = std::vector<std::mutex>(mpi_comm_world_size_);
+
+ for (int i = 0;i < mpi_comm_world_size_; i++) {
+ // Shard buffers across GPUs
+ auto backend = clientGraphs_[i % devices_.size()]->getBackend();
+ Tensor accGrad = newTensor(nodeSizes_[i], backend);
+ Tensor accGradBuff = newTensor(nodeSizes_[i], backend);
+ accGradients.push_back(accGrad);
+ accGradientBuffer.push_back(accGradBuff);
+ }
+ }
}
/**
@@ -206,6 +221,9 @@ void MultiNodeGraphGroup::calculateShardSizes() {
*/
void MultiNodeGraphGroup::initShardGpuTensors() {
size_t offset = 0;
+ for (int i = 0; i < mpi_my_rank_; i++) {
+ offset += nodeSizes_[i];
+ }
for(int shard = 0; shard < devices_.size(); shard++) {
Tensor gpuParams
= newTensor(shardSizes_[shard], clientGraphs_[shard]->getBackend());
@@ -214,6 +232,7 @@ void MultiNodeGraphGroup::initShardGpuTensors() {
shardParams_.push_back(gpuParams);
shardGrads_.push_back(
newTensor(shardSizes_[shard], clientGraphs_[shard]->getBackend()));
+ offset += shardSizes_[shard];
}
}
@@ -235,7 +254,7 @@ void MultiNodeGraphGroup::launchServerThread() {
4,
MPI_UNSIGNED_LONG,
MPI_ANY_SOURCE,
- MPI_TAG_GRAD_PUSH_,
+ MPI_TAG_GRAD_PUSH_MSG_,
MPI_COMM_WORLD,
&status);
if(messageInfo[MSG_INFO_STATUS_] == STATUS_NODE_FINISHED_) {
@@ -388,9 +407,31 @@ void MultiNodeGraphGroup::synchronizeWithServerShards(Tensor newGrads,
// Update remotely if node != this node
if(node != mpi_my_rank_) {
+ Tensor gradient;
+
+ // Delayed Gradient Update
+ if (tau_ > 1) {
+ std::lock_guard<std::mutex> guard(optDelayMutex_[node]);
+ accGradientBuffer[node]->copyFrom(newGrads->subtensor(offset, nodeSize));
+ // Accumulate the gradient
+ using namespace functional;
+ Element(_1 += _2, accGradients[node], accGradientBuffer[node]);
+ // Accumulate total batch word
+ totalBatchWords[node] += batchWords;
+ delay_count[node]++;
+
+ if (delay_count[node] < tau_)
+ continue;
+ delay_count[node] = 0;
+ gradient = accGradients[node];
+ batchWords = totalBatchWords[node];
+ } else {
+ gradient = newGrads->subtensor(offset, nodeSize);
+ }
+
// Copy grads from GPU to CPU (for MPI sending)
cudaMemcpy(clientCommBuffersCPU_[gpu].data(),
- newGrads->subtensor(offset, nodeSize)->data(),
+ gradient->data(),
nodeSize * sizeof(float),
cudaMemcpyDeviceToHost);
cudaStreamSynchronize(0);
@@ -405,7 +446,7 @@ void MultiNodeGraphGroup::synchronizeWithServerShards(Tensor newGrads,
4,
MPI_UNSIGNED_LONG,
node,
- MPI_TAG_GRAD_PUSH_,
+ MPI_TAG_GRAD_PUSH_MSG_,
MPI_COMM_WORLD);
MPI_Ssend(clientCommBuffersCPU_[gpu].data(),
nodeSize,
@@ -413,7 +454,12 @@ void MultiNodeGraphGroup::synchronizeWithServerShards(Tensor newGrads,
node,
MPI_TAG_GRAD_PUSH_,
MPI_COMM_WORLD);
-
+ // Reset total gradient and batch words
+ if (tau_ > 1) {
+ std::lock_guard<std::mutex> guard(optDelayMutex_[node]);
+ accGradients[node]->set(0);
+ totalBatchWords[node] = 0;
+ }
// Receive updated params from server node
MPI_Recv(clientCommBuffersCPU_[gpu].data(),
nodeSize,
@@ -492,6 +538,11 @@ void MultiNodeGraphGroup::execute(Ptr<data::Batch> batch) {
thread_local Ptr<ExpressionGraph> graph;
thread_local Ptr<models::ModelBase> builder;
thread_local size_t my_id = 0;
+ thread_local size_t t = 0;
+ // only for scheduler statistic
+ thread_local float cost = 0;
+ thread_local size_t num_seen_words = 0;
+ thread_local size_t num_seen_sentences = 0;
if(!graph) {
std::lock_guard<std::mutex> lock(mutexClientInit_);
@@ -502,10 +553,23 @@ void MultiNodeGraphGroup::execute(Ptr<data::Batch> batch) {
auto costNode = builder->build(graph, batch);
+#if MPI_FOUND
+ if (t == 0) {
+ MPI_Barrier(MPI_COMM_WORLD);
+ if (my_id != 0)
+ graph->params()->vals()->copyFrom(clientGraphs_[0]->params()->vals());
+ MPI_Barrier(MPI_COMM_WORLD);
+ }
+#endif
+
graph->forward();
- float cost = costNode->scalar();
+ cost += costNode->scalar();
+ num_seen_words += batch->words();
+ num_seen_sentences += batch->size();
graph->backward();
+ t++;
+
graph->getBackend()->synchronize();
if(!clientCommOverlap) {
@@ -558,27 +622,50 @@ void MultiNodeGraphGroup::execute(Ptr<data::Batch> batch) {
}
// Run scheduler (if enabled)
- if(scheduler_) {
+ if(t % tau_ == 0 && scheduler_) {
std::unique_lock<std::mutex> lock(schedulerMutex_);
// Wait until the thread that wants to do validation is finished.
clientThreadPool_->wait_for_one(lock);
- scheduler_->update(cost, batch);
+ if (options_->get<std::string>("cost-type") != "ce-sum")
+ cost /= tau_;
+
+ if (tau_ > 1) {
+ std::vector<size_t> fakeLength = {1, 1};
+ auto fb = data::CorpusBatch::fakeBatch(fakeLength,
+ num_seen_sentences,
+ NULL);
+ fb->front()->setWords(num_seen_words);
+ scheduler_->update(cost, fb);
+ } else {
+ scheduler_->update(cost, batch);
+ }
+
+ num_seen_words = 0;
+ num_seen_sentences = 0;
+ cost = 0;
- if(scheduler_->saving() || scheduler_->validating()) {
+ if((scheduler_->saving() || scheduler_->validating())) {
// Wait with validation or saving until all other threads are done with
// update.
// We want to reuse the graphs for validation, so they need to be in
// a safe state.
clientThreadPool_->wait_for_others(lock);
-
- if(scheduler_->saving())
- this->save(graph);
-
- if(scheduler_->validating())
+#if MPI_FOUND
+ //wait until other nodes are ready
+ MPI_Barrier(MPI_COMM_WORLD);
+
+ // TODO: Saving is broken
+ //if(mpi_my_rank_ == 0 && scheduler_->saving())
+ // this->save(graph);
+
+ if(mpi_my_rank_ == 0 && scheduler_->validating())
scheduler_->validate(clientGraphs_);
+ // inform other nodes to continue
+ MPI_Barrier(MPI_COMM_WORLD);
+#endif
// Validation or saving is done, tell other threads to continue work.
clientThreadPool_->notify_others();
}
diff --git a/src/training/graph_group_multinode.h b/src/training/graph_group_multinode.h
index 2d5d8cfe..b106bb1a 100644
--- a/src/training/graph_group_multinode.h
+++ b/src/training/graph_group_multinode.h
@@ -124,15 +124,21 @@ protected:
int mpi_comm_world_size_{1};
/**
- * Flag to indicate that an MPI message contains gradients (client -> server).
+ * Flag to indicate that an MPI message contains message info
+ * before sending the gradient (client -> server).
*/
- static const int MPI_TAG_GRAD_PUSH_{0};
+ static const int MPI_TAG_GRAD_PUSH_MSG_{0};
+
+ /**
+ * Flag to indicate that an MPI message contains gradient (client -> server).
+ */
+ static const int MPI_TAG_GRAD_PUSH_{5};
/**
* Flag to indicate that an MPI message contains parameters (server ->
* client).
*/
- static const int MPI_TAG_PARAM_PUSH_{5};
+ static const int MPI_TAG_PARAM_PUSH_{10};
/**
* Message info indices: 0 = size; 1 = originating client; 2 = number of batch
@@ -216,6 +222,20 @@ protected:
std::vector<std::condition_variable> cvClientCommOverlapBuffersFilled_;
/**
+ * Variables for optimizer delay
+ */
+ size_t tau_{1};
+ std::vector<std::mutex> optDelayMutex_;
+ std::vector<size_t> delay_count;
+ std::vector<int> totalBatchWords;
+ std::vector<Tensor> accGradients, accGradientBuffer;
+
+ /**
+ * LocalOptimizers related variables
+ */
+ bool useLocalOpt_;
+
+ /**
* Allocate new tensor on given GPU and store allocator.
*/
Tensor newTensor(int size, Ptr<Backend> backend);
@@ -384,15 +404,23 @@ public:
*/
MultiNodeGraphGroup(Ptr<Config> options)
: GraphGroup(options),
+ tau_{options_->get<size_t>("optimizer-delay")},
+ useLocalOpt_{options_->get<bool>("multi-node-local-optimizers")},
clientCommOverlap{options_->get<bool>("multi-node-overlap")} {
// Set up devices for this node
- loadDeviceConfig(options_->get<std::vector<size_t>>("devices"));
+ setupMPI(); //Setup MPI before creating device vectors
+ std::vector<size_t> devices;
+ for(auto& d : options_->getDevices())
+ devices.push_back(d.no);
+ loadDeviceConfig(devices);
+
// Create builders and graphs for clients.
- for(int i = 0; i < devices_.size(); i++) {
+ for(size_t i = 0; i < devices_.size(); i++) {
clientGraphs_.push_back(New<ExpressionGraph>());
clientGraphs_[i]->setDevice({devices_[i], DeviceType::gpu});
clientGraphs_[i]->reserveWorkspaceMB(options_->get<size_t>("workspace"));
- clientBuilders_.push_back(models::from_config(options_, models::usage::training));
+ clientBuilders_.push_back(
+ models::from_config(options_, models::usage::training));
}
}
diff --git a/src/training/graph_group_singleton.h b/src/training/graph_group_singleton.h
index 73cdaaac..121742b3 100644
--- a/src/training/graph_group_singleton.h
+++ b/src/training/graph_group_singleton.h
@@ -35,6 +35,7 @@ public:
auto deviceId = options_->getDevices()[0];
graph_ = New<ExpressionGraph>();
graph_->setDevice(deviceId);
+ graph_->getBackend()->setClip(options_->get<float>("clip-gemm"));
graph_->reserveWorkspaceMB(options_->get<size_t>("workspace"));
opt_ = Optimizer(options_);
builder_ = models::from_config(options_, models::usage::training);
diff --git a/src/training/graph_group_sync.h b/src/training/graph_group_sync.h
index ecc1a8e8..b0e2b428 100644
--- a/src/training/graph_group_sync.h
+++ b/src/training/graph_group_sync.h
@@ -49,6 +49,8 @@ public:
auto graph = New<ExpressionGraph>();
graph->setDevice(device);
graph->reserveWorkspaceMB(options_->get<size_t>("workspace"));
+ graph->getBackend()->setClip(options_->get<float>("clip-gemm"));
+
graphs_.push_back(graph);
shardOpt_.push_back(Optimizer(options_));
builders_.push_back(models::from_config(options_, models::usage::training));
diff --git a/src/training/training.h b/src/training/training.h
index 770bf82c..df641201 100644
--- a/src/training/training.h
+++ b/src/training/training.h
@@ -78,7 +78,10 @@ public:
scheduler->finished();
model->finalize();
- model->save(true);
+
+ // @TODO: romang, can you comment on this?
+ if(!trainState->loaded)
+ model->save(true);
}
};
}
diff --git a/src/training/training_state.h b/src/training/training_state.h
index dd1f3d05..12e9df4d 100644
--- a/src/training/training_state.h
+++ b/src/training/training_state.h
@@ -61,7 +61,10 @@ public:
// The state of the random number generator from a corpus
std::string seedCorpus;
+ // Set flag if training was resumed
bool loaded{false};
+
+ // @TODO: romang, is this doing anything?
bool validated{false};
TrainingState(float learnRate) : eta(learnRate) {}
@@ -82,6 +85,7 @@ public:
void newBatch() {
++batches;
++batchesEpoch;
+ loaded = false;
validated = false;
for(auto observer : observers_)
observer->actAfterBatches(*this);
diff --git a/src/translator/beam_search.h b/src/translator/beam_search.h
index 243d5080..fd394434 100644
--- a/src/translator/beam_search.h
+++ b/src/translator/beam_search.h
@@ -225,7 +225,7 @@ public:
for(int i = 0; i < dimBatch; ++i) {
if(!beams[i].empty()) {
final = final
- || histories[i]->size() >= 3 * batch->front()->batchWidth();
+ || histories[i]->size() >= options_->get<float>("max-length-factor") * batch->front()->batchWidth();
histories[i]->Add(beams[i], prunedBeams[i].empty() || final);
}
}
diff --git a/src/translator/scorers.cpp b/src/translator/scorers.cpp
index 1d84423c..9fe14bf3 100644
--- a/src/translator/scorers.cpp
+++ b/src/translator/scorers.cpp
@@ -18,7 +18,10 @@ Ptr<Scorer> scorerByType(std::string fname,
options->set("index", index);
}
- auto encdec = models::from_options(options, models::usage::translation);
+ bool skipCost = config->get<bool>("skip-cost");
+ auto encdec = models::from_options(options,
+ skipCost ? models::usage::raw
+ : models::usage::translation);
LOG(info, "Loading scorer of type {} as feature {}", type, fname);
diff --git a/src/translator/translator.h b/src/translator/translator.h
index 4cde139d..e44c9f68 100644
--- a/src/translator/translator.h
+++ b/src/translator/translator.h
@@ -55,6 +55,7 @@ public:
auto task = [&](DeviceId device, size_t id) {
auto graph = New<ExpressionGraph>(true, options_->get<bool>("optimize"));
graph->setDevice(device);
+ graph->getBackend()->setClip(options_->get<float>("clip-gemm"));
graph->reserveWorkspaceMB(options_->get<size_t>("workspace"));
graphs_[id] = graph;
@@ -66,6 +67,7 @@ public:
}
scorers_[id] = scorers;
+ graph->forward();
};
threadPool.enqueue(task, device, id++);
@@ -154,6 +156,7 @@ public:
for(auto device : devices_) {
auto graph = New<ExpressionGraph>(true, options_->get<bool>("optimize"));
graph->setDevice(device);
+ graph->getBackend()->setClip(options_->get<float>("clip-gemm"));
graph->reserveWorkspaceMB(options_->get<size_t>("workspace"));
graphs_.push_back(graph);