Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHieu Hoang <hihoan@microsoft.com>2021-06-08 00:16:41 +0300
committerHieu Hoang <hihoan@microsoft.com>2021-06-08 00:16:41 +0300
commitbc4ad2408c308fe3d9ac31accdaa5019dc2187ba (patch)
treed259be53ae1c05fd47c26b62d51cc62d0acf15d1
parentf19ebbae69e06d85979ac16b76fb1acf0dc4e695 (diff)
parentce34df4d985d3ff86e4babfc9529c4aaa0aba57d (diff)
Merge branch 'mjd/bdot' into hihoan/lsh7
-rw-r--r--CHANGELOG.md5
-rw-r--r--CMakeLists.txt16
-rw-r--r--VERSION2
m---------regression-tests0
m---------src/3rd_party/sentencepiece0
-rw-r--r--src/graph/expression_graph.h2
-rw-r--r--src/graph/expression_operators.cpp4
-rw-r--r--src/graph/expression_operators.h6
-rw-r--r--src/graph/node_operators_binary.h174
-rw-r--r--src/models/transformer.h4
-rw-r--r--src/tensors/allocator.h5
-rwxr-xr-xsrc/tensors/cpu/prod.cpp156
-rw-r--r--src/tensors/device.h5
-rwxr-xr-xsrc/tensors/gpu/prod.cpp158
-rw-r--r--src/tensors/tensor_operators.h1
-rw-r--r--src/tests/units/operator_tests.cpp60
-rw-r--r--src/training/scheduler.h3
17 files changed, 575 insertions, 26 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 7f41b8d1..2f4141c8 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
### Added
- Early stopping based on first, all, or any validation metrics via `--early-stopping-on`
+- Compute 8.6 support if using CUDA>=11.1
- Support for RMSNorm as drop-in replace for LayerNorm from `Biao Zhang; Rico Sennrich (2019). Root Mean Square Layer Normalization`. Enabled in Transformer model via `--transformer-postprocess dar` instead of `dan`.
- Extend suppression of unwanted output symbols, specifically "\n" from default vocabulary if generated by SentencePiece with byte-fallback. Deactivates with --allow-special
- Allow for fine-grained CPU intrinsics overrides when BUILD_ARCH != native e.g. -DBUILD_ARCH=x86-64 -DCOMPILE_AVX512=off
@@ -25,8 +26,10 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
- Dynamic gradient-scaling with `--dynamic-gradient-scaling`.
- Add unit tests for binary files.
- Fix compilation with OMP
+- Compute aligned memory sizes using exact sizing
### Fixed
+- Adding new validation metrics when training is restarted and --reset-valid-stalled is used
- Missing depth-scaling in transformer FFN
- Fixed an issue when loading intgemm16 models from unaligned memory.
- Fix building marian with gcc 9.3+ and FBGEMM
@@ -41,6 +44,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
- Broken links to MNIST data sets
### Changed
+- Set REQUIRED_BIAS_ALIGNMENT = 16 in tensors/gpu/prod.cpp to avoid memory-misalignment on certain Ampere GPUs.
- For BUILD_ARCH != native enable all intrinsics types by default, can be disabled like this: -DCOMPILE_AVX512=off
- Moved FBGEMM pointer to commit c258054 for gcc 9.3+ fix
- Change compile options a la -DCOMPILE_CUDA_SM35 to -DCOMPILE_KEPLER, -DCOMPILE_MAXWELL,
@@ -50,6 +54,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
- Developer documentation framework based on Sphinx+Doxygen+Breathe+Exhale
- Expresion graph documentation (#788)
- Graph operators documentation (#801)
+- Remove unused variable from expression graph
## [1.10.0] - 2021-02-06
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 79c8585e..119bc01f 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -325,6 +325,16 @@ if(CUDA_FOUND)
option(COMPILE_AMPERE "Compile GPU version with SM80 support" ON)
LIST(APPEND COMPUTE -Wno-deprecated-gpu-targets)
endif()
+ if(CUDA_VERSION VERSION_EQUAL "11.1" OR CUDA_VERSION VERSION_GREATER "11.1")
+ option(COMPILE_KEPLER "Compile GPU version with SM35 support" OFF) # deprecated for CUDA 11
+ option(COMPILE_MAXWELL "Compile GPU version with SM50 support" OFF) # deprecated for CUDA 11
+ option(COMPILE_PASCAL "Compile GPU version with SM60 support" ON)
+ option(COMPILE_VOLTA "Compile GPU version with SM70 support" ON)
+ option(COMPILE_TURING "Compile GPU version with SM75 support" ON)
+ option(COMPILE_AMPERE "Compile GPU version with SM80 support" ON)
+ option(COMPILE_AMPERE_RTX "Compile GPU version with SM86 support" ON)
+ LIST(APPEND COMPUTE -Wno-deprecated-gpu-targets)
+ endif()
if(COMPILE_KEPLER)
message(STATUS "Compiling code for Kepler GPUs")
@@ -354,6 +364,12 @@ if(CUDA_FOUND)
LIST(APPEND COMPUTE -gencode=arch=compute_80,code=sm_80; -gencode=arch=compute_80,code=compute_80) # Ampere GPUs
endif(COMPILE_AMPERE)
endif()
+ if(CUDA_VERSION VERSION_EQUAL "11.1" OR CUDA_VERSION VERSION_GREATER "11.1")
+ if(COMPILE_AMPERE_RTX)
+ message(STATUS "Compiling code for Ampere RTX GPUs")
+ LIST(APPEND COMPUTE -gencode=arch=compute_86,code=sm_86; -gencode=arch=compute_86,code=compute_86) # Ampere RTX GPUs
+ endif(COMPILE_AMPERE_RTX)
+ endif()
if(USE_STATIC_LIBS)
set(EXT_LIBS ${EXT_LIBS} ${CUDA_curand_LIBRARY} ${CUDA_CUBLAS_LIBRARIES} ${CUDA_cusparse_LIBRARY})
diff --git a/VERSION b/VERSION
index b609d445..c1cadea1 100644
--- a/VERSION
+++ b/VERSION
@@ -1,2 +1,2 @@
-v1.10.18
+v1.10.20
diff --git a/regression-tests b/regression-tests
-Subproject 1afd4eb1014ac451c6a3d6f9b5d34c322902e62
+Subproject 7d612ca5e4b27a76f92584dad76d240e34f216d
diff --git a/src/3rd_party/sentencepiece b/src/3rd_party/sentencepiece
-Subproject 8336bbd0c1cfba02a879afe625bf1ddaf7cd93c
+Subproject 6f24a6b52a521a3467e99a9c175ba9e13690521
diff --git a/src/graph/expression_graph.h b/src/graph/expression_graph.h
index adc0aeae..fce7d532 100644
--- a/src/graph/expression_graph.h
+++ b/src/graph/expression_graph.h
@@ -145,8 +145,6 @@ protected: // (these are protected, not private, for ONNX exporting)
Ptr<Tensors> tensors_;
private:
- std::unordered_map<size_t, std::vector<Expr>> memoized_;
-
Type defaultElementType_{Type::float32}; // Type used for storing parameters, currently all parameters have to have the same type
bool inferenceOnly_{false}; // a flag holds whether the graph is used for inference only
diff --git a/src/graph/expression_operators.cpp b/src/graph/expression_operators.cpp
index 6c7ef91c..baec94df 100644
--- a/src/graph/expression_operators.cpp
+++ b/src/graph/expression_operators.cpp
@@ -519,6 +519,10 @@ Expr bdot(Expr a, Expr b, bool transA, bool transB, float scale) {
return Expression<DotBatchedNodeOp>(a, b, transA, transB, scale);
}
+Expr bdot_legacy(Expr a, Expr b, bool transA, bool transB, float scale) {
+ return Expression<DotBatchedLegacyNodeOp>(a, b, transA, transB, scale);
+}
+
Expr affineDefault(Expr a, Expr b, Expr bias, bool transA, bool transB, float scale) {
// general version, MKL, CBlas or CUDA
diff --git a/src/graph/expression_operators.h b/src/graph/expression_operators.h
index f3d84eb6..c1570eff 100644
--- a/src/graph/expression_operators.h
+++ b/src/graph/expression_operators.h
@@ -478,6 +478,12 @@ Expr bdot(Expr a,
bool transB = false,
float scalar = 1.f);
+Expr bdot_legacy(Expr a,
+ Expr b,
+ bool transA = false,
+ bool transB = false,
+ float scalar = 1.f);
+
/**
* Performs an affine transformation.
* Computes
diff --git a/src/graph/node_operators_binary.h b/src/graph/node_operators_binary.h
index 91fc29da..169b1420 100644
--- a/src/graph/node_operators_binary.h
+++ b/src/graph/node_operators_binary.h
@@ -529,11 +529,27 @@ public:
shapeB.set(-1, b->shape()[-2]);
}
- Shape outShape = shapeA;
- outShape.set(-1, shapeB[-1]);
ABORT_IF(shapeA[-1] != shapeB[-2],
- "Batched matrix product requires inner dimensions to match in {}{} * {}{}", std::string(shapeA), transA, std::string(shapeB), transB);
- return outShape;
+ "Batched matrix product requires inner dimensions to match in {}{} * {}{}",
+ std::string(shapeA), transA, std::string(shapeB), transB);
+
+ // create shapes for batch dimensions only
+ auto shapeBatchA = shapeA;
+ shapeBatchA.set(-1, 1);
+ shapeBatchA.set(-2, 1);
+
+ auto shapeBatchB = shapeB;
+ shapeBatchB.set(-1, 1);
+ shapeBatchB.set(-2, 1);
+
+ // broadcast batch dimensions
+ auto shapeOut = Shape::broadcast({shapeBatchA, shapeBatchB});
+
+ // set non-batch dimensions in output
+ shapeOut.set(-2, shapeA[-2]);
+ shapeOut.set(-1, shapeB[-1]);
+
+ return shapeOut;
}
NodeOps forwardOps() override {
@@ -655,6 +671,156 @@ public:
const std::string color() override { return "orange"; }
};
+class DotBatchedLegacyNodeOp : public NaryNodeOp {
+private:
+ friend class SerializationHelpers;
+ bool transA_;
+ bool transB_;
+ float scalar_;
+
+public:
+ DotBatchedLegacyNodeOp(Expr a, Expr b, bool transA, bool transB, float scalar)
+ : NaryNodeOp({a, b}, newShape(a, b, transA, transB)),
+ transA_(transA),
+ transB_(transB),
+ scalar_(scalar) {}
+
+ Shape newShape(Expr a, Expr b, bool transA, bool transB) {
+ auto shapeA = a->shape();
+ if(transA) {
+ shapeA.set(-2, a->shape()[-1]);
+ shapeA.set(-1, a->shape()[-2]);
+ }
+
+ auto shapeB = b->shape();
+ if(transB) {
+ shapeB.set(-2, b->shape()[-1]);
+ shapeB.set(-1, b->shape()[-2]);
+ }
+
+ Shape outShape = shapeA;
+ outShape.set(-1, shapeB[-1]);
+ ABORT_IF(shapeA[-1] != shapeB[-2],
+ "Batched matrix product requires inner dimensions to match in {}{} * {}{}", std::string(shapeA), transA, std::string(shapeB), transB);
+ return outShape;
+ }
+
+ NodeOps forwardOps() override {
+ // C = alpha * dot(op(A), op(B))
+ return {NodeOp(ProdBatchedLegacy(val_,
+ graph()->allocator(),
+ child(0)->val(),
+ child(1)->val(),
+ transA_,
+ transB_,
+ 0.f,
+ scalar_))};
+ }
+
+ NodeOps backwardOps() override {
+ // D is the adjoint, the matrix of derivatives
+ // df/dA += alpha * dot(D, op(B).T)
+ // df/dB += alpha * dot(op(A).T, D)
+ // beta set to 1.0 in gemm, C = alpha * dot(op(A), op(B)) + beta * C
+ // to sum gradients from different graph parts
+
+ if(!transA_ && transB_)
+ return {NodeOp(ProdBatchedLegacy(child(0)->grad(),
+ graph()->allocator(),
+ adj_,
+ child(1)->val(),
+ false,
+ false,
+ 1.0,
+ scalar_)),
+ NodeOp(ProdBatchedLegacy(child(1)->grad(),
+ graph()->allocator(),
+ adj_,
+ child(0)->val(),
+ true,
+ false,
+ 1.0,
+ scalar_))};
+ if(transA_ && !transB_)
+ return {NodeOp(ProdBatchedLegacy(child(0)->grad(),
+ graph()->allocator(),
+ child(1)->val(),
+ adj_,
+ false,
+ true,
+ 1.0,
+ scalar_)),
+ NodeOp(ProdBatchedLegacy(child(1)->grad(),
+ graph()->allocator(),
+ child(0)->val(),
+ adj_,
+ false,
+ false,
+ 1.0,
+ scalar_))};
+ if(transA_ && transB_)
+ return {NodeOp(ProdBatchedLegacy(child(0)->grad(),
+ graph()->allocator(),
+ child(1)->val(),
+ adj_,
+ true,
+ true,
+ 1.0,
+ scalar_)),
+ NodeOp(ProdBatchedLegacy(child(1)->grad(),
+ graph()->allocator(),
+ adj_,
+ child(0)->val(),
+ true,
+ true,
+ 1.0,
+ scalar_))};
+ return {NodeOp(ProdBatchedLegacy(child(0)->grad(),
+ graph()->allocator(),
+ adj_,
+ child(1)->val(),
+ false,
+ true,
+ 1.0,
+ scalar_)),
+ NodeOp(ProdBatchedLegacy(child(1)->grad(),
+ graph()->allocator(),
+ child(0)->val(),
+ adj_,
+ true,
+ false,
+ 1.0,
+ scalar_))};
+ }
+
+ const std::string type() override { return "bdot_legacy"; }
+
+ virtual size_t hash() override {
+ size_t seed = NaryNodeOp::hash();
+ util::hash_combine(seed, transA_);
+ util::hash_combine(seed, transB_);
+ util::hash_combine(seed, scalar_);
+ return seed;
+ }
+
+ virtual bool equal(Expr node) override {
+ if(!NaryNodeOp::equal(node))
+ return false;
+ auto cnode = std::dynamic_pointer_cast<DotBatchedLegacyNodeOp>(node);
+ if(!cnode)
+ return false;
+ if(transA_ != cnode->transA_)
+ return false;
+ if(transB_ != cnode->transB_)
+ return false;
+ if(scalar_ != cnode->scalar_)
+ return false;
+ return true;
+ }
+
+ const std::string color() override { return "orange"; }
+};
+
// Note: To reduce code duplication, we use the same NodeOp for C = op(S) x D and C = D x op(S).
// Set swapOperands to select the latter.
class CSRDotNodeOp : public NaryNodeOp {
diff --git a/src/models/transformer.h b/src/models/transformer.h
index 1da02318..a792de8b 100644
--- a/src/models/transformer.h
+++ b/src/models/transformer.h
@@ -249,7 +249,7 @@ public:
// multiplicative attention with flattened softmax
float scale = 1.0f / std::sqrt((float)dk); // scaling to avoid extreme values due to matrix multiplication
- auto z = bdot(q, k, false, true, scale); // [-4: beam depth * batch size, -3: num heads, -2: max tgt length, -1: max src length]
+ auto z = bdot_legacy(q, k, false, true, scale); // [-4: beam depth * batch size, -3: num heads, -2: max tgt length, -1: max src length]
// mask out garbage beyond end of sequences
z = z + mask;
@@ -264,7 +264,7 @@ public:
weights = dropout(weights, inference_ ? 0 : opt<float>("transformer-dropout-attention"));
// apply attention weights to values
- auto output = bdot(weights, v); // [-4: beam depth * batch size, -3: num heads, -2: max tgt length, -1: split vector dim]
+ auto output = bdot_legacy(weights, v); // [-4: beam depth * batch size, -3: num heads, -2: max tgt length, -1: split vector dim]
return output;
}
diff --git a/src/tensors/allocator.h b/src/tensors/allocator.h
index 9dc44f58..1844be14 100644
--- a/src/tensors/allocator.h
+++ b/src/tensors/allocator.h
@@ -175,8 +175,9 @@ public:
reserve(bytes);
}
- size_t alignedSize(size_t size) {
- return (size_t)(ceil(size / (double)alignment_) * alignment_);
+ size_t alignedSize(size_t size) const {
+ size_t over = size + alignment_ - 1;
+ return over - (over % alignment_);
}
void throwAtReallocation(bool throwRealloc) { throw_ = throwRealloc; }
diff --git a/src/tensors/cpu/prod.cpp b/src/tensors/cpu/prod.cpp
index 6e28bdd2..07cc2b99 100755
--- a/src/tensors/cpu/prod.cpp
+++ b/src/tensors/cpu/prod.cpp
@@ -93,6 +93,162 @@ void ProdBatched(marian::Tensor C,
#if BLAS_FOUND
float alpha = scalar;
+ // determine meta-shape of bdot operation. Essentially treat the last two dimensions as single elements
+ // such that (..., m, k) x (..., k, n) -> (..., m, n) where ... is a broadcastable shape as in element-wise kernels.
+
+ auto aShape = A->shape();
+ auto bShape = B->shape();
+
+ // make sure both shape have the same number of dimensions via broadcasting
+ size_t maxLength = std::max(aShape.size(), bShape.size());
+ if(aShape.size() != bShape.size()) {
+ Shape ones(std::vector<int>(maxLength, 1));
+ aShape = Shape::broadcast({aShape, ones});
+ bShape = Shape::broadcast({bShape, ones});
+ }
+
+ // Create meta-shapes without last 2 dimensions
+ Shape aShapeMeta, bShapeMeta, cShapeMeta;
+ aShapeMeta.resize(maxLength - 2);
+ bShapeMeta.resize(maxLength - 2);
+ for(size_t i = 0; i < maxLength - 2; ++i) {
+ aShapeMeta.set(i, aShape[i]);
+ bShapeMeta.set(i, bShape[i]);
+ }
+ cShapeMeta = Shape::broadcast({aShapeMeta, bShapeMeta});
+
+ size_t m = aShape[-2];
+ size_t k = aShape[-1];
+ if(transA)
+ std::swap(m, k);
+
+ size_t l = bShape[-2];
+ size_t n = bShape[-1];
+ if(transB)
+ std::swap(l, n);
+
+ size_t lda = aShape[-1];
+ size_t ldb = bShape[-1];
+ size_t ldc = bShape[-1];
+
+ if(transB)
+ ldc = bShape[-2];
+
+ auto strideA = m * k;
+ auto strideB = n * k;
+ auto strideC = n * m;
+
+ auto batchC = cShapeMeta.elements();
+
+ // Convert to functional shapes to be able to map dimensions. @TODO merge this
+ functional::Shape aShapeMetaF = aShapeMeta;
+ functional::Shape bShapeMetaF = bShapeMeta;
+ functional::Shape cShapeMetaF = cShapeMeta;
+
+#if MKL_FOUND
+ CBLAS_TRANSPOSE transA_forarr = CblasNoTrans;
+ CBLAS_TRANSPOSE transB_forarr = CblasNoTrans;
+
+ if(transA)
+ transA_forarr = CblasTrans;
+
+ if(transB)
+ transB_forarr = CblasTrans;
+
+ /* cblas_sgemm_batch allows us to group all the small GEMMs that are done in a for loop with sgemm and compute
+ * them in only one MKL call. For the API documentation refer to
+ * https://software.intel.com/content/www/us/en/develop/documentation/mkl-developer-reference-c/top/blas-and-sparse-blas-routines/blas-like-extensions/cblas-gemm-batch.html
+ * The API supports dependencies, where you can specify one "group" of GEMMs to be computed after another. (This controlled by the group_count parameter).
+ * In our case, the operations are not dependent on one another so we hardcode one group. The rest of the arguments (with the exception of group_size) are
+ * the same as the ones that cblas_sgemm expects, with the difference that we are supposed to provide an array pointer (One element per group).
+ * Weirdly enough, we are required to to provide all of the integer arguments as the MKL_INT datatype
+ */
+
+ static const constexpr size_t group_count = 1; // We have one group
+ const std::vector<CBLAS_TRANSPOSE> transa_arr(group_count, transA_forarr);
+ const std::vector<CBLAS_TRANSPOSE> transb_arr(group_count, transB_forarr);
+ const std::vector<MKL_INT> m_arr(group_count, (MKL_INT)m);
+ const std::vector<MKL_INT> n_arr(group_count, (MKL_INT)n);
+ const std::vector<MKL_INT> k_arr(group_count, (MKL_INT)k);
+ const std::vector<float> alpha_arr(group_count, alpha);
+ const std::vector<float> beta_arr(group_count, beta);
+ const std::vector<MKL_INT> lda_arr(group_count, (MKL_INT)lda);
+ const std::vector<MKL_INT> ldb_arr(group_count, (MKL_INT)ldb);
+ const std::vector<MKL_INT> ldc_arr(group_count, (MKL_INT)ldc);
+ const std::vector<MKL_INT> group_size(group_count, (MKL_INT)batchC); // Group size specifies number of GEMM operations per group (Which is batchC)
+
+ std::vector<const float *> a_array(batchC, nullptr);
+ std::vector<const float *> b_array(batchC, nullptr);
+ std::vector<float *> c_array(batchC, nullptr);
+
+ // This loop initializes the array pointers in the same way as the for loop
+ // in the normal sgemm version a few lines below
+ functional::Array<int, functional::Shape::size()> dims;
+ for(size_t i = 0; i < batchC; ++i) {
+ cShapeMetaF.dims(i, dims);
+ auto aIndex = aShapeMetaF.bindex(dims);
+ auto bIndex = bShapeMetaF.bindex(dims);
+
+ a_array[i] = A->data() + aIndex * strideA;
+ b_array[i] = B->data() + bIndex * strideB;
+ c_array[i] = C->data() + i * strideC;
+ }
+ cblas_sgemm_batch (CblasRowMajor,
+ &transa_arr[0],
+ &transb_arr[0],
+ &m_arr[0],
+ &n_arr[0],
+ &k_arr[0],
+ &alpha_arr[0],
+ &a_array[0],
+ &lda_arr[0],
+ &b_array[0],
+ &ldb_arr[0],
+ &beta_arr[0],
+ &c_array[0],
+ &ldc_arr[0],
+ group_count,
+ &group_size[0]);
+#else
+ functional::Array<int, functional::Shape::size()> dims;
+ for(size_t i = 0; i < batchC; ++i) {
+ cShapeMetaF.dims(i, dims);
+ auto aIndex = aShapeMetaF.bindex(dims);
+ auto bIndex = bShapeMetaF.bindex(dims);
+
+ sgemm(transA,
+ transB,
+ (int)m,
+ (int)n,
+ (int)k,
+ alpha,
+ A->data() + aIndex * strideA,
+ (int)lda,
+ B->data() + bIndex * strideB,
+ (int)ldb,
+ beta,
+ C->data() + i * strideC,
+ (int)ldc);
+ }
+#endif
+#else
+ C; A; B; transA; transB; beta; scalar;
+ ABORT("You need to compile with MKL in order to use the CPU version");
+#endif
+}
+
+
+void ProdBatchedLegacy(marian::Tensor C,
+ Ptr<Allocator> /*allocator*/,
+ const marian::Tensor A,
+ const marian::Tensor B,
+ bool transA,
+ bool transB,
+ float beta,
+ float scalar) {
+#if BLAS_FOUND
+ float alpha = scalar;
+
size_t batchA = A->shape().elements() / (A->shape()[-1] * A->shape()[-2]);
size_t batchB = B->shape().elements() / (B->shape()[-1] * B->shape()[-2]);
diff --git a/src/tensors/device.h b/src/tensors/device.h
index 0be6c076..5fe3c1fb 100644
--- a/src/tensors/device.h
+++ b/src/tensors/device.h
@@ -15,8 +15,9 @@ protected:
size_t size_{0};
size_t alignment_;
- size_t align(size_t size) {
- return (size_t)(ceil(size / (float)alignment_) * alignment_);
+ size_t align(size_t size) const {
+ size_t over = size + alignment_ - 1;
+ return over - (over % alignment_);
}
public:
diff --git a/src/tensors/gpu/prod.cpp b/src/tensors/gpu/prod.cpp
index 8cfa78ca..e996f58f 100755
--- a/src/tensors/gpu/prod.cpp
+++ b/src/tensors/gpu/prod.cpp
@@ -22,7 +22,7 @@ namespace gpu {
// It seems that the bias must be 8 byte aligned for the cublasLt epilogue to work. Therefore,
// if the bias pointer is not 8 byte aligned, we do a normal matmul in cublasLt and invoke a
// custom epilogue kernel.
-static constexpr int REQUIRED_BIAS_ALIGNMENT = 8;
+static constexpr int REQUIRED_BIAS_ALIGNMENT = 16; // @TODO: MJD: changed this to 16 to avoid alignment error on A100. Seems to work fine.
// Used to set preferences for cublasLt to filter out algos if matrices to not meet default 256 byte alignment
int getAlignmentUpTo256(const void *ptr) {
@@ -347,6 +347,139 @@ void ProdBatchedTyped(marian::Tensor C,
CUDA_CHECK(cudaSetDevice((int)C->getDeviceId().no));
ComputeType alpha = scalar;
+ // determine meta-shape of bdot operation. Essentially treat the last two dimensions as single elements
+ // such that (..., m, k) x (..., k, n) -> (..., m, n) where ... is a broadcastable shape as in element-wise kernels.
+
+ auto aShape = A->shape();
+ auto bShape = B->shape();
+
+ // make sure both shape have the same number of dimensions via broadcasting
+ size_t maxLength = std::max(aShape.size(), bShape.size());
+ if(aShape.size() != bShape.size()) {
+ Shape ones(std::vector<int>(maxLength, 1));
+ aShape = Shape::broadcast({aShape, ones});
+ bShape = Shape::broadcast({bShape, ones});
+ }
+
+ // Create meta-shapes without last 2 dimensions
+ Shape aShapeMeta, bShapeMeta, cShapeMeta;
+ aShapeMeta.resize(maxLength - 2);
+ bShapeMeta.resize(maxLength - 2);
+ for(size_t i = 0; i < maxLength - 2; ++i) {
+ aShapeMeta.set(i, aShape[i]);
+ bShapeMeta.set(i, bShape[i]);
+ }
+ cShapeMeta = Shape::broadcast({aShapeMeta, bShapeMeta});
+
+ size_t m = aShape[-2];
+ size_t k = aShape[-1];
+ if(transA)
+ std::swap(m, k);
+
+ size_t l = bShape[-2];
+ size_t n = bShape[-1];
+ if(transB)
+ std::swap(l, n);
+
+ size_t lda = aShape[-1];
+ size_t ldb = bShape[-1];
+ size_t ldc = bShape[-1];
+
+ if(transB)
+ ldc = bShape[-2];
+
+ cublasOperation_t opA = transA ? CUBLAS_OP_T : CUBLAS_OP_N;
+ cublasOperation_t opB = transB ? CUBLAS_OP_T : CUBLAS_OP_N;
+
+ auto backend = std::static_pointer_cast<gpu::Backend>(C->getBackend());
+ auto cublasHandle = backend->getCublasHandle();
+ auto compute = backend->getCudaComputeCapability();
+
+ auto strideA = m * k;
+ auto strideB = n * k;
+ auto strideC = n * m;
+
+ auto batchC = cShapeMeta.elements();
+
+ // Convert to functional shapes to be able to map dimensions. @TODO merge this
+ functional::Shape aShapeMetaF = aShapeMeta;
+ functional::Shape bShapeMetaF = bShapeMeta;
+ functional::Shape cShapeMetaF = cShapeMeta;
+
+ std::vector<const ElementType*> aptr;
+ std::vector<const ElementType*> bptr;
+ std::vector<ElementType*> cptr;
+
+ functional::Array<int, functional::Shape::size()> dims;
+ for(int i = 0; i < batchC; i++) {
+ cShapeMetaF.dims(i, dims);
+ auto aIndex = aShapeMetaF.bindex(dims);
+ auto bIndex = bShapeMetaF.bindex(dims);
+
+ aptr.push_back(A->data<ElementType>() + aIndex * strideA);
+ bptr.push_back(B->data<ElementType>() + bIndex * strideB);
+ cptr.push_back(C->data<ElementType>() + i * strideC);
+ }
+
+ // auto fails here from weird reason
+ IPtr<MemoryPiece> mp_aptr = allocator->alloc<const ElementType*>(aptr.size());
+ CudaCopy(aptr.data(), aptr.data() + aptr.size(), mp_aptr->data<const ElementType*>());
+
+ IPtr<MemoryPiece> mp_bptr = allocator->alloc<const ElementType*>(bptr.size());
+ CudaCopy(bptr.data(), bptr.data() + bptr.size(), mp_bptr->data<const ElementType*>());
+
+ IPtr<MemoryPiece> mp_cptr = allocator->alloc<ElementType*>(cptr.size());
+ CudaCopy(cptr.data(), cptr.data() + cptr.size(), mp_cptr->data<ElementType*>());
+
+ setTensorMode(cublasHandle);
+ TypedGemm<ElementType, ComputeType>::batchedGemm(cublasHandle, compute,
+ opB, opA,
+ n, m, k,
+ &alpha,
+ mp_bptr->data<const ElementType*>(), ldb,
+ mp_aptr->data<const ElementType*>(), lda,
+ &beta,
+ mp_cptr->data<ElementType*>(), ldc,
+ batchC);
+ unsetTensorMode(cublasHandle);
+
+ allocator->free(mp_aptr);
+ allocator->free(mp_bptr);
+ allocator->free(mp_cptr);
+}
+
+// @TODO: add version with compute type for completeness
+void ProdBatched(marian::Tensor C,
+ Ptr<Allocator> allocator,
+ const marian::Tensor A,
+ const marian::Tensor B,
+ bool transA,
+ bool transB,
+ float beta,
+ float scalar) {
+ if(C->type() == Type::float32) {
+ ProdBatchedTyped<float, float>(C, allocator, A, B, transA, transB, beta, scalar);
+#if COMPILE_FP16
+ } else if(C->type() == Type::float16) { // not a *.cu file
+ ProdBatchedTyped<half, half>(C, allocator, A, B, transA, transB, __float2half(beta), __float2half(scalar));
+#endif
+ } else {
+ ABORT("ProdBatched not implemented for element type {}", C->type());
+ }
+}
+
+template <typename ElementType, typename ComputeType>
+void ProdBatchedTypedLegacy(marian::Tensor C,
+ Ptr<Allocator> allocator,
+ const marian::Tensor A,
+ const marian::Tensor B,
+ bool transA,
+ bool transB,
+ ComputeType beta,
+ ComputeType scalar) {
+ CUDA_CHECK(cudaSetDevice((int)C->getDeviceId().no));
+ ComputeType alpha = scalar;
+
int batchA = A->shape().elements() / (A->shape()[-1] * A->shape()[-2]);
int batchB = B->shape().elements() / (B->shape()[-1] * B->shape()[-2]);
@@ -417,25 +550,26 @@ void ProdBatchedTyped(marian::Tensor C,
}
// @TODO: add version with compute type for completeness
-void ProdBatched(marian::Tensor C,
- Ptr<Allocator> allocator,
- const marian::Tensor A,
- const marian::Tensor B,
- bool transA,
- bool transB,
- float beta,
- float scalar) {
+void ProdBatchedLegacy(marian::Tensor C,
+ Ptr<Allocator> allocator,
+ const marian::Tensor A,
+ const marian::Tensor B,
+ bool transA,
+ bool transB,
+ float beta,
+ float scalar) {
if(C->type() == Type::float32) {
- ProdBatchedTyped<float, float>(C, allocator, A, B, transA, transB, beta, scalar);
+ ProdBatchedTypedLegacy<float, float>(C, allocator, A, B, transA, transB, beta, scalar);
#if COMPILE_FP16
} else if(C->type() == Type::float16) { // not a *.cu file
- ProdBatchedTyped<half, half>(C, allocator, A, B, transA, transB, __float2half(beta), __float2half(scalar));
+ ProdBatchedTypedLegacy<half, half>(C, allocator, A, B, transA, transB, __float2half(beta), __float2half(scalar));
#endif
} else {
- ABORT("ProdBatched not implemented for element type {}", C->type());
+ ABORT("ProdBatchedLegacy not implemented for element type {}", C->type());
}
}
+
#if CUDA_VERSION >= 11000 // Earlier versions of cublasLT do not support bias addition for fp32 and fp16.
static cublasStatus_t cublasLtAffineHelper(cublasLtHandle_t ltHandle, cublasOperation_t transA, cublasOperation_t transB,
diff --git a/src/tensors/tensor_operators.h b/src/tensors/tensor_operators.h
index ef485068..6e587953 100644
--- a/src/tensors/tensor_operators.h
+++ b/src/tensors/tensor_operators.h
@@ -104,6 +104,7 @@ DISPATCH7(Prod, marian::Tensor, const marian::Tensor&, const marian::Tensor&, bo
DISPATCH8(Prod, marian::Tensor, const marian::Tensor&, const marian::Tensor&, bool, bool, float, float, Type) // overloading since we want the default to for computeType be C->type() which difficult otherwise.
DISPATCH8(ProdBatched, marian::Tensor, Ptr<Allocator>, const marian::Tensor, const marian::Tensor, bool, bool, float, float)
+DISPATCH8(ProdBatchedLegacy, marian::Tensor, Ptr<Allocator>, const marian::Tensor, const marian::Tensor, bool, bool, float, float)
DISPATCH9(CSRProd, marian::Tensor, Ptr<Allocator>, const marian::Tensor&, const marian::Tensor&, const marian::Tensor&, const marian::Tensor&, bool, bool, float)
DISPATCH10(Affine, marian::Tensor, Ptr<Allocator>, const marian::Tensor&, const marian::Tensor&, const marian::Tensor&, bool, bool, float, float, bool)
diff --git a/src/tests/units/operator_tests.cpp b/src/tests/units/operator_tests.cpp
index 1a18da99..f3b5fda3 100644
--- a/src/tests/units/operator_tests.cpp
+++ b/src/tests/units/operator_tests.cpp
@@ -615,6 +615,66 @@ void tests(DeviceType device, Type floatType = Type::float32) {
CHECK(values2 == values);
}
+ SECTION("bdot") {
+ graph->clear();
+ values.clear();
+
+ std::vector<T> vA({ 1, 2,
+ 3, 4,
+ 5, 6,
+ 7, 8});
+
+ std::vector<T> vB({ 1, 2,
+ 3, 4,
+ 5, 6,
+ 7, 8,
+ 9, 10,
+ 11, 12});
+
+ std::vector<T> vC({ 7, 10,
+ 15, 22,
+ 19, 22,
+ 43, 50,
+ 31, 34,
+ 71, 78,
+ 23, 34,
+ 31, 46,
+ 67, 78,
+ 91, 106,
+ 111, 122,
+ 151, 166});
+
+ std::vector<T> vCt({ 5, 11,
+ 11, 25,
+ 17, 23,
+ 39, 53,
+ 29, 35,
+ 67, 81,
+ 17, 39,
+ 23, 53,
+ 61, 83,
+ 83, 113,
+ 105, 127,
+ 143, 173});
+
+ auto A = graph->param("A", {2, 1, 2, 2}, inits::fromVector(vA));
+ auto B = graph->param("B", {1, 3, 2, 2}, inits::fromVector(vB));
+
+ auto C = bdot(A, B, /*transA=*/false, /*transB=*/false);
+ auto Ct = bdot(A, B, /*transA=*/false, /*transB=*/true);
+
+ graph->forward();
+
+ CHECK(C->shape() == Shape({2, 3, 2, 2}));
+ CHECK(Ct->shape() == Shape({2, 3, 2, 2}));
+
+ C->val()->get(values);
+ CHECK(vC == values);
+
+ Ct->val()->get(values);
+ CHECK(vCt == values);
+ }
+
SECTION("repeat") {
graph->clear();
values.clear();
diff --git a/src/training/scheduler.h b/src/training/scheduler.h
index 8d4fa30c..3cc3b207 100644
--- a/src/training/scheduler.h
+++ b/src/training/scheduler.h
@@ -511,7 +511,8 @@ public:
state_->stalled = 0;
state_->maxStalled = 0;
for(const auto& validator : validators_) {
- state_->validators[validator->type()]["stalled"] = 0;
+ if(state_->validators[validator->type()])
+ state_->validators[validator->type()]["stalled"] = 0;
}
}