Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorrhenry-nv <72179960+rhenry-nv@users.noreply.github.com>2021-04-09 07:46:27 +0300
committerGitHub <noreply@github.com>2021-04-09 07:46:27 +0300
commitfddd0e0661eb13b2a132e401e268315a35e468f7 (patch)
tree279578bcb7a4dd3cddfeb7cb4a7fa193574b5aaf /src
parent0223ce90b1d6afaa047f2baeb4d0689f87d7ae81 (diff)
Adds better Affine support for GPUs when using CUDA 11. Introduces a new bias addition kernel for CUDA < 11 (#778)
Co-authored-by: Marcin Junczys-Dowmunt <marcinjd@microsoft.com>
Diffstat (limited to 'src')
-rw-r--r--src/CMakeLists.txt1
-rw-r--r--src/graph/expression_operators.cpp12
-rw-r--r--src/graph/expression_operators.h12
-rw-r--r--src/graph/node_operators_binary.h124
-rw-r--r--src/layers/generic.h25
-rw-r--r--src/layers/output.cpp2
-rw-r--r--src/models/transformer.h27
-rwxr-xr-xsrc/tensors/cpu/prod.cpp17
-rw-r--r--src/tensors/dispatch.h42
-rwxr-xr-xsrc/tensors/gpu/prod.cpp217
-rw-r--r--src/tensors/gpu/prod.cu69
-rw-r--r--src/tensors/gpu/prod.h15
-rw-r--r--src/tensors/tensor_operators.h2
-rw-r--r--src/tests/units/operator_tests.cpp21
14 files changed, 536 insertions, 50 deletions
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index 64b86a69..cf276137 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -175,6 +175,7 @@ if(CUDA_FOUND)
tensors/gpu/device.cu
tensors/gpu/algorithm.cu
tensors/gpu/prod.cpp
+ tensors/gpu/prod.cu
tensors/gpu/prod_sparse.cpp
tensors/gpu/topk.cu
tensors/gpu/element.cu
diff --git a/src/graph/expression_operators.cpp b/src/graph/expression_operators.cpp
index f354caab..048c7478 100644
--- a/src/graph/expression_operators.cpp
+++ b/src/graph/expression_operators.cpp
@@ -1,4 +1,5 @@
#include "graph/expression_operators.h"
+#include "common/definitions.h"
#include "layers/constructors.h"
#include "graph/node_operators.h"
@@ -518,7 +519,7 @@ Expr bdot(Expr a, Expr b, bool transA, bool transB, float scale) {
return Expression<DotBatchedNodeOp>(a, b, transA, transB, scale);
}
-static Expr affineDefault(Expr a, Expr b, Expr bias, bool transA, bool transB, float scale) {
+Expr affineDefault(Expr a, Expr b, Expr bias, bool transA, bool transB, float scale) {
// general version, MKL, CBlas or CUDA
int rows = a->shape().elements() / a->shape()[-1];
@@ -577,6 +578,15 @@ Expr affine(Expr a, Expr b, Expr bias, bool transA, bool transB, float scale) {
}
}
+Expr affineWithRelu(Expr a, Expr b, Expr bias, bool transA, bool transB, float scale) {
+ auto graph = a->graph();
+
+ if(graph->isInference() && graph->getDeviceId().type == DeviceType::gpu)
+ return Expression<AffineWithReluNodeOp>(a, b, bias, transA, transB, scale);
+ else
+ return relu(affine(a, b, bias, transA, transB, scale));
+}
+
// @TODO: Not a great place to check this
#if CUDA_VERSION < 11000
// multiply a CSR matrix A with a matrix B
diff --git a/src/graph/expression_operators.h b/src/graph/expression_operators.h
index ca0739e4..81b0f5ea 100644
--- a/src/graph/expression_operators.h
+++ b/src/graph/expression_operators.h
@@ -488,12 +488,22 @@ Expr bdot(Expr a,
*/
Expr affine(Expr a,
Expr b,
- Expr c,
+ Expr bias,
bool transA = false,
bool transB = false,
float scalar = 1.f);
/**
+ * As above, but efficiently applies relu transformation to output. For inference only.
+ */
+Expr affineWithRelu(Expr a,
+ Expr b,
+ Expr bias,
+ bool transA = false,
+ bool transB = false,
+ float scalar = 1.f);
+
+/**
* Computes the dot product of CSR-tensor @p A with @p B.
*/
Expr csr_dot(const Shape& A_shape, Expr Avalues, Expr Aindices, Expr Aoffsets, Expr B, bool transA = false);
diff --git a/src/graph/node_operators_binary.h b/src/graph/node_operators_binary.h
index 261885ec..55f105a9 100644
--- a/src/graph/node_operators_binary.h
+++ b/src/graph/node_operators_binary.h
@@ -266,17 +266,18 @@ public:
NodeOps forwardOps() override {
using namespace functional;
-
+
return {
- NodeOp(
- Prod(val_,
- child(0)->val(),
- child(1)->val(),
- transA_,
- transB_,
- 0.f,
- scalar_);
- Prod(val_, child(3)->val(), child(2)->val(), false, false, 1.f, 1.f))
+ NodeOp(Affine(val_,
+ graph()->allocator(),
+ child(0)->val(),
+ child(1)->val(),
+ child(2)->val(),
+ transA_,
+ transB_,
+ 0.f,
+ scalar_,
+ /*doRelu=*/false))
};
}
@@ -323,8 +324,7 @@ public:
false,
1.0,
scalar_, computeTypeB)),
- NodeOp(Prod(
- child(2)->grad(), child(3)->val(), adj_, true, false, 0.f, 1.f, computeTypeC))
+ NodeOp(Prod(child(2)->grad(), child(3)->val(), adj_, true, false, 0.f, 1.f, computeTypeC))
};
if(transA_ && !transB_)
@@ -343,8 +343,7 @@ public:
false,
1.0,
scalar_, computeTypeB)),
- NodeOp(Prod(
- child(2)->grad(), child(3)->val(), adj_, true, false, 0.f, 1.f, computeTypeC))
+ NodeOp(Prod(child(2)->grad(), child(3)->val(), adj_, true, false, 0.f, 1.f, computeTypeC))
};
if(transA_ && transB_)
@@ -363,8 +362,7 @@ public:
true,
1.0,
scalar_, computeTypeB)),
- NodeOp(Prod(
- child(2)->grad(), child(3)->val(), adj_, true, false, 0.f, 1.f, computeTypeC))
+ NodeOp(Prod(child(2)->grad(), child(3)->val(), adj_, true, false, 0.f, 1.f, computeTypeC))
};
return {
@@ -382,8 +380,7 @@ public:
false,
1.0,
scalar_, computeTypeB)),
- NodeOp(Prod(
- child(2)->grad(), child(3)->val(), adj_, true, false, 0.f, 1.f, computeTypeC))
+ NodeOp(Prod(child(2)->grad(), child(3)->val(), adj_, true, false, 0.f, 1.f, computeTypeC))
};
}
@@ -414,6 +411,97 @@ public:
};
+class AffineWithReluNodeOp : public NaryNodeOp {
+private:
+ friend class SerializationHelpers;
+ bool transA_;
+ bool transB_;
+ float scalar_;
+
+public:
+ AffineWithReluNodeOp(Expr a,
+ Expr b,
+ Expr bias,
+ bool transA,
+ bool transB,
+ float scalar)
+ : NaryNodeOp({a, b, bias}, newShape(a, b, transA, transB)),
+ transA_(transA),
+ transB_(transB),
+ scalar_(scalar) {
+ ABORT_IF(!graph()->isInference() || graph()->getDeviceId().type != DeviceType::gpu,
+ "AffineWithReluNodeOp currently only supported for inference on GPU");
+ }
+
+ Shape newShape(Expr a, Expr b, bool transA, bool transB) {
+ auto shapeA = a->shape();
+ if(transA) {
+ shapeA.set(shapeA.size() - 2, a->shape()[shapeA.size() - 1]);
+ shapeA.set(shapeA.size() - 1, a->shape()[shapeA.size() - 2]);
+ }
+
+ auto shapeB = b->shape();
+ if(transB) {
+ shapeB.set(shapeB.size() - 2, b->shape()[shapeB.size() - 1]);
+ shapeB.set(shapeB.size() - 1, b->shape()[shapeB.size() - 2]);
+ }
+
+ Shape outShape = shapeA;
+ outShape.set(outShape.size() - 1, shapeB[shapeB.size() - 1]);
+ ABORT_IF(shapeA[shapeA.size() - 1] != shapeB[shapeB.size() - 2],
+ "Matrix product requires inner dimensions to match in {}{} * {}{}", std::string(shapeA), transA, std::string(shapeB), transB);
+ return outShape;
+ }
+
+ NodeOps forwardOps() override {
+ ABORT_IF(!graph()->isInference() || graph()->getDeviceId().type != DeviceType::gpu,
+ "AffineWithReluNodeOp currently only supported for inference on GPU");
+
+ return {
+ NodeOp(Affine(val_,
+ graph()->allocator(),
+ child(0)->val(),
+ child(1)->val(),
+ child(2)->val(),
+ transA_,
+ transB_,
+ 0.f,
+ scalar_,
+ /*doRelu=*/true))
+ };
+ }
+
+ NodeOps backwardOps() override {
+ ABORT("AffineWithReluNodeOp cannot be used for training??");
+ return {};
+ }
+
+ const std::string type() override { return "affineWithRelu"; }
+
+ virtual size_t hash() override {
+ size_t seed = NaryNodeOp::hash();
+ util::hash_combine(seed, transA_);
+ util::hash_combine(seed, transB_);
+ util::hash_combine(seed, scalar_);
+ return seed;
+ }
+
+ virtual bool equal(Expr node) override {
+ if(!NaryNodeOp::equal(node))
+ return false;
+ auto cnode = std::dynamic_pointer_cast<AffineWithReluNodeOp>(node);
+ if(!cnode)
+ return false;
+ if(transA_ != cnode->transA_)
+ return false;
+ if(transB_ != cnode->transB_)
+ return false;
+ if(scalar_ != cnode->scalar_)
+ return false;
+ return true;
+ }
+};
+
class DotBatchedNodeOp : public NaryNodeOp {
private:
friend class SerializationHelpers;
diff --git a/src/layers/generic.h b/src/layers/generic.h
index 89f5c1e9..5eb93615 100644
--- a/src/layers/generic.h
+++ b/src/layers/generic.h
@@ -1,5 +1,7 @@
#pragma once
+#include "common/definitions.h"
+#include "graph/expression_operators.h"
#include "marian.h"
#include "data/shortlist.h"
@@ -168,22 +170,37 @@ public:
// --- a few layers with built-in parameters created on the fly, without proper object
// @TODO: change to a proper layer object
+static inline std::function<Expr(Expr)> activationByName(const std::string& actName) {
+ if (actName == "relu")
+ return (ActivationFunction*)relu;
+ else if (actName == "swish")
+ return (ActivationFunction*)swish;
+ else if (actName == "gelu")
+ return (ActivationFunction*)gelu;
+ else if (actName == "") // return identity function if activation name is empty
+ return [](Expr x) { return x; };
+ ABORT("Invalid activation name '{}'", actName);
+}
+
// like affine() but with built-in parameters, activation, and dropout
static inline Expr denseInline(Expr x,
std::string prefix,
std::string suffix,
int outDim,
Ptr<inits::NodeInitializer> initFn = inits::glorotUniform(),
- const std::function<Expr(Expr)>& actFn = nullptr,
+ std::string actName = "",
float dropProb = 0.0f) {
auto graph = x->graph();
auto W = graph->param(prefix + "_W" + suffix, {x->shape()[-1], outDim}, inits::glorotUniform());
auto b = graph->param(prefix + "_b" + suffix, {1, outDim}, inits::zeros());
- x = affine(x, W, b);
- if(actFn)
- x = actFn(x);
+ if(actName == "relu") {
+ x = affineWithRelu(x, W, b); // speed optimization for inference, @TODO: handle better in future layer framework
+ } else {
+ x = affine(x, W, b);
+ x = activationByName(actName)(x);
+ }
x = dropout(x, dropProb); // @TODO: check for infernce?
return x;
}
diff --git a/src/layers/output.cpp b/src/layers/output.cpp
index 1d9c7b4b..4c34bdce 100644
--- a/src/layers/output.cpp
+++ b/src/layers/output.cpp
@@ -170,7 +170,7 @@ Logits Output::applyAsLogits(Expr input) /*override final*/ {
/*suffix=*/"1",
ffnDim,
inits::glorotUniform(),
- (ActivationFunction*)relu,
+ "relu",
ffnDropProb);
f = denseInline(f, name + "_ffn", /*suffix=*/"2", inputDim);
// add & norm
diff --git a/src/models/transformer.h b/src/models/transformer.h
index 6368cc6a..79b59000 100644
--- a/src/models/transformer.h
+++ b/src/models/transformer.h
@@ -396,18 +396,6 @@ public:
opt<int>("transformer-heads"), /*cache=*/false);
}
- static inline
- std::function<Expr(Expr)> activationByName(const std::string& actName)
- {
- if (actName == "relu")
- return (ActivationFunction*)relu;
- else if (actName == "swish")
- return (ActivationFunction*)swish;
- else if (actName == "gelu")
- return (ActivationFunction*)gelu;
- ABORT("Invalid activation name '{}'", actName);
- }
-
Expr LayerFFN(std::string prefix, Expr input) const {
int dimModel = input->shape()[-1];
@@ -415,9 +403,9 @@ public:
auto opsPre = opt<std::string>("transformer-preprocess");
auto output = preProcess(prefix + "_ffn", opsPre, input, dropProb);
+ auto actName = opt<std::string>("transformer-ffn-activation");
int dimFfn = opt<int>("transformer-dim-ffn");
int depthFfn = opt<int>("transformer-ffn-depth");
- auto actFn = activationByName(opt<std::string>("transformer-ffn-activation"));
float ffnDropProb
= inference_ ? 0 : opt<float>("transformer-dropout-ffn");
@@ -427,12 +415,11 @@ public:
// the stack of FF layers
for(int i = 1; i < depthFfn; ++i)
- output = denseInline(output, prefix, /*suffix=*/std::to_string(i), dimFfn, initFn, actFn, ffnDropProb);
+ output = denseInline(output, prefix, /*suffix=*/std::to_string(i), dimFfn, initFn, actName, ffnDropProb);
output = denseInline(output, prefix, /*suffix=*/std::to_string(depthFfn), dimModel, initFn);
auto opsPost = opt<std::string>("transformer-postprocess");
- output
- = postProcess(prefix + "_ffn", opsPost, output, input, dropProb);
+ output = postProcess(prefix + "_ffn", opsPost, output, input, dropProb);
return output;
}
@@ -450,21 +437,21 @@ public:
// FFN
int dimAan = opt<int>("transformer-dim-aan");
int depthAan = opt<int>("transformer-aan-depth");
- auto actFn = activationByName(opt<std::string>("transformer-aan-activation"));
+ auto actName = opt<std::string>("transformer-aan-activation");
float aanDropProb = inference_ ? 0 : opt<float>("transformer-dropout-ffn");
auto initFn = inits::glorotUniform(true, true, depthScaling_ ? 1.f / sqrtf((float)depth_) : 1.f);
// the stack of AAN layers
for(int i = 1; i < depthAan; ++i)
- y = denseInline(y, prefix, /*suffix=*/std::to_string(i), dimAan, initFn, actFn, aanDropProb);
+ y = denseInline(y, prefix, /*suffix=*/std::to_string(i), dimAan, initFn, actName, aanDropProb);
if(y->shape()[-1] != dimModel) // bring it back to the desired dimension if needed
y = denseInline(y, prefix, std::to_string(depthAan), dimModel, initFn);
bool noGate = opt<bool>("transformer-aan-nogate");
if(!noGate) {
- auto gi = denseInline(x, prefix, /*suffix=*/"i", dimModel, initFn, (ActivationFunction*)sigmoid);
- auto gf = denseInline(y, prefix, /*suffix=*/"f", dimModel, initFn, (ActivationFunction*)sigmoid);
+ auto gi = denseInline(x, prefix, /*suffix=*/"i", dimModel, initFn, "sigmoid");
+ auto gf = denseInline(y, prefix, /*suffix=*/"f", dimModel, initFn, "sigmoid");
y = gi * x + gf * y;
}
diff --git a/src/tensors/cpu/prod.cpp b/src/tensors/cpu/prod.cpp
index f77337d6..6e28bdd2 100755
--- a/src/tensors/cpu/prod.cpp
+++ b/src/tensors/cpu/prod.cpp
@@ -212,6 +212,23 @@ void ProdWithBias(marian::Tensor C,
cpu::integer::AddBias(C, bias);
}
+void Affine(marian::Tensor C,
+ Ptr<Allocator> /*allocator*/,
+ const marian::Tensor& A,
+ const marian::Tensor& B,
+ const marian::Tensor& bias,
+ bool transA,
+ bool transB,
+ float beta,
+ float scalar,
+ bool reluPostprocess) {
+ using namespace functional;
+ ProdWithBias(C, A, B, bias, transA, transB, beta, scalar);
+ if(reluPostprocess)
+ cpu::Element(_1 = ReLU(_1), C); // @TODO: also fuse with AddBias
+}
+
+
void CSRProd(marian::Tensor C,
Ptr<Allocator> /*allocator*/,
const marian::Tensor& S_values,
diff --git a/src/tensors/dispatch.h b/src/tensors/dispatch.h
index 094f156c..f7154351 100644
--- a/src/tensors/dispatch.h
+++ b/src/tensors/dispatch.h
@@ -152,6 +152,30 @@
cpu::Function(arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9); \
}
+#define DISPATCH10( \
+ Function, Arg1, Arg2, Arg3, Arg4, Arg5, Arg6, Arg7, Arg8, Arg9, Arg10) \
+namespace gpu { \
+void Function(Arg1, Arg2, Arg3, Arg4, Arg5, Arg6, Arg7, Arg8, Arg9, Arg10); \
+} \
+namespace cpu { \
+void Function(Arg1, Arg2, Arg3, Arg4, Arg5, Arg6, Arg7, Arg8, Arg9, Arg10); \
+} \
+static inline void Function(Arg1 arg1, \
+ Arg2 arg2, \
+ Arg3 arg3, \
+ Arg4 arg4, \
+ Arg5 arg5, \
+ Arg6 arg6, \
+ Arg7 arg7, \
+ Arg8 arg8, \
+ Arg9 arg9, \
+ Arg10 arg10) { \
+ if(arg1->getBackend()->getDeviceId().type == DeviceType::gpu) \
+ gpu::Function(arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10); \
+ else \
+ cpu::Function(arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10); \
+}
+
#else
#define DISPATCH1(Function, Arg1) \
@@ -248,4 +272,22 @@
cpu::Function(arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9); \
}
+#define DISPATCH10( \
+ Function, Arg1, Arg2, Arg3, Arg4, Arg5, Arg6, Arg7, Arg8, Arg9, Arg10) \
+ namespace cpu { \
+ void Function(Arg1, Arg2, Arg3, Arg4, Arg5, Arg6, Arg7, Arg8, Arg9, Arg10); \
+ } \
+ static inline void Function(Arg1 arg1, \
+ Arg2 arg2, \
+ Arg3 arg3, \
+ Arg4 arg4, \
+ Arg5 arg5, \
+ Arg6 arg6, \
+ Arg7 arg7, \
+ Arg8 arg8, \
+ Arg9 arg9, \
+ Arg10 arg10) { \
+ cpu::Function(arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10); \
+ }
+
#endif
diff --git a/src/tensors/gpu/prod.cpp b/src/tensors/gpu/prod.cpp
index bf7e5512..8cfa78ca 100755
--- a/src/tensors/gpu/prod.cpp
+++ b/src/tensors/gpu/prod.cpp
@@ -11,10 +11,34 @@
#include "tensors/gpu/cuda_helpers.h"
// clang-format on
+#if CUDA_VERSION >= 11000
+#include <cublasLt.h>
+#endif
+
namespace marian {
namespace gpu {
+// It seems that the bias must be 8 byte aligned for the cublasLt epilogue to work. Therefore,
+// if the bias pointer is not 8 byte aligned, we do a normal matmul in cublasLt and invoke a
+// custom epilogue kernel.
+static constexpr int REQUIRED_BIAS_ALIGNMENT = 8;
+
+// Used to set preferences for cublasLt to filter out algos if matrices to not meet default 256 byte alignment
+int getAlignmentUpTo256(const void *ptr) {
+ uintptr_t addr = (uintptr_t)ptr;
+ int trailingZeros = 0;
+
+ for(int shiftAmt = 8, mask = 0xFF; shiftAmt > 0; shiftAmt /= 2, mask >>=shiftAmt) {
+ if ((addr & mask) == 0) {
+ trailingZeros += shiftAmt;
+ addr >>= shiftAmt;
+ }
+ }
+
+ return std::min(256, 1 << trailingZeros);
+}
+
// The explicit version of matmult like cublasGemmEx choose their math mode based on the algorithm that
// has been passed into the function call and seem to ignore setMathMode. Here we query the used math mode
// to choose the algorithm.
@@ -412,5 +436,198 @@ void ProdBatched(marian::Tensor C,
}
}
+#if CUDA_VERSION >= 11000 // Earlier versions of cublasLT do not support bias addition for fp32 and fp16.
+
+static cublasStatus_t cublasLtAffineHelper(cublasLtHandle_t ltHandle, cublasOperation_t transA, cublasOperation_t transB,
+ cudaDataType matrixType,
+ int m, int n, int k, const void *alpha, const void *A, int lda, const void *B,
+ int ldb, const void *beta, void *C, int ldc, const void* bias,
+ void* workspace, size_t workspaceSize, bool do_relu, cudaStream_t stream) {
+
+ cublasLtMatmulDesc_t operationDesc = NULL;
+ cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL;
+ cublasLtMatmulPreference_t preference = NULL;
+
+ int returnedResults = 0;
+ cublasLtMatmulHeuristicResult_t heuristicResult = {};
+
+ cublasLtEpilogue_t epilogue = do_relu? CUBLASLT_EPILOGUE_RELU_BIAS: CUBLASLT_EPILOGUE_BIAS;
+ cublasComputeType_t computeType = matrixType == CUDA_R_32F? CUBLAS_COMPUTE_32F_FAST_16F: CUBLAS_COMPUTE_16F;
+
+ // If the bias is not aligned, just matmul and invoke custom epilogue later.
+ // cublas fails with a misalignment error if this condition is not true.
+ if((uintptr_t)bias % REQUIRED_BIAS_ALIGNMENT != 0) {
+ epilogue = CUBLASLT_EPILOGUE_DEFAULT;
+ }
+
+ CUBLAS_CHECK(cublasLtMatmulDescCreate(&operationDesc, computeType, matrixType));
+ CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transA, sizeof(transA)));
+ CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transB, sizeof(transB)));
+ CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue)));
+ CUBLAS_CHECK(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias)));
+
+ CUBLAS_CHECK(cublasLtMatrixLayoutCreate(&Adesc, matrixType, transA == CUBLAS_OP_N ? m : k, transA == CUBLAS_OP_N ? k : m, lda));
+ CUBLAS_CHECK(cublasLtMatrixLayoutCreate(&Bdesc, matrixType, transB == CUBLAS_OP_N ? k : n, transB == CUBLAS_OP_N ? n : k, ldb));
+ CUBLAS_CHECK(cublasLtMatrixLayoutCreate(&Cdesc, matrixType, m, n, ldc));
+
+ // I think we need to do this since we can slice matrices...
+ // The allocator always allocates on 256 byte boundaries but we have no guarantees about the alignment of a matrix slice so we filter out
+ // algorithms that would not work with matrices not aligned to 256 bytes.
+ int alignmentA = getAlignmentUpTo256(A);
+ int alignmentB = getAlignmentUpTo256(B);
+ int alignmentC = getAlignmentUpTo256(C);
+
+ CUBLAS_CHECK(cublasLtMatmulPreferenceCreate(&preference));
+ CUBLAS_CHECK(cublasLtMatmulPreferenceSetAttribute(preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize)));
+ CUBLAS_CHECK(cublasLtMatmulPreferenceSetAttribute(preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES, &alignmentA, sizeof(alignmentA)));
+ CUBLAS_CHECK(cublasLtMatmulPreferenceSetAttribute(preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_B_BYTES, &alignmentB, sizeof(alignmentB)));
+ CUBLAS_CHECK(cublasLtMatmulPreferenceSetAttribute(preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES, &alignmentC, sizeof(alignmentC)));
+ CUBLAS_CHECK(cublasLtMatmulPreferenceSetAttribute(preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES, &alignmentC, sizeof(alignmentC)));
+ CUBLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(ltHandle, operationDesc, Adesc, Bdesc, Cdesc, Cdesc, preference, 1, &heuristicResult, &returnedResults));
+
+ cublasStatus_t opStatus = cublasLtMatmul(ltHandle, operationDesc, alpha, A, Adesc, B, Bdesc, beta, C, Cdesc, C, Cdesc,
+ &heuristicResult.algo, workspace, workspaceSize, stream);
+
+ if (preference) CUBLAS_CHECK(cublasLtMatmulPreferenceDestroy(preference));
+ if (Cdesc) CUBLAS_CHECK(cublasLtMatrixLayoutDestroy(Cdesc));
+ if (Bdesc) CUBLAS_CHECK(cublasLtMatrixLayoutDestroy(Bdesc));
+ if (Adesc) CUBLAS_CHECK(cublasLtMatrixLayoutDestroy(Adesc));
+ if (operationDesc) CUBLAS_CHECK(cublasLtMatmulDescDestroy(operationDesc));
+
+ return opStatus;
+}
+
+static cublasStatus_t cublasLtAffineTyped(cublasLtHandle_t ltHandle, cublasOperation_t transA, cublasOperation_t transB,
+ int m, int n, int k, const half *alpha, const half *A, int lda, const half *B,
+ int ldb, const half *beta, half *C, int ldc, const half* bias,
+ half* workspace, size_t workspaceSizeBytes, bool do_relu, cudaStream_t stream) {
+ return cublasLtAffineHelper(ltHandle, transA, transB, CUDA_R_16F, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc, bias,
+ workspace, workspaceSizeBytes, do_relu, stream);
+}
+
+static cublasStatus_t cublasLtAffineTyped(cublasLtHandle_t ltHandle, cublasOperation_t transA, cublasOperation_t transB,
+ int m, int n, int k, const float *alpha, const float *A, int lda, const float *B,
+ int ldb, const float *beta, float *C, int ldc, const float* bias,
+ float* workspace, size_t workspaceSizeBytes,bool do_relu, cudaStream_t stream) {
+
+ return cublasLtAffineHelper(ltHandle, transA, transB, CUDA_R_32F, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc, bias,
+ workspace, workspaceSizeBytes, do_relu, stream);
+}
+
+template <typename T>
+void affineTyped(marian::Tensor C, Ptr<Allocator> allocator, const marian::Tensor& A, const marian::Tensor& B, const marian::Tensor& bias,
+ bool transA, bool transB, T beta, T scalar, bool do_relu) {
+
+ CUDA_CHECK(cudaSetDevice((int)C->getDeviceId().no));
+ T alpha = scalar;
+
+ int m = A->shape().elements() / A->shape().back();
+ int k = A->shape().back();
+ if(transA)
+ std::swap(m, k);
+
+ int l = B->shape().elements() / B->shape().back();
+ int n = B->shape().back();
+ if(transB)
+ std::swap(l, n);
+
+ int lda = A->shape().back();
+ int ldb = B->shape().back();
+ int ldc = B->shape().back();
+
+ size_t bias_size = bias->shape().elements();
+ ABORT_IF(n != bias_size, "The number of elements in the bias must match the number of columns in C");
+
+ if(transB)
+ ldc = B->shape().elements() / B->shape().back();
+
+ cublasOperation_t opA = transA ? CUBLAS_OP_T : CUBLAS_OP_N;
+ cublasOperation_t opB = transB ? CUBLAS_OP_T : CUBLAS_OP_N;
+
+ auto backend = std::static_pointer_cast<gpu::Backend>(C->getBackend());
+ auto cublasHandle = backend->getCublasHandle();
+ auto ltHandle = (cublasLtHandle_t)backend->getCublasHandle(); // A cublas handle encapsulates an lt handle
+
+ size_t numWorkSpaceElts = 8192; // Allows for cublasLt to perform split-K gemms. This is chosen to be at least
+ // 16 KiB for float16 which is large enough to prevent alloc failed errors
+ size_t workspaceSizeBytes = numWorkSpaceElts * sizeof(T);
+ IPtr<MemoryPiece> workspace = allocator->alloc<T>(numWorkSpaceElts);
+
+ cudaStream_t stream = 0;
+ CUBLAS_CHECK(cublasGetStream(cublasHandle, &stream));
+
+
+ CUBLAS_CHECK(cublasLtAffineTyped(ltHandle,
+ opB,
+ opA,
+ n,
+ m,
+ k,
+ &alpha,
+ B->data<T>(),
+ ldb,
+ A->data<T>(),
+ lda,
+ &beta,
+ C->data<T>(),
+ ldc,
+ bias->data<T>(),
+ workspace->data<T>(),
+ workspaceSizeBytes,
+ do_relu,
+ stream));
+
+ allocator->free(workspace);
+}
+
+// This version is needed so that Windows doesn't complain when compiling CUDA < 11. Otherwise, the ifdef could be inside of one
+// definition of Affine.
+void Affine(marian::Tensor C,
+ Ptr<Allocator> allocator,
+ const marian::Tensor& A,
+ const marian::Tensor& B,
+ const marian::Tensor& bias,
+ bool transA, bool transB, float beta, float scalar, bool do_relu) {
+ // There is a bug in CUDA 11 where the bias pointer needs to be 8 byte aligned. This bug will be fix in a subsequent release. For now,
+ // we launch a custom epilogue if the bias does not meet the alignment requirement.
+ if(C->type() == Type::float32) {
+ affineTyped<float>(C, allocator, A, B, bias, transA, transB, beta, scalar, do_relu);
+ if((uintptr_t)bias->data<float>() % REQUIRED_BIAS_ALIGNMENT != 0) {
+ BiasAdd(C, bias, do_relu);
+ }
+#if COMPILE_FP16
+ } else if(C->type() == Type::float16) {
+ affineTyped<half>(C, allocator, A, B, bias, transA, transB, __float2half(beta), __float2half(scalar), do_relu);
+ if((uintptr_t)bias->data<half>() % REQUIRED_BIAS_ALIGNMENT != 0) {
+ BiasAdd(C, bias, do_relu);
+ }
+#endif
+ } else {
+ ABORT("Affine not implemented for type {}", C->type());
+ }
+}
+
+#else
+
+void Affine(marian::Tensor C,
+ Ptr<Allocator> /*allocator*/,
+ const marian::Tensor& A,
+ const marian::Tensor& B,
+ const marian::Tensor& bias,
+ bool transA, bool transB, float beta, float scalar, bool do_relu) {
+
+ if(C->type() == Type::float32) {
+ ProdTyped<float>(C, A, B, transA, transB, beta, scalar);
+#if COMPILE_FP16
+ } else if(C->type() == Type::float16) {
+ ProdTyped<half>(C, A, B, transA, transB, __float2half(beta), __float2half(scalar));
+#endif
+ } else {
+ ABORT("Prod not implemented for type {}", C->type());
+ }
+ BiasAdd(C, bias, do_relu);
+}
+#endif
+
} // namespace gpu
} // namespace marian
diff --git a/src/tensors/gpu/prod.cu b/src/tensors/gpu/prod.cu
new file mode 100644
index 00000000..ec01d57e
--- /dev/null
+++ b/src/tensors/gpu/prod.cu
@@ -0,0 +1,69 @@
+#include <stdint.h>
+#include "tensors/tensor.h"
+#include "tensors/gpu/cuda_helpers.h"
+#include "tensors/gpu/backend.h"
+
+namespace marian {
+namespace gpu {
+
+template <typename T, typename ActFunc>
+__global__ static void gBiasAddFused(T* tensor, T* bias, size_t tensor_size, size_t bias_size, ActFunc f) {
+ const size_t row_start = blockIdx.x * bias_size;
+ for(int bias_offset = threadIdx.x; bias_offset < bias_size; bias_offset+=blockDim.x) {
+ size_t offset_into_tensor = row_start + bias_offset;
+ if(offset_into_tensor < tensor_size) {
+ T added_bias = tensor[offset_into_tensor] + bias[bias_offset];
+ tensor[offset_into_tensor] = f(added_bias);
+ }
+ }
+}
+
+struct identity {
+ template <typename T>
+ __device__ constexpr T&& operator() (T&& t) const noexcept {
+ return std::forward<T>(t);
+ }
+};
+
+struct reluAct {
+ template <typename T>
+ __device__ T operator() (T t) const noexcept {
+ return t > (T) 0? t : (T) 0;
+ }
+};
+
+void BiasAdd(marian::Tensor C, const marian::Tensor& bias, bool do_relu) {
+ auto backend = std::static_pointer_cast<gpu::Backend>(C->getBackend());
+ CUDA_CHECK(cudaSetDevice(backend->getDeviceId().no));
+
+ size_t size = C->shape().elements();
+ size_t bias_size = bias->shape().elements();
+
+ int m = C->shape().elements() / C->shape().back();
+ int n = C->shape().back();
+
+ ABORT_IF(n != bias_size, "The number of elements in the bias must match the number of columns in C");
+
+ int threads_per_block = std::min(MAX_THREADS, n);
+ int blocks = m;
+
+ if(C->type() == Type::float32) {
+ if (do_relu)
+ gBiasAddFused<<<blocks, threads_per_block>>>(C->data<float>(), bias->data<float>(), size, bias_size, reluAct());
+ else
+ gBiasAddFused<<<blocks, threads_per_block>>>(C->data<float>(), bias->data<float>(), size, bias_size, identity());
+
+#if COMPILE_FP16
+ } else if(C->type() == Type::float16) {
+ if (do_relu)
+ gBiasAddFused<<<blocks, threads_per_block>>>(C->data<half>(), bias->data<half>(), size, bias_size, reluAct());
+ else
+ gBiasAddFused<<<blocks, threads_per_block>>>(C->data<half>(), bias->data<half>(), size, bias_size, identity());
+#endif
+ } else {
+ ABORT("Prod not implemented for type {}", C->type());
+ }
+}
+
+}
+} \ No newline at end of file
diff --git a/src/tensors/gpu/prod.h b/src/tensors/gpu/prod.h
index 63b9192a..aec8cb73 100644
--- a/src/tensors/gpu/prod.h
+++ b/src/tensors/gpu/prod.h
@@ -6,6 +6,21 @@
namespace marian {
namespace gpu {
+void BiasAdd(marian::Tensor C,
+ const marian::Tensor& bias,
+ bool do_relu = false);
+
+void Affine(marian::Tensor C,
+ Ptr<Allocator> allocator,
+ const marian::Tensor& A,
+ const marian::Tensor& B,
+ const marian::Tensor& bias,
+ bool transA,
+ bool transB,
+ float beta = 0,
+ float scalar = 1,
+ bool do_relu = false);
+
void Prod(marian::Tensor C,
const marian::Tensor& A,
const marian::Tensor& B,
diff --git a/src/tensors/tensor_operators.h b/src/tensors/tensor_operators.h
index 83bce819..af7946dd 100644
--- a/src/tensors/tensor_operators.h
+++ b/src/tensors/tensor_operators.h
@@ -106,6 +106,8 @@ DISPATCH8(Prod, marian::Tensor, const marian::Tensor&, const marian::Tensor&, bo
DISPATCH8(ProdBatched, marian::Tensor, Ptr<Allocator>, const marian::Tensor, const marian::Tensor, bool, bool, float, float)
DISPATCH9(CSRProd, marian::Tensor, Ptr<Allocator>, const marian::Tensor&, const marian::Tensor&, const marian::Tensor&, const marian::Tensor&, bool, bool, float)
+DISPATCH10(Affine, marian::Tensor, Ptr<Allocator>, const marian::Tensor&, const marian::Tensor&, const marian::Tensor&, bool, bool, float, float, bool)
+
DISPATCH2(Softmax, marian::Tensor, marian::Tensor)
DISPATCH3(SoftmaxGrad, marian::Tensor, marian::Tensor, marian::Tensor)
diff --git a/src/tests/units/operator_tests.cpp b/src/tests/units/operator_tests.cpp
index 27ccf139..c3fd4a9e 100644
--- a/src/tests/units/operator_tests.cpp
+++ b/src/tests/units/operator_tests.cpp
@@ -32,6 +32,8 @@ void tests(DeviceType device, Type floatType = Type::float32) {
Config::seed = 1234;
auto graph = New<ExpressionGraph>();
+
+ graph->setInference(true);
graph->setDefaultElementType(floatType);
graph->setDevice({0, device});
graph->reserveWorkspaceMB(16);
@@ -539,15 +541,19 @@ void tests(DeviceType device, Type floatType = Type::float32) {
values.clear();
std::vector<T> vA({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
- std::vector<T> vB({1, 2, 3, 4, 5, 6});
- std::vector<T> vAff({24, 30, 51, 66, 78, 102, 105, 138});
+ std::vector<T> vB({1, -2, 3, 4, -5, 6});
+ std::vector<T> vAff({-6, 26, -9, 50, -12, 74, -15, 98});
+ std::vector<T> vAffRelu({0, 26, 0, 50, 0, 74, 0, 98});
auto A = graph->param("A", {4, 3}, inits::fromVector(vA));
auto B = graph->param("B", {3, 2}, inits::fromVector(vB));
- auto C = graph->param("C", {4, 2}, inits::fromValue(2));
+ auto bias = graph->param("C", {1, 2}, inits::fromValue(2));
+
+ auto aff1 = affine(A, B, bias);
+ auto aff2 = dot(A, B) + bias;
- auto aff1 = affine(A, B, C);
- auto aff2 = dot(A, B) + C;
+ auto affRelu1 = affineWithRelu(A, B, bias);
+ auto affRelu2 = relu(dot(A, B) + bias);
graph->forward();
@@ -559,6 +565,11 @@ void tests(DeviceType device, Type floatType = Type::float32) {
CHECK(aff2->shape() == aff1->shape());
aff2->val()->get(values2);
CHECK(values2 == values);
+
+ affRelu1->val()->get(values);
+ affRelu2->val()->get(values2);
+ CHECK(values2 == vAffRelu);
+ CHECK(values2 == values);
}
SECTION("repeat") {