diff options
author | Hieu Hoang <hihoan@microsoft.com> | 2021-06-08 00:16:41 +0300 |
---|---|---|
committer | Hieu Hoang <hihoan@microsoft.com> | 2021-06-08 00:16:41 +0300 |
commit | bc4ad2408c308fe3d9ac31accdaa5019dc2187ba (patch) | |
tree | d259be53ae1c05fd47c26b62d51cc62d0acf15d1 | |
parent | f19ebbae69e06d85979ac16b76fb1acf0dc4e695 (diff) | |
parent | ce34df4d985d3ff86e4babfc9529c4aaa0aba57d (diff) |
Merge branch 'mjd/bdot' into hihoan/lsh7
-rw-r--r-- | CHANGELOG.md | 5 | ||||
-rw-r--r-- | CMakeLists.txt | 16 | ||||
-rw-r--r-- | VERSION | 2 | ||||
m--------- | regression-tests | 0 | ||||
m--------- | src/3rd_party/sentencepiece | 0 | ||||
-rw-r--r-- | src/graph/expression_graph.h | 2 | ||||
-rw-r--r-- | src/graph/expression_operators.cpp | 4 | ||||
-rw-r--r-- | src/graph/expression_operators.h | 6 | ||||
-rw-r--r-- | src/graph/node_operators_binary.h | 174 | ||||
-rw-r--r-- | src/models/transformer.h | 4 | ||||
-rw-r--r-- | src/tensors/allocator.h | 5 | ||||
-rwxr-xr-x | src/tensors/cpu/prod.cpp | 156 | ||||
-rw-r--r-- | src/tensors/device.h | 5 | ||||
-rwxr-xr-x | src/tensors/gpu/prod.cpp | 158 | ||||
-rw-r--r-- | src/tensors/tensor_operators.h | 1 | ||||
-rw-r--r-- | src/tests/units/operator_tests.cpp | 60 | ||||
-rw-r--r-- | src/training/scheduler.h | 3 |
17 files changed, 575 insertions, 26 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md index 7f41b8d1..2f4141c8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. ### Added - Early stopping based on first, all, or any validation metrics via `--early-stopping-on` +- Compute 8.6 support if using CUDA>=11.1 - Support for RMSNorm as drop-in replace for LayerNorm from `Biao Zhang; Rico Sennrich (2019). Root Mean Square Layer Normalization`. Enabled in Transformer model via `--transformer-postprocess dar` instead of `dan`. - Extend suppression of unwanted output symbols, specifically "\n" from default vocabulary if generated by SentencePiece with byte-fallback. Deactivates with --allow-special - Allow for fine-grained CPU intrinsics overrides when BUILD_ARCH != native e.g. -DBUILD_ARCH=x86-64 -DCOMPILE_AVX512=off @@ -25,8 +26,10 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. - Dynamic gradient-scaling with `--dynamic-gradient-scaling`. - Add unit tests for binary files. - Fix compilation with OMP +- Compute aligned memory sizes using exact sizing ### Fixed +- Adding new validation metrics when training is restarted and --reset-valid-stalled is used - Missing depth-scaling in transformer FFN - Fixed an issue when loading intgemm16 models from unaligned memory. - Fix building marian with gcc 9.3+ and FBGEMM @@ -41,6 +44,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. - Broken links to MNIST data sets ### Changed +- Set REQUIRED_BIAS_ALIGNMENT = 16 in tensors/gpu/prod.cpp to avoid memory-misalignment on certain Ampere GPUs. - For BUILD_ARCH != native enable all intrinsics types by default, can be disabled like this: -DCOMPILE_AVX512=off - Moved FBGEMM pointer to commit c258054 for gcc 9.3+ fix - Change compile options a la -DCOMPILE_CUDA_SM35 to -DCOMPILE_KEPLER, -DCOMPILE_MAXWELL, @@ -50,6 +54,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. - Developer documentation framework based on Sphinx+Doxygen+Breathe+Exhale - Expresion graph documentation (#788) - Graph operators documentation (#801) +- Remove unused variable from expression graph ## [1.10.0] - 2021-02-06 diff --git a/CMakeLists.txt b/CMakeLists.txt index 79c8585e..119bc01f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -325,6 +325,16 @@ if(CUDA_FOUND) option(COMPILE_AMPERE "Compile GPU version with SM80 support" ON) LIST(APPEND COMPUTE -Wno-deprecated-gpu-targets) endif() + if(CUDA_VERSION VERSION_EQUAL "11.1" OR CUDA_VERSION VERSION_GREATER "11.1") + option(COMPILE_KEPLER "Compile GPU version with SM35 support" OFF) # deprecated for CUDA 11 + option(COMPILE_MAXWELL "Compile GPU version with SM50 support" OFF) # deprecated for CUDA 11 + option(COMPILE_PASCAL "Compile GPU version with SM60 support" ON) + option(COMPILE_VOLTA "Compile GPU version with SM70 support" ON) + option(COMPILE_TURING "Compile GPU version with SM75 support" ON) + option(COMPILE_AMPERE "Compile GPU version with SM80 support" ON) + option(COMPILE_AMPERE_RTX "Compile GPU version with SM86 support" ON) + LIST(APPEND COMPUTE -Wno-deprecated-gpu-targets) + endif() if(COMPILE_KEPLER) message(STATUS "Compiling code for Kepler GPUs") @@ -354,6 +364,12 @@ if(CUDA_FOUND) LIST(APPEND COMPUTE -gencode=arch=compute_80,code=sm_80; -gencode=arch=compute_80,code=compute_80) # Ampere GPUs endif(COMPILE_AMPERE) endif() + if(CUDA_VERSION VERSION_EQUAL "11.1" OR CUDA_VERSION VERSION_GREATER "11.1") + if(COMPILE_AMPERE_RTX) + message(STATUS "Compiling code for Ampere RTX GPUs") + LIST(APPEND COMPUTE -gencode=arch=compute_86,code=sm_86; -gencode=arch=compute_86,code=compute_86) # Ampere RTX GPUs + endif(COMPILE_AMPERE_RTX) + endif() if(USE_STATIC_LIBS) set(EXT_LIBS ${EXT_LIBS} ${CUDA_curand_LIBRARY} ${CUDA_CUBLAS_LIBRARIES} ${CUDA_cusparse_LIBRARY}) @@ -1,2 +1,2 @@ -v1.10.18 +v1.10.20 diff --git a/regression-tests b/regression-tests -Subproject 1afd4eb1014ac451c6a3d6f9b5d34c322902e62 +Subproject 7d612ca5e4b27a76f92584dad76d240e34f216d diff --git a/src/3rd_party/sentencepiece b/src/3rd_party/sentencepiece -Subproject 8336bbd0c1cfba02a879afe625bf1ddaf7cd93c +Subproject 6f24a6b52a521a3467e99a9c175ba9e13690521 diff --git a/src/graph/expression_graph.h b/src/graph/expression_graph.h index adc0aeae..fce7d532 100644 --- a/src/graph/expression_graph.h +++ b/src/graph/expression_graph.h @@ -145,8 +145,6 @@ protected: // (these are protected, not private, for ONNX exporting) Ptr<Tensors> tensors_; private: - std::unordered_map<size_t, std::vector<Expr>> memoized_; - Type defaultElementType_{Type::float32}; // Type used for storing parameters, currently all parameters have to have the same type bool inferenceOnly_{false}; // a flag holds whether the graph is used for inference only diff --git a/src/graph/expression_operators.cpp b/src/graph/expression_operators.cpp index 6c7ef91c..baec94df 100644 --- a/src/graph/expression_operators.cpp +++ b/src/graph/expression_operators.cpp @@ -519,6 +519,10 @@ Expr bdot(Expr a, Expr b, bool transA, bool transB, float scale) { return Expression<DotBatchedNodeOp>(a, b, transA, transB, scale); } +Expr bdot_legacy(Expr a, Expr b, bool transA, bool transB, float scale) { + return Expression<DotBatchedLegacyNodeOp>(a, b, transA, transB, scale); +} + Expr affineDefault(Expr a, Expr b, Expr bias, bool transA, bool transB, float scale) { // general version, MKL, CBlas or CUDA diff --git a/src/graph/expression_operators.h b/src/graph/expression_operators.h index f3d84eb6..c1570eff 100644 --- a/src/graph/expression_operators.h +++ b/src/graph/expression_operators.h @@ -478,6 +478,12 @@ Expr bdot(Expr a, bool transB = false, float scalar = 1.f); +Expr bdot_legacy(Expr a, + Expr b, + bool transA = false, + bool transB = false, + float scalar = 1.f); + /** * Performs an affine transformation. * Computes diff --git a/src/graph/node_operators_binary.h b/src/graph/node_operators_binary.h index 91fc29da..169b1420 100644 --- a/src/graph/node_operators_binary.h +++ b/src/graph/node_operators_binary.h @@ -529,11 +529,27 @@ public: shapeB.set(-1, b->shape()[-2]); } - Shape outShape = shapeA; - outShape.set(-1, shapeB[-1]); ABORT_IF(shapeA[-1] != shapeB[-2], - "Batched matrix product requires inner dimensions to match in {}{} * {}{}", std::string(shapeA), transA, std::string(shapeB), transB); - return outShape; + "Batched matrix product requires inner dimensions to match in {}{} * {}{}", + std::string(shapeA), transA, std::string(shapeB), transB); + + // create shapes for batch dimensions only + auto shapeBatchA = shapeA; + shapeBatchA.set(-1, 1); + shapeBatchA.set(-2, 1); + + auto shapeBatchB = shapeB; + shapeBatchB.set(-1, 1); + shapeBatchB.set(-2, 1); + + // broadcast batch dimensions + auto shapeOut = Shape::broadcast({shapeBatchA, shapeBatchB}); + + // set non-batch dimensions in output + shapeOut.set(-2, shapeA[-2]); + shapeOut.set(-1, shapeB[-1]); + + return shapeOut; } NodeOps forwardOps() override { @@ -655,6 +671,156 @@ public: const std::string color() override { return "orange"; } }; +class DotBatchedLegacyNodeOp : public NaryNodeOp { +private: + friend class SerializationHelpers; + bool transA_; + bool transB_; + float scalar_; + +public: + DotBatchedLegacyNodeOp(Expr a, Expr b, bool transA, bool transB, float scalar) + : NaryNodeOp({a, b}, newShape(a, b, transA, transB)), + transA_(transA), + transB_(transB), + scalar_(scalar) {} + + Shape newShape(Expr a, Expr b, bool transA, bool transB) { + auto shapeA = a->shape(); + if(transA) { + shapeA.set(-2, a->shape()[-1]); + shapeA.set(-1, a->shape()[-2]); + } + + auto shapeB = b->shape(); + if(transB) { + shapeB.set(-2, b->shape()[-1]); + shapeB.set(-1, b->shape()[-2]); + } + + Shape outShape = shapeA; + outShape.set(-1, shapeB[-1]); + ABORT_IF(shapeA[-1] != shapeB[-2], + "Batched matrix product requires inner dimensions to match in {}{} * {}{}", std::string(shapeA), transA, std::string(shapeB), transB); + return outShape; + } + + NodeOps forwardOps() override { + // C = alpha * dot(op(A), op(B)) + return {NodeOp(ProdBatchedLegacy(val_, + graph()->allocator(), + child(0)->val(), + child(1)->val(), + transA_, + transB_, + 0.f, + scalar_))}; + } + + NodeOps backwardOps() override { + // D is the adjoint, the matrix of derivatives + // df/dA += alpha * dot(D, op(B).T) + // df/dB += alpha * dot(op(A).T, D) + // beta set to 1.0 in gemm, C = alpha * dot(op(A), op(B)) + beta * C + // to sum gradients from different graph parts + + if(!transA_ && transB_) + return {NodeOp(ProdBatchedLegacy(child(0)->grad(), + graph()->allocator(), + adj_, + child(1)->val(), + false, + false, + 1.0, + scalar_)), + NodeOp(ProdBatchedLegacy(child(1)->grad(), + graph()->allocator(), + adj_, + child(0)->val(), + true, + false, + 1.0, + scalar_))}; + if(transA_ && !transB_) + return {NodeOp(ProdBatchedLegacy(child(0)->grad(), + graph()->allocator(), + child(1)->val(), + adj_, + false, + true, + 1.0, + scalar_)), + NodeOp(ProdBatchedLegacy(child(1)->grad(), + graph()->allocator(), + child(0)->val(), + adj_, + false, + false, + 1.0, + scalar_))}; + if(transA_ && transB_) + return {NodeOp(ProdBatchedLegacy(child(0)->grad(), + graph()->allocator(), + child(1)->val(), + adj_, + true, + true, + 1.0, + scalar_)), + NodeOp(ProdBatchedLegacy(child(1)->grad(), + graph()->allocator(), + adj_, + child(0)->val(), + true, + true, + 1.0, + scalar_))}; + return {NodeOp(ProdBatchedLegacy(child(0)->grad(), + graph()->allocator(), + adj_, + child(1)->val(), + false, + true, + 1.0, + scalar_)), + NodeOp(ProdBatchedLegacy(child(1)->grad(), + graph()->allocator(), + child(0)->val(), + adj_, + true, + false, + 1.0, + scalar_))}; + } + + const std::string type() override { return "bdot_legacy"; } + + virtual size_t hash() override { + size_t seed = NaryNodeOp::hash(); + util::hash_combine(seed, transA_); + util::hash_combine(seed, transB_); + util::hash_combine(seed, scalar_); + return seed; + } + + virtual bool equal(Expr node) override { + if(!NaryNodeOp::equal(node)) + return false; + auto cnode = std::dynamic_pointer_cast<DotBatchedLegacyNodeOp>(node); + if(!cnode) + return false; + if(transA_ != cnode->transA_) + return false; + if(transB_ != cnode->transB_) + return false; + if(scalar_ != cnode->scalar_) + return false; + return true; + } + + const std::string color() override { return "orange"; } +}; + // Note: To reduce code duplication, we use the same NodeOp for C = op(S) x D and C = D x op(S). // Set swapOperands to select the latter. class CSRDotNodeOp : public NaryNodeOp { diff --git a/src/models/transformer.h b/src/models/transformer.h index 1da02318..a792de8b 100644 --- a/src/models/transformer.h +++ b/src/models/transformer.h @@ -249,7 +249,7 @@ public: // multiplicative attention with flattened softmax float scale = 1.0f / std::sqrt((float)dk); // scaling to avoid extreme values due to matrix multiplication - auto z = bdot(q, k, false, true, scale); // [-4: beam depth * batch size, -3: num heads, -2: max tgt length, -1: max src length] + auto z = bdot_legacy(q, k, false, true, scale); // [-4: beam depth * batch size, -3: num heads, -2: max tgt length, -1: max src length] // mask out garbage beyond end of sequences z = z + mask; @@ -264,7 +264,7 @@ public: weights = dropout(weights, inference_ ? 0 : opt<float>("transformer-dropout-attention")); // apply attention weights to values - auto output = bdot(weights, v); // [-4: beam depth * batch size, -3: num heads, -2: max tgt length, -1: split vector dim] + auto output = bdot_legacy(weights, v); // [-4: beam depth * batch size, -3: num heads, -2: max tgt length, -1: split vector dim] return output; } diff --git a/src/tensors/allocator.h b/src/tensors/allocator.h index 9dc44f58..1844be14 100644 --- a/src/tensors/allocator.h +++ b/src/tensors/allocator.h @@ -175,8 +175,9 @@ public: reserve(bytes); } - size_t alignedSize(size_t size) { - return (size_t)(ceil(size / (double)alignment_) * alignment_); + size_t alignedSize(size_t size) const { + size_t over = size + alignment_ - 1; + return over - (over % alignment_); } void throwAtReallocation(bool throwRealloc) { throw_ = throwRealloc; } diff --git a/src/tensors/cpu/prod.cpp b/src/tensors/cpu/prod.cpp index 6e28bdd2..07cc2b99 100755 --- a/src/tensors/cpu/prod.cpp +++ b/src/tensors/cpu/prod.cpp @@ -93,6 +93,162 @@ void ProdBatched(marian::Tensor C, #if BLAS_FOUND float alpha = scalar; + // determine meta-shape of bdot operation. Essentially treat the last two dimensions as single elements + // such that (..., m, k) x (..., k, n) -> (..., m, n) where ... is a broadcastable shape as in element-wise kernels. + + auto aShape = A->shape(); + auto bShape = B->shape(); + + // make sure both shape have the same number of dimensions via broadcasting + size_t maxLength = std::max(aShape.size(), bShape.size()); + if(aShape.size() != bShape.size()) { + Shape ones(std::vector<int>(maxLength, 1)); + aShape = Shape::broadcast({aShape, ones}); + bShape = Shape::broadcast({bShape, ones}); + } + + // Create meta-shapes without last 2 dimensions + Shape aShapeMeta, bShapeMeta, cShapeMeta; + aShapeMeta.resize(maxLength - 2); + bShapeMeta.resize(maxLength - 2); + for(size_t i = 0; i < maxLength - 2; ++i) { + aShapeMeta.set(i, aShape[i]); + bShapeMeta.set(i, bShape[i]); + } + cShapeMeta = Shape::broadcast({aShapeMeta, bShapeMeta}); + + size_t m = aShape[-2]; + size_t k = aShape[-1]; + if(transA) + std::swap(m, k); + + size_t l = bShape[-2]; + size_t n = bShape[-1]; + if(transB) + std::swap(l, n); + + size_t lda = aShape[-1]; + size_t ldb = bShape[-1]; + size_t ldc = bShape[-1]; + + if(transB) + ldc = bShape[-2]; + + auto strideA = m * k; + auto strideB = n * k; + auto strideC = n * m; + + auto batchC = cShapeMeta.elements(); + + // Convert to functional shapes to be able to map dimensions. @TODO merge this + functional::Shape aShapeMetaF = aShapeMeta; + functional::Shape bShapeMetaF = bShapeMeta; + functional::Shape cShapeMetaF = cShapeMeta; + +#if MKL_FOUND + CBLAS_TRANSPOSE transA_forarr = CblasNoTrans; + CBLAS_TRANSPOSE transB_forarr = CblasNoTrans; + + if(transA) + transA_forarr = CblasTrans; + + if(transB) + transB_forarr = CblasTrans; + + /* cblas_sgemm_batch allows us to group all the small GEMMs that are done in a for loop with sgemm and compute + * them in only one MKL call. For the API documentation refer to + * https://software.intel.com/content/www/us/en/develop/documentation/mkl-developer-reference-c/top/blas-and-sparse-blas-routines/blas-like-extensions/cblas-gemm-batch.html + * The API supports dependencies, where you can specify one "group" of GEMMs to be computed after another. (This controlled by the group_count parameter). + * In our case, the operations are not dependent on one another so we hardcode one group. The rest of the arguments (with the exception of group_size) are + * the same as the ones that cblas_sgemm expects, with the difference that we are supposed to provide an array pointer (One element per group). + * Weirdly enough, we are required to to provide all of the integer arguments as the MKL_INT datatype + */ + + static const constexpr size_t group_count = 1; // We have one group + const std::vector<CBLAS_TRANSPOSE> transa_arr(group_count, transA_forarr); + const std::vector<CBLAS_TRANSPOSE> transb_arr(group_count, transB_forarr); + const std::vector<MKL_INT> m_arr(group_count, (MKL_INT)m); + const std::vector<MKL_INT> n_arr(group_count, (MKL_INT)n); + const std::vector<MKL_INT> k_arr(group_count, (MKL_INT)k); + const std::vector<float> alpha_arr(group_count, alpha); + const std::vector<float> beta_arr(group_count, beta); + const std::vector<MKL_INT> lda_arr(group_count, (MKL_INT)lda); + const std::vector<MKL_INT> ldb_arr(group_count, (MKL_INT)ldb); + const std::vector<MKL_INT> ldc_arr(group_count, (MKL_INT)ldc); + const std::vector<MKL_INT> group_size(group_count, (MKL_INT)batchC); // Group size specifies number of GEMM operations per group (Which is batchC) + + std::vector<const float *> a_array(batchC, nullptr); + std::vector<const float *> b_array(batchC, nullptr); + std::vector<float *> c_array(batchC, nullptr); + + // This loop initializes the array pointers in the same way as the for loop + // in the normal sgemm version a few lines below + functional::Array<int, functional::Shape::size()> dims; + for(size_t i = 0; i < batchC; ++i) { + cShapeMetaF.dims(i, dims); + auto aIndex = aShapeMetaF.bindex(dims); + auto bIndex = bShapeMetaF.bindex(dims); + + a_array[i] = A->data() + aIndex * strideA; + b_array[i] = B->data() + bIndex * strideB; + c_array[i] = C->data() + i * strideC; + } + cblas_sgemm_batch (CblasRowMajor, + &transa_arr[0], + &transb_arr[0], + &m_arr[0], + &n_arr[0], + &k_arr[0], + &alpha_arr[0], + &a_array[0], + &lda_arr[0], + &b_array[0], + &ldb_arr[0], + &beta_arr[0], + &c_array[0], + &ldc_arr[0], + group_count, + &group_size[0]); +#else + functional::Array<int, functional::Shape::size()> dims; + for(size_t i = 0; i < batchC; ++i) { + cShapeMetaF.dims(i, dims); + auto aIndex = aShapeMetaF.bindex(dims); + auto bIndex = bShapeMetaF.bindex(dims); + + sgemm(transA, + transB, + (int)m, + (int)n, + (int)k, + alpha, + A->data() + aIndex * strideA, + (int)lda, + B->data() + bIndex * strideB, + (int)ldb, + beta, + C->data() + i * strideC, + (int)ldc); + } +#endif +#else + C; A; B; transA; transB; beta; scalar; + ABORT("You need to compile with MKL in order to use the CPU version"); +#endif +} + + +void ProdBatchedLegacy(marian::Tensor C, + Ptr<Allocator> /*allocator*/, + const marian::Tensor A, + const marian::Tensor B, + bool transA, + bool transB, + float beta, + float scalar) { +#if BLAS_FOUND + float alpha = scalar; + size_t batchA = A->shape().elements() / (A->shape()[-1] * A->shape()[-2]); size_t batchB = B->shape().elements() / (B->shape()[-1] * B->shape()[-2]); diff --git a/src/tensors/device.h b/src/tensors/device.h index 0be6c076..5fe3c1fb 100644 --- a/src/tensors/device.h +++ b/src/tensors/device.h @@ -15,8 +15,9 @@ protected: size_t size_{0}; size_t alignment_; - size_t align(size_t size) { - return (size_t)(ceil(size / (float)alignment_) * alignment_); + size_t align(size_t size) const { + size_t over = size + alignment_ - 1; + return over - (over % alignment_); } public: diff --git a/src/tensors/gpu/prod.cpp b/src/tensors/gpu/prod.cpp index 8cfa78ca..e996f58f 100755 --- a/src/tensors/gpu/prod.cpp +++ b/src/tensors/gpu/prod.cpp @@ -22,7 +22,7 @@ namespace gpu { // It seems that the bias must be 8 byte aligned for the cublasLt epilogue to work. Therefore, // if the bias pointer is not 8 byte aligned, we do a normal matmul in cublasLt and invoke a // custom epilogue kernel. -static constexpr int REQUIRED_BIAS_ALIGNMENT = 8; +static constexpr int REQUIRED_BIAS_ALIGNMENT = 16; // @TODO: MJD: changed this to 16 to avoid alignment error on A100. Seems to work fine. // Used to set preferences for cublasLt to filter out algos if matrices to not meet default 256 byte alignment int getAlignmentUpTo256(const void *ptr) { @@ -347,6 +347,139 @@ void ProdBatchedTyped(marian::Tensor C, CUDA_CHECK(cudaSetDevice((int)C->getDeviceId().no)); ComputeType alpha = scalar; + // determine meta-shape of bdot operation. Essentially treat the last two dimensions as single elements + // such that (..., m, k) x (..., k, n) -> (..., m, n) where ... is a broadcastable shape as in element-wise kernels. + + auto aShape = A->shape(); + auto bShape = B->shape(); + + // make sure both shape have the same number of dimensions via broadcasting + size_t maxLength = std::max(aShape.size(), bShape.size()); + if(aShape.size() != bShape.size()) { + Shape ones(std::vector<int>(maxLength, 1)); + aShape = Shape::broadcast({aShape, ones}); + bShape = Shape::broadcast({bShape, ones}); + } + + // Create meta-shapes without last 2 dimensions + Shape aShapeMeta, bShapeMeta, cShapeMeta; + aShapeMeta.resize(maxLength - 2); + bShapeMeta.resize(maxLength - 2); + for(size_t i = 0; i < maxLength - 2; ++i) { + aShapeMeta.set(i, aShape[i]); + bShapeMeta.set(i, bShape[i]); + } + cShapeMeta = Shape::broadcast({aShapeMeta, bShapeMeta}); + + size_t m = aShape[-2]; + size_t k = aShape[-1]; + if(transA) + std::swap(m, k); + + size_t l = bShape[-2]; + size_t n = bShape[-1]; + if(transB) + std::swap(l, n); + + size_t lda = aShape[-1]; + size_t ldb = bShape[-1]; + size_t ldc = bShape[-1]; + + if(transB) + ldc = bShape[-2]; + + cublasOperation_t opA = transA ? CUBLAS_OP_T : CUBLAS_OP_N; + cublasOperation_t opB = transB ? CUBLAS_OP_T : CUBLAS_OP_N; + + auto backend = std::static_pointer_cast<gpu::Backend>(C->getBackend()); + auto cublasHandle = backend->getCublasHandle(); + auto compute = backend->getCudaComputeCapability(); + + auto strideA = m * k; + auto strideB = n * k; + auto strideC = n * m; + + auto batchC = cShapeMeta.elements(); + + // Convert to functional shapes to be able to map dimensions. @TODO merge this + functional::Shape aShapeMetaF = aShapeMeta; + functional::Shape bShapeMetaF = bShapeMeta; + functional::Shape cShapeMetaF = cShapeMeta; + + std::vector<const ElementType*> aptr; + std::vector<const ElementType*> bptr; + std::vector<ElementType*> cptr; + + functional::Array<int, functional::Shape::size()> dims; + for(int i = 0; i < batchC; i++) { + cShapeMetaF.dims(i, dims); + auto aIndex = aShapeMetaF.bindex(dims); + auto bIndex = bShapeMetaF.bindex(dims); + + aptr.push_back(A->data<ElementType>() + aIndex * strideA); + bptr.push_back(B->data<ElementType>() + bIndex * strideB); + cptr.push_back(C->data<ElementType>() + i * strideC); + } + + // auto fails here from weird reason + IPtr<MemoryPiece> mp_aptr = allocator->alloc<const ElementType*>(aptr.size()); + CudaCopy(aptr.data(), aptr.data() + aptr.size(), mp_aptr->data<const ElementType*>()); + + IPtr<MemoryPiece> mp_bptr = allocator->alloc<const ElementType*>(bptr.size()); + CudaCopy(bptr.data(), bptr.data() + bptr.size(), mp_bptr->data<const ElementType*>()); + + IPtr<MemoryPiece> mp_cptr = allocator->alloc<ElementType*>(cptr.size()); + CudaCopy(cptr.data(), cptr.data() + cptr.size(), mp_cptr->data<ElementType*>()); + + setTensorMode(cublasHandle); + TypedGemm<ElementType, ComputeType>::batchedGemm(cublasHandle, compute, + opB, opA, + n, m, k, + &alpha, + mp_bptr->data<const ElementType*>(), ldb, + mp_aptr->data<const ElementType*>(), lda, + &beta, + mp_cptr->data<ElementType*>(), ldc, + batchC); + unsetTensorMode(cublasHandle); + + allocator->free(mp_aptr); + allocator->free(mp_bptr); + allocator->free(mp_cptr); +} + +// @TODO: add version with compute type for completeness +void ProdBatched(marian::Tensor C, + Ptr<Allocator> allocator, + const marian::Tensor A, + const marian::Tensor B, + bool transA, + bool transB, + float beta, + float scalar) { + if(C->type() == Type::float32) { + ProdBatchedTyped<float, float>(C, allocator, A, B, transA, transB, beta, scalar); +#if COMPILE_FP16 + } else if(C->type() == Type::float16) { // not a *.cu file + ProdBatchedTyped<half, half>(C, allocator, A, B, transA, transB, __float2half(beta), __float2half(scalar)); +#endif + } else { + ABORT("ProdBatched not implemented for element type {}", C->type()); + } +} + +template <typename ElementType, typename ComputeType> +void ProdBatchedTypedLegacy(marian::Tensor C, + Ptr<Allocator> allocator, + const marian::Tensor A, + const marian::Tensor B, + bool transA, + bool transB, + ComputeType beta, + ComputeType scalar) { + CUDA_CHECK(cudaSetDevice((int)C->getDeviceId().no)); + ComputeType alpha = scalar; + int batchA = A->shape().elements() / (A->shape()[-1] * A->shape()[-2]); int batchB = B->shape().elements() / (B->shape()[-1] * B->shape()[-2]); @@ -417,25 +550,26 @@ void ProdBatchedTyped(marian::Tensor C, } // @TODO: add version with compute type for completeness -void ProdBatched(marian::Tensor C, - Ptr<Allocator> allocator, - const marian::Tensor A, - const marian::Tensor B, - bool transA, - bool transB, - float beta, - float scalar) { +void ProdBatchedLegacy(marian::Tensor C, + Ptr<Allocator> allocator, + const marian::Tensor A, + const marian::Tensor B, + bool transA, + bool transB, + float beta, + float scalar) { if(C->type() == Type::float32) { - ProdBatchedTyped<float, float>(C, allocator, A, B, transA, transB, beta, scalar); + ProdBatchedTypedLegacy<float, float>(C, allocator, A, B, transA, transB, beta, scalar); #if COMPILE_FP16 } else if(C->type() == Type::float16) { // not a *.cu file - ProdBatchedTyped<half, half>(C, allocator, A, B, transA, transB, __float2half(beta), __float2half(scalar)); + ProdBatchedTypedLegacy<half, half>(C, allocator, A, B, transA, transB, __float2half(beta), __float2half(scalar)); #endif } else { - ABORT("ProdBatched not implemented for element type {}", C->type()); + ABORT("ProdBatchedLegacy not implemented for element type {}", C->type()); } } + #if CUDA_VERSION >= 11000 // Earlier versions of cublasLT do not support bias addition for fp32 and fp16. static cublasStatus_t cublasLtAffineHelper(cublasLtHandle_t ltHandle, cublasOperation_t transA, cublasOperation_t transB, diff --git a/src/tensors/tensor_operators.h b/src/tensors/tensor_operators.h index ef485068..6e587953 100644 --- a/src/tensors/tensor_operators.h +++ b/src/tensors/tensor_operators.h @@ -104,6 +104,7 @@ DISPATCH7(Prod, marian::Tensor, const marian::Tensor&, const marian::Tensor&, bo DISPATCH8(Prod, marian::Tensor, const marian::Tensor&, const marian::Tensor&, bool, bool, float, float, Type) // overloading since we want the default to for computeType be C->type() which difficult otherwise. DISPATCH8(ProdBatched, marian::Tensor, Ptr<Allocator>, const marian::Tensor, const marian::Tensor, bool, bool, float, float) +DISPATCH8(ProdBatchedLegacy, marian::Tensor, Ptr<Allocator>, const marian::Tensor, const marian::Tensor, bool, bool, float, float) DISPATCH9(CSRProd, marian::Tensor, Ptr<Allocator>, const marian::Tensor&, const marian::Tensor&, const marian::Tensor&, const marian::Tensor&, bool, bool, float) DISPATCH10(Affine, marian::Tensor, Ptr<Allocator>, const marian::Tensor&, const marian::Tensor&, const marian::Tensor&, bool, bool, float, float, bool) diff --git a/src/tests/units/operator_tests.cpp b/src/tests/units/operator_tests.cpp index 1a18da99..f3b5fda3 100644 --- a/src/tests/units/operator_tests.cpp +++ b/src/tests/units/operator_tests.cpp @@ -615,6 +615,66 @@ void tests(DeviceType device, Type floatType = Type::float32) { CHECK(values2 == values); } + SECTION("bdot") { + graph->clear(); + values.clear(); + + std::vector<T> vA({ 1, 2, + 3, 4, + 5, 6, + 7, 8}); + + std::vector<T> vB({ 1, 2, + 3, 4, + 5, 6, + 7, 8, + 9, 10, + 11, 12}); + + std::vector<T> vC({ 7, 10, + 15, 22, + 19, 22, + 43, 50, + 31, 34, + 71, 78, + 23, 34, + 31, 46, + 67, 78, + 91, 106, + 111, 122, + 151, 166}); + + std::vector<T> vCt({ 5, 11, + 11, 25, + 17, 23, + 39, 53, + 29, 35, + 67, 81, + 17, 39, + 23, 53, + 61, 83, + 83, 113, + 105, 127, + 143, 173}); + + auto A = graph->param("A", {2, 1, 2, 2}, inits::fromVector(vA)); + auto B = graph->param("B", {1, 3, 2, 2}, inits::fromVector(vB)); + + auto C = bdot(A, B, /*transA=*/false, /*transB=*/false); + auto Ct = bdot(A, B, /*transA=*/false, /*transB=*/true); + + graph->forward(); + + CHECK(C->shape() == Shape({2, 3, 2, 2})); + CHECK(Ct->shape() == Shape({2, 3, 2, 2})); + + C->val()->get(values); + CHECK(vC == values); + + Ct->val()->get(values); + CHECK(vCt == values); + } + SECTION("repeat") { graph->clear(); values.clear(); diff --git a/src/training/scheduler.h b/src/training/scheduler.h index 8d4fa30c..3cc3b207 100644 --- a/src/training/scheduler.h +++ b/src/training/scheduler.h @@ -511,7 +511,8 @@ public: state_->stalled = 0; state_->maxStalled = 0; for(const auto& validator : validators_) { - state_->validators[validator->type()]["stalled"] = 0; + if(state_->validators[validator->type()]) + state_->validators[validator->type()]["stalled"] = 0; } } |