Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorMartin Junczys-Dowmunt <Marcin.JunczysDowmunt@microsoft.com>2021-04-10 18:28:38 +0300
committerMartin Junczys-Dowmunt <Marcin.JunczysDowmunt@microsoft.com>2021-04-10 18:28:38 +0300
commitcaddad90cdb06283da2f5d17a3340ca8c6387b38 (patch)
tree8a83de33e15cf9221ce2d6d75ddd18ed53fecc2e /src
parenta05124176d8869962b717a3557c383406f8c76f4 (diff)
Merged PR 18505: RMSNorm on GPU
Support for RMSNorm as drop-in replace for LayerNorm from _Biao Zhang; Rico Sennrich (2019). Root Mean Square Layer Normalization_. Enabled in Transformer model via `--transformer-postprocess dar` instead of `dan`.
Diffstat (limited to 'src')
-rw-r--r--src/graph/expression_graph.cpp11
-rw-r--r--src/graph/expression_operators.cpp12
-rw-r--r--src/graph/expression_operators.h12
-rw-r--r--src/graph/node_operators_binary.h58
-rw-r--r--src/layers/generic.h6
-rw-r--r--src/models/transformer.h4
-rwxr-xr-xsrc/tensors/cpu/tensor_operators.cpp196
-rw-r--r--src/tensors/gpu/tensor_operators.cu267
-rw-r--r--src/tensors/tensor_operators.h49
-rw-r--r--src/tests/units/operator_tests.cpp43
10 files changed, 636 insertions, 22 deletions
diff --git a/src/graph/expression_graph.cpp b/src/graph/expression_graph.cpp
index 827fb3ed..12a1195e 100644
--- a/src/graph/expression_graph.cpp
+++ b/src/graph/expression_graph.cpp
@@ -208,8 +208,15 @@ void ExpressionGraph::backward(bool reset, float clipValue) {
}
if(v->trainable() && v->marked_for_debug()) {
- LOG(info, "Debug Grad: {} op={}", v->debug_message(), v->type());
- LOG(info, v->grad()->debug());
+ Logger log = spdlog::get("general");
+ if(log) {
+ LOG(info, "Debug Grad: {} op={}", v->debug_message(), v->type());
+ LOG(info, v->grad()->debug());
+ }
+ else {
+ std::cerr << "Debug Grad: " << v->debug_message() << " op=" << v->type() << std::endl;
+ std::cerr << v->grad()->debug() << std::endl;
+ }
}
if(v->trainable() && clipValue != 0) {
diff --git a/src/graph/expression_operators.cpp b/src/graph/expression_operators.cpp
index 048c7478..6c7ef91c 100644
--- a/src/graph/expression_operators.cpp
+++ b/src/graph/expression_operators.cpp
@@ -749,6 +749,18 @@ Expr layerNorm(Expr x,
return Expression<LayerNormalizationOp>(nodes, eps);
}
+Expr rmsNorm(Expr x,
+ Expr gamma,
+ Expr beta /*= nullptr*/,
+ float eps /*= 1e-9*/) {
+
+ // layerNorm accumulates in float, so small eps is fine
+ std::vector<Expr> nodes = {x, gamma};
+ if(beta)
+ nodes.push_back(beta);
+ return Expression<RMSNormalizationOp>(nodes, eps);
+}
+
Expr highway(Expr y, Expr x, Expr t) {
std::vector<Expr> nodes = {y, x, t};
return Expression<HighwayNodeOp>(nodes);
diff --git a/src/graph/expression_operators.h b/src/graph/expression_operators.h
index 81b0f5ea..f3d84eb6 100644
--- a/src/graph/expression_operators.h
+++ b/src/graph/expression_operators.h
@@ -916,6 +916,18 @@ Expr weighted_average(Expr in, Expr weights, int ax = 0);
Expr layerNorm(Expr x, Expr gamma, Expr beta = nullptr, float eps = 1e-9);
/**
+ * Applies RMS normalization over the last dimension.
+ *
+ * See: Biao Zhang; Rico Sennrich (2019). Root Mean Square Layer Normalization.
+ * In Advances in Neural Information Processing Systems 32. Vancouver, Canada.
+ * @f[
+ \frac{x}{\sqrt{\frac{1}{N}\sum x^2 + \mathrm{eps}}} \times \gamma + \beta
+ * @f]
+ * @see RMSNormalizationOp
+ */
+Expr rmsNorm(Expr x, Expr gamma, Expr beta = nullptr, float eps = 1e-9);
+
+/**
* Highway transformation.
* Computes the highway tranform on @p y and @p x as gated by @p t:
* @f$ \operatorname{sigmoid}(t) y + (1-\operatorname{sigmoid}(t)) x @f$
diff --git a/src/graph/node_operators_binary.h b/src/graph/node_operators_binary.h
index 55f105a9..91fc29da 100644
--- a/src/graph/node_operators_binary.h
+++ b/src/graph/node_operators_binary.h
@@ -1369,6 +1369,64 @@ private:
float eps_;
};
+// RMS norm along last axis
+struct RMSNormalizationOp : public NaryNodeOp {
+public:
+ RMSNormalizationOp(const std::vector<Expr>& nodes, float eps = 1e-9)
+ : NaryNodeOp(nodes), eps_(eps) {
+ // @TODO: dimension check
+ }
+
+ NodeOps forwardOps() override {
+ return {NodeOp(
+ RMSNormalization(val_,
+ child(0)->val(),
+ child(1)->val(),
+ (children_.size() == 3) ? child(2)->val() : nullptr,
+ eps_))};
+ }
+
+ // @BUGBUG: backward has not been tested for broadcasting gamma/beta
+ NodeOps backwardOps() override {
+ return {NodeOp(
+ RMSNormalizationGrad(
+ graph()->allocator(),
+ child(0)->grad(),
+ child(1)->grad(),
+ (children_.size() == 3) ? child(2)->grad() : nullptr,
+ adj_,
+ val_,
+ child(0)->val(),
+ child(1)->val(),
+ (children_.size() == 3) ? child(2)->val() : nullptr,
+ eps_))};
+ }
+
+ const std::string type() override { return "rms_normalization"; }
+
+ virtual size_t hash() override {
+ size_t seed = NaryNodeOp::hash();
+ util::hash_combine(seed, eps_);
+ return seed;
+ }
+
+ virtual bool equal(Expr node) override {
+ if(!NaryNodeOp::equal(node))
+ return false;
+ auto cnode = std::dynamic_pointer_cast<RMSNormalizationOp>(node);
+ if(!cnode)
+ return false;
+ if(eps_ != cnode->eps_)
+ return false;
+ return true;
+ }
+
+private:
+ friend class SerializationHelpers; // @TODO: use the same name for this as SqrtNodeOp
+ float eps_;
+};
+
+
struct HighwayNodeOp : public NaryNodeOp {
HighwayNodeOp(const std::vector<Expr>& nodes) : NaryNodeOp(nodes) {}
diff --git a/src/layers/generic.h b/src/layers/generic.h
index 5eb93615..2746bc85 100644
--- a/src/layers/generic.h
+++ b/src/layers/generic.h
@@ -212,4 +212,10 @@ static inline Expr layerNorm(Expr x, std::string prefix, std::string suffix = st
return marian::layerNorm(x, scale, bias, 1e-6f);
}
+static inline Expr rmsNorm(Expr x, std::string prefix, std::string suffix = std::string()) {
+ int dimModel = x->shape()[-1];
+ auto scale = x->graph()->param(prefix + "_rms_scale" + suffix, {1, dimModel}, inits::ones());
+ return marian::rmsNorm(x, scale, nullptr, 1e-6f);
+}
+
} // namespace marian
diff --git a/src/models/transformer.h b/src/models/transformer.h
index 79b59000..1da02318 100644
--- a/src/models/transformer.h
+++ b/src/models/transformer.h
@@ -176,6 +176,8 @@ public:
// layer normalization
else if (op == 'n')
output = layerNorm(output, prefix, "_pre");
+ else if (op == 'r')
+ output = rmsNorm(output, prefix, "_pre");
else
ABORT("Unknown pre-processing operation '{}'", op);
}
@@ -201,6 +203,8 @@ public:
// layer normalization
else if(op == 'n')
output = layerNorm(output, prefix);
+ else if(op == 'r')
+ output = rmsNorm(output, prefix);
else
ABORT("Unknown pre-processing operation '{}'", op);
}
diff --git a/src/tensors/cpu/tensor_operators.cpp b/src/tensors/cpu/tensor_operators.cpp
index 1191a2be..67d993fc 100755
--- a/src/tensors/cpu/tensor_operators.cpp
+++ b/src/tensors/cpu/tensor_operators.cpp
@@ -977,7 +977,7 @@ float L2Norm(Tensor in, Ptr<Allocator> /*not used*/) {
float sum = 0.f;
size_t size = in->size();
const float* data = in->data();
-#pragma omp parallel for simd reduction(+ : sum)
+ #pragma omp parallel for simd reduction(+ : sum)
for(size_t i = 0; i < size; ++i) {
sum += data[i] * data[i];
}
@@ -998,14 +998,14 @@ void Att(Tensor out_, Tensor va_, Tensor context_, Tensor state_) {
int rows = m;
int cols = k;
-#pragma omp parallel for
+ #pragma omp parallel for
for(int j = 0; j < rows; ++j) {
const float* vaRow = va;
const float* ctxRow = ctx + (j % (b * t)) * cols;
const float* stateRow = state + ((j / (b * t)) * b + j % b) * cols;
float sum = 0.f;
-#pragma omp simd reduction(+ : sum)
+ #pragma omp simd reduction(+ : sum)
for(int i = 0; i < cols; ++i) {
float z = ctxRow[i] + stateRow[i];
sum += std::tanh(z) * vaRow[i];
@@ -1035,7 +1035,7 @@ void AttBack(Tensor gVa_,
size_t k = context_->shape()[-1];
size_t n = context_->shape()[-2];
-#pragma omp parallel for reduction(+ : gState[:n * k], gVa[:k])
+ #pragma omp parallel for reduction(+ : gState[:n * k], gVa[:k])
for(size_t j = 0; j < m; ++j) {
float* gcRow = gContext + j * k;
float* gsRow = gState + (j % n) * k;
@@ -1045,7 +1045,7 @@ void AttBack(Tensor gVa_,
float adj_j = adj[j];
-#pragma omp simd
+ #pragma omp simd
for(size_t i = 0; i < k; ++i) {
float z = cRow[i] + sRow[i];
@@ -1070,20 +1070,20 @@ void LayerNormalizationImpl(float* out,
float eps,
int rows,
int cols) {
-#pragma omp parallel for
+ #pragma omp parallel for
for(int j = 0; j < rows; ++j) {
float* so = out + j * cols;
const float* sp = in + j * cols;
float sum = 0.f;
-#pragma omp simd reduction(+ : sum)
+ #pragma omp simd reduction(+ : sum)
for(int i = 0; i < cols; ++i) {
sum += sp[i];
}
float mean = sum / cols;
float sqSum = 0.f;
-#pragma omp simd reduction(+ : sqSum)
+ #pragma omp simd reduction(+ : sqSum)
for(int i = 0; i < cols; ++i) {
float ex = sp[i] - mean;
sqSum += ex * ex;
@@ -1091,7 +1091,7 @@ void LayerNormalizationImpl(float* out,
float sigma = std::sqrt(sqSum / cols + eps);
-#pragma omp simd
+ #pragma omp simd
for(int i = 0; i < cols; ++i) {
float t = alpha[alphaStride * i] * ((sp[i] - mean) / sigma);
if(hasBeta)
@@ -1168,7 +1168,7 @@ void LayerNormalizationGrad(Tensor gradX_,
size_t cols = y_->shape()[-1];
if(beta) {
-#pragma omp parallel for reduction(+ : gradGamma[:cols], gradBeta[:cols])
+ #pragma omp parallel for reduction(+ : gradGamma[:cols], gradBeta[:cols])
for(size_t j = 0; j < rows; ++j) {
const float* xRow = x + j * cols;
const float* yRow = y + j * cols;
@@ -1180,7 +1180,7 @@ void LayerNormalizationGrad(Tensor gradX_,
float sum_adj_x = 0.f;
float sum_sqr = 0.f;
-#pragma omp simd reduction(+ : sum_x, sum_adj_x, sum_adj)
+ #pragma omp simd reduction(+ : sum_x, sum_adj_x, sum_adj)
for(size_t i = 0; i < cols; ++i) {
sum_x += xRow[i];
sum_adj_x += adjRow[i] * (yRow[i] - (beta ? beta[betaStride * i] : 0.f)) / gamma[gammaStride * i];
@@ -1188,14 +1188,14 @@ void LayerNormalizationGrad(Tensor gradX_,
}
float mean = sum_x / cols;
-#pragma omp simd reduction(+ : sum_sqr)
+ #pragma omp simd reduction(+ : sum_sqr)
for(size_t i = 0; i < cols; ++i) {
float ex = xRow[i] - mean;
sum_sqr += ex * ex;
}
float sigma = std::sqrt(sum_sqr / cols + eps);
-#pragma omp simd
+ #pragma omp simd
for(size_t i = 0; i < cols; ++i) {
float grad_x = 0.f;
float x_hat = (yRow[i] - beta[betaStride * i]) / gamma[gammaStride * i];
@@ -1209,8 +1209,8 @@ void LayerNormalizationGrad(Tensor gradX_,
gradBeta[betaStride * i] += adjRow[i];
}
}
- } else {
-#pragma omp parallel for reduction(+ : gradGamma[:cols])
+ } else { // @TODO: this code duplication is really ugly, but required for omp to work correctly?
+ #pragma omp parallel for reduction(+ : gradGamma[:cols])
for(size_t j = 0; j < rows; ++j) {
const float* xRow = x + j * cols;
const float* yRow = y + j * cols;
@@ -1222,23 +1222,22 @@ void LayerNormalizationGrad(Tensor gradX_,
float sum_adj_x = 0.f;
float sum_sqr = 0.f;
-#pragma omp simd reduction(+ : sum_x, sum_adj_x, sum_adj)
+ #pragma omp simd reduction(+ : sum_x, sum_adj_x, sum_adj)
for(size_t i = 0; i < cols; ++i) {
sum_x += xRow[i];
- sum_adj_x += adjRow[i] * (yRow[i] - (beta ? beta[betaStride * i] : 0.f)) / gamma[gammaStride * i];
- // @TODO: beta is NULL here ^^
+ sum_adj_x += adjRow[i] * yRow[i] / gamma[gammaStride * i];
sum_adj += adjRow[i];
}
float mean = sum_x / cols;
-#pragma omp simd reduction(+ : sum_sqr)
+ #pragma omp simd reduction(+ : sum_sqr)
for(size_t i = 0; i < cols; ++i) {
float ex = xRow[i] - mean;
sum_sqr += ex * ex;
}
float sigma = std::sqrt(sum_sqr / cols + eps);
-#pragma omp simd
+ #pragma omp simd
for(size_t i = 0; i < cols; ++i) {
float grad_x = 0.f;
float x_hat = yRow[i] / gamma[gammaStride * i];
@@ -1255,6 +1254,163 @@ void LayerNormalizationGrad(Tensor gradX_,
}
MARIAN_FFAST_MATH_END
+MARIAN_FFAST_MATH_BEGIN
+template <int alphaStride, int betaStride, bool hasBeta>
+void RMSNormalizationImpl(float* out,
+ const float* in,
+ const float* alpha,
+ const float* beta,
+ float eps,
+ int rows,
+ int cols) {
+ #pragma omp parallel for
+ for(int j = 0; j < rows; ++j) {
+ float* so = out + j * cols;
+ const float* sp = in + j * cols;
+
+ float sqSum = 0.f;
+ #pragma omp simd reduction(+ : sqSum)
+ for(int i = 0; i < cols; ++i) {
+ sqSum += sp[i] * sp[i];
+ }
+
+ float rms = std::sqrt(sqSum / cols + eps);
+
+ #pragma omp simd
+ for(int i = 0; i < cols; ++i) {
+ float t = alpha[alphaStride * i] * (sp[i] / rms);
+ if(hasBeta)
+ t += beta[betaStride * i];
+
+ so[i] = t;
+ }
+ }
+}
+MARIAN_FFAST_MATH_END
+
+template <int alphaStride>
+inline void RMSNormalizationDispatchBeta(float* out,
+ const float* in,
+ const float* alpha,
+ Tensor beta,
+ float eps,
+ int rows,
+ int cols) {
+ if (beta) {
+ if (beta->shape().back() > 1) {
+ RMSNormalizationImpl<alphaStride, 1, true>(out, in, alpha, beta->data(), eps, rows, cols);
+ } else {
+ RMSNormalizationImpl<alphaStride, 0, true>(out, in, alpha, beta->data(), eps, rows, cols);
+ }
+ } else {
+ RMSNormalizationImpl<alphaStride, 0, false>(out, in, alpha, nullptr, eps, rows, cols);
+ }
+}
+
+void RMSNormalization(Tensor out,
+ Tensor in,
+ Tensor gamma,
+ Tensor beta,
+ float eps) {
+ const float* alpha = gamma->data();
+ const int alphaStride = gamma->shape().back() > 1; // broadcasting for alpha and beta
+
+ int rows = in->shape().elements() / in->shape().back();
+ int cols = in->shape().back();
+ if (alphaStride == 0) {
+ RMSNormalizationDispatchBeta<0>(out->data(), in->data(), alpha, beta, eps, rows, cols);
+ } else {
+ RMSNormalizationDispatchBeta<1>(out->data(), in->data(), alpha, beta, eps, rows, cols);
+ }
+}
+
+MARIAN_FFAST_MATH_BEGIN
+void RMSNormalizationGrad(Tensor gradX_,
+ Tensor gradGamma_,
+ Tensor gradBeta_,
+ Tensor adj_,
+ Tensor y_,
+ Tensor x_,
+ Tensor gamma_,
+ Tensor beta_,
+ float eps) {
+ float* gradX = gradX_->data();
+ float* gradGamma = gradGamma_->data();
+ float* gradBeta = gradBeta_ ? gradBeta_->data() : nullptr;
+ float* adj = adj_->data();
+ float* x = x_->data();
+ float* y = y_->data();
+ float* gamma = gamma_->data();
+ float* beta = beta_ ? beta_->data() : nullptr;
+ // @TODO: The CPU implementation supports scalar gamma and beta. This is a left-over,
+ // we should enable that in the GPU version as well.
+ const int gammaStride = gamma_->shape().back() > 1; // broadcasting for alpha and beta. 0 means it's a scalar
+ const int betaStride = beta_ && beta_->shape().back() > 1;
+
+ size_t rows = y_->shape().elements() / y_->shape()[-1];
+ size_t cols = y_->shape()[-1];
+
+ if(beta) {
+ #pragma omp parallel for reduction(+ : gradGamma[:cols], gradBeta[:cols])
+ for(size_t j = 0; j < rows; ++j) {
+ const float* xRow = x + j * cols;
+ const float* yRow = y + j * cols;
+ const float* adjRow = adj + j * cols;
+ float* gradXRow = gradX + j * cols;
+
+ float sum_adj_r = 0.f;
+ float sum_sqr = 0.f;
+
+ #pragma omp simd reduction(+ : sum_adj_r, sum_sqr)
+ for(size_t i = 0; i < cols; ++i) {
+ sum_adj_r += adjRow[i] * (yRow[i] - beta[betaStride * i]) / gamma[gammaStride * i];
+ sum_sqr += xRow[i] * xRow[i];
+ }
+
+ float rms = std::sqrt(sum_sqr / cols + eps);
+ #pragma omp simd
+ for(size_t i = 0; i < cols; ++i) {
+ float rmsNorm = (yRow[i] - beta[betaStride * i]) / gamma[gammaStride * i];
+ float gradNorm = cols * adjRow[i] - rmsNorm * sum_adj_r;
+ gradNorm /= cols * rms;
+
+ gradXRow[i] += gamma[gammaStride * i] * gradNorm;
+ gradGamma[gammaStride * i] += adjRow[i] * rmsNorm;
+ gradBeta[betaStride * i] += adjRow[i];
+ }
+ }
+ } else {
+ #pragma omp parallel for reduction(+ : gradGamma[:cols])
+ for(size_t j = 0; j < rows; ++j) {
+ const float* xRow = x + j * cols;
+ const float* yRow = y + j * cols;
+ const float* adjRow = adj + j * cols;
+ float* gradXRow = gradX + j * cols;
+
+ float sum_adj_r = 0.f;
+ float sum_sqr = 0.f;
+
+ #pragma omp simd reduction(+ : sum_adj_r, sum_sqr)
+ for(size_t i = 0; i < cols; ++i) {
+ sum_adj_r += yRow[i] / gamma[gammaStride * i];
+ sum_sqr += xRow[i] * xRow[i];
+ }
+
+ float rms = std::sqrt(sum_sqr / cols + eps);
+ #pragma omp simd
+ for(size_t i = 0; i < cols; ++i) {
+ float rmsNorm = yRow[i] / gamma[gammaStride * i];
+ float gradNorm = cols * adjRow[i] - rmsNorm * sum_adj_r;
+ gradNorm /= cols * rms;
+
+ gradXRow[i] += gamma[gammaStride * i] * gradNorm;
+ gradGamma[gammaStride * i] += adjRow[i] * rmsNorm;
+ }
+ }
+ }
+}
+MARIAN_FFAST_MATH_END
+
void Shift(Tensor out_,
Tensor in_,
marian::Shape shift,
diff --git a/src/tensors/gpu/tensor_operators.cu b/src/tensors/gpu/tensor_operators.cu
index 97f0cdfe..d55214bc 100644
--- a/src/tensors/gpu/tensor_operators.cu
+++ b/src/tensors/gpu/tensor_operators.cu
@@ -2303,6 +2303,273 @@ void LayerNormalizationGrad(Ptr<Allocator> allocator,
allocator->free(tempOnesMemory);
}
+template <typename T, typename AccType = float>
+__global__ void gRMSNormalization(T* out,
+ const T* in,
+ const T* gamma,
+ const T* beta,
+ int rows,
+ int cols,
+ AccType eps = 1e-9) {
+ extern __shared__ uint8_t _sharedBytes[];
+ AccType* _shareAccType = (AccType*)_sharedBytes;
+
+ AccType N = cols;
+ for(int bid = 0; bid < rows; bid += gridDim.x) {
+ int j = bid + blockIdx.x;
+ if(j < rows) {
+ T* yRow = out + j * cols;
+ const T* xRow = in + j * cols;
+
+ AccType* _sqSum = _shareAccType;
+
+ _sqSum[threadIdx.x] = (AccType)0.0f;
+ for(int tid = 0; tid < cols; tid += blockDim.x) {
+ int id = tid + threadIdx.x;
+ if(id < cols) {
+ AccType xv = (AccType)xRow[id];
+ _sqSum[threadIdx.x] += xv * xv;
+ }
+ }
+ __syncthreads();
+ int len = blockDim.x;
+ while(len != 1) {
+ __syncthreads();
+ int skip = (len + 1) >> 1;
+ if(threadIdx.x < (len >> 1))
+ _sqSum[threadIdx.x] += _sqSum[threadIdx.x + skip];
+ len = (len + 1) >> 1;
+ }
+ __syncthreads();
+ AccType rms = functional::Ops<AccType>::sqrt(_sqSum[0] / N + eps); // all AccType
+ __syncthreads();
+
+ for(int tid = 0; tid < cols; tid += blockDim.x) {
+ int id = tid + threadIdx.x;
+ if(id < cols) {
+ AccType gammav = (AccType)gamma[id];
+ AccType xv = (AccType)xRow[id];
+ AccType betav = beta ? (AccType)beta[id] : (AccType)0.f;
+ AccType rmsNorm = xv / rms;
+ AccType y = gammav * rmsNorm + betav;
+ yRow[id] = (T)y;
+ }
+ }
+ }
+ __syncthreads();
+ }
+}
+
+void RMSNormalization(Tensor out,
+ Tensor in,
+ Tensor gamma,
+ Tensor beta,
+ float eps) {
+ cudaSetDevice(out->getDeviceId().no);
+
+ int rows = in->shape().elements() / in->shape().back();
+ int cols = in->shape().back();
+
+ int blocks = std::min(MAX_BLOCKS, (int)rows);
+ int threads = std::min(MAX_THREADS, (int)cols);
+ int shared = threads * sizeof(float);
+
+ if(out->type() == Type::float32) {
+ gRMSNormalization<float, float><<<blocks, threads, shared>>>(out->data<float>(),
+ in->data<float>(),
+ gamma->data<float>(),
+ beta ? beta->data<float>() : nullptr,
+ rows,
+ cols,
+ eps);
+#if COMPILE_FP16
+ } else if (out->type() == Type::float16) {
+ gRMSNormalization<half, float><<<blocks, threads, shared>>>(out->data<half>(),
+ in->data<half>(),
+ gamma->data<half>(),
+ beta ? beta->data<half>() : nullptr,
+ rows,
+ cols,
+ eps);
+#endif
+ } else {
+ ABORT("RMSNormalization not implemented for type {}", out->type());
+ }
+}
+
+template <typename T, typename AccType = float>
+__global__ void gRMSNormalizationGrad(T* gradX,
+ T* gradGamma,
+ T* adj,
+ T* y,
+ T* x,
+ T* gamma,
+ T* beta,
+ int rows,
+ int cols,
+ AccType eps = 1e-9) {
+ extern __shared__ uint8_t sharedBytes[];
+ AccType* shared = (AccType*)sharedBytes;
+
+ AccType N = cols;
+
+ for(int bid = 0; bid < rows; bid += gridDim.x) {
+ int j = bid + blockIdx.x;
+ if(j < rows) {
+ AccType* sum_adj_r = shared; // sum of gradient coming in times layerNorm from value
+ AccType* sum_sqr = shared + blockDim.x; // sum of x^2
+
+ const T* xRow = x + j * cols;
+ const T* yRow = y + j * cols;
+ const T* adjRow = adj + j * cols;
+
+ sum_adj_r[threadIdx.x] = (AccType)0.0f;
+ sum_sqr[threadIdx.x] = (AccType)0.0f;
+
+ for(int tid = 0; tid < cols; tid += blockDim.x) {
+ int id = tid + threadIdx.x;
+ if(id < cols) {
+ AccType xv = xRow[id];
+ AccType yv = yRow[id];
+ AccType betav = beta ? (AccType)beta[id] : (AccType)0.f;
+ AccType gammav = (AccType)gamma[id];
+ AccType adjv = adjRow[id];
+ AccType rv = (yv - betav) / gammav; // go back to RMSNorm(x) from scaled and shifted version for accumulation
+
+ sum_adj_r[threadIdx.x] += adjv * rv;
+ sum_sqr[threadIdx.x] += xv * xv;
+ }
+ }
+ __syncthreads();
+ int len = blockDim.x;
+ while(len != 1) {
+ __syncthreads();
+ int skip = (len + 1) >> 1;
+ if(threadIdx.x < (len >> 1)) {
+ sum_adj_r[threadIdx.x] += sum_adj_r[threadIdx.x + skip]; // Accumulates in AccType
+ sum_sqr[threadIdx.x] += sum_sqr[threadIdx.x + skip]; // Accumulates in AccType
+ }
+ len = (len + 1) >> 1;
+ }
+
+ __syncthreads();
+ AccType rms = functional::Ops<AccType>::sqrt(sum_sqr[0] / N + eps);
+ __syncthreads();
+
+ // Jacobian of RMS norm
+ // J = [ \frac{1}{N * rms} (N\delta_{ij} - RN_i RN_j) ]_{ij}
+ // J * a = dC/dx_i = ( N a_i - RN_i \sum_j RN_j a_j ) / (N * rms)
+
+ for(int tid = 0; tid < cols; tid += blockDim.x) {
+ int id = tid + threadIdx.x;
+ if(id < cols) {
+
+ AccType xv = xRow[id];
+ AccType gammav = (AccType)gamma[id];
+ AccType adjv = adjRow[id];
+ AccType rmsNorm = xv / rms;
+
+ AccType gradNorm = N * adjv - rmsNorm * sum_adj_r[0];
+ gradNorm /= N * rms;
+
+ AccType gradXv = gammav * gradNorm;
+
+ // Keep RMSN gradient between [-1000, 1000] for TensorOps, this currently used for making values fit into fp16. This wil also clip inf.
+ // @TODO: to be fixed and removed.
+ AccType sign = functional::Ops<AccType>::sgn(gradXv);
+ AccType cutoff = (AccType)1000.f; // @TODO: expose this somehow as an option? or better: make obsolete.
+ gradXv = functional::Ops<AccType>::abs(gradXv) > cutoff ? sign * cutoff : gradXv; // if gradXv is NaN the value return is NaN too because NaN > value is false.
+
+ // @TODO: frankly, this is embarrasing and should rather be removed or optional? It does help for low precision computation though. Maybe turn into option?
+ gradXv = isnan(gradXv) ? 0.f : gradXv; // turn NaN into 0.
+
+ T* gradXRow = gradX + j * cols;
+ gradXRow[id] += (T)(gradXv);
+
+ T* gradGammaRow = gradGamma + j * cols;
+ // assignment is correct here as this gets summed up
+ // in the next kernel via matrix product
+ gradGammaRow[id] = (T)(adjv * rmsNorm);
+ }
+ }
+ }
+ __syncthreads();
+ }
+}
+
+void RMSNormalizationGrad(Ptr<Allocator> allocator,
+ Tensor gradX,
+ Tensor gradGamma,
+ Tensor gradBeta,
+ Tensor adj,
+ Tensor y,
+ Tensor x,
+ Tensor gamma,
+ Tensor beta,
+ float eps) {
+ cudaSetDevice(adj->getDeviceId().no);
+ int rows = y->shape().elements() / y->shape()[-1];
+ int cols = y->shape()[-1];
+
+ int threads = std::min(MAX_THREADS, cols);
+ int blocks = std::min(MAX_BLOCKS, rows);
+
+ auto tempGradGammaMemory = allocator->alloc(adj->memory()->size());
+ Tensor tempGradGamma = TensorBase::New(tempGradGammaMemory, adj->shape(), adj->type(), adj->getBackend());
+ tempGradGamma->set(0.f);
+
+ auto tempOnesMemory = allocator->alloc(rows * sizeOf(adj->type()));
+ Tensor tempOnes = TensorBase::New(tempOnesMemory, Shape({1, rows}), adj->type(), adj->getBackend());
+ tempOnes->set(1.f);
+
+ if(gradX->type() == Type::float32) {
+ int shared = sizeof(float) * threads * 2;
+ gRMSNormalizationGrad<float, float><<<blocks, threads, shared>>>(
+ gradX->data<float>(),
+ tempGradGamma->data<float>(),
+ adj->data<float>(),
+ y->data<float>(),
+ x->data<float>(),
+ gamma->data<float>(),
+ (beta) ? beta->data<float>() : nullptr,
+ rows,
+ cols,
+ eps);
+#if COMPILE_FP16
+ } else if (gradX->type() == Type::float16) {
+ // accumulate in float
+ int shared = sizeof(float) * threads * 2;
+ gRMSNormalizationGrad<half, float><<<blocks, threads, shared>>>(
+ gradX->data<half>(),
+ tempGradGamma->data<half>(),
+ adj->data<half>(),
+ y->data<half>(),
+ x->data<half>(),
+ gamma->data<half>(),
+ (beta) ? beta->data<half>() : nullptr,
+ rows,
+ cols,
+ eps);
+#endif
+ } else {
+ ABORT("RMSNormalizationGrad not implemented for type {}", gradX->type());
+ }
+
+ // We use this go get rid of the atomicAdd and perform a reduce of the gradients afterwards.
+ // This is much faster for fp16 which seems to have a broken atomicAdd implementation.
+ // We reduce bias gradients with a matrix multiply, but use a 32-bit compute type.
+ // This preserves precision with larger batches where all batch entries reduce into a single vector.
+ // See also AffineNodeOp where we do the same for biases
+ gpu::Prod(gradGamma, tempOnes, tempGradGamma, false, false, 1, 1, Type::float32); // beta set to one to add
+
+ if(gradBeta) // dC/dbeta = adj - inverse broadcasting (reduction)
+ gpu::Prod(gradBeta, tempOnes, adj, false, false, 1, 1, Type::float32); // beta set to one to add
+
+ allocator->free(tempGradGammaMemory);
+ allocator->free(tempOnesMemory);
+}
+
+
template <bool add, typename T>
__global__ void gShift(T* out,
const T* in,
diff --git a/src/tensors/tensor_operators.h b/src/tensors/tensor_operators.h
index af7946dd..ef485068 100644
--- a/src/tensors/tensor_operators.h
+++ b/src/tensors/tensor_operators.h
@@ -218,6 +218,55 @@ static inline void LayerNormalizationGrad(
cpu::LayerNormalizationGrad(gradX, gradGamma, gradBeta, adj, y, x, gamma, beta, eps);
}
+// clang-format off
+DISPATCH5(RMSNormalization, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor, float)
+
+#ifdef CUDA_FOUND
+namespace gpu {
+void RMSNormalizationGrad(Ptr<Allocator> allocator,
+ Tensor gradX,
+ Tensor gradGamma,
+ Tensor gradBeta,
+ Tensor adj,
+ Tensor y,
+ Tensor x,
+ Tensor gamma,
+ Tensor beta,
+ float eps);
+}
+#endif
+
+namespace cpu {
+void RMSNormalizationGrad(Tensor gradX,
+ Tensor gradGamma,
+ Tensor gradBeta,
+ Tensor adj,
+ Tensor y,
+ Tensor x,
+ Tensor gamma,
+ Tensor beta,
+ float eps);
+}
+
+static inline void RMSNormalizationGrad(
+ Ptr<Allocator> allocator,
+ Tensor gradX,
+ Tensor gradGamma,
+ Tensor gradBeta,
+ Tensor adj,
+ Tensor y,
+ Tensor x,
+ Tensor gamma,
+ Tensor beta,
+ float eps) {
+#ifdef CUDA_FOUND
+ if(gradX->getBackend()->getDeviceId().type == DeviceType::gpu)
+ gpu::RMSNormalizationGrad(allocator, gradX, gradGamma, gradBeta, adj, y, x, gamma, beta, eps);
+ else
+#endif
+ cpu::RMSNormalizationGrad(gradX, gradGamma, gradBeta, adj, y, x, gamma, beta, eps);
+}
+
DISPATCH4(HighwayForward, marian::Tensor, const marian::Tensor, const marian::Tensor, const marian::Tensor)
DISPATCH7(HighwayBackward, marian::Tensor, marian::Tensor, marian::Tensor, const marian::Tensor, const marian::Tensor, const marian::Tensor, const marian::Tensor)
diff --git a/src/tests/units/operator_tests.cpp b/src/tests/units/operator_tests.cpp
index c3fd4a9e..1a18da99 100644
--- a/src/tests/units/operator_tests.cpp
+++ b/src/tests/units/operator_tests.cpp
@@ -300,6 +300,49 @@ void tests(DeviceType device, Type floatType = Type::float32) {
}
+ SECTION("RMS normalization") {
+ graph->clear();
+ values.clear();
+
+ std::vector<T> init = {
+ 2.88794374, 4.67853451, 3.96257305, 3.28433037,
+ 0.37778997, 0.67662024, 4.24959183, 1.23910618,
+ 0.68929380, 2.00369596, 4.38251686, 1.75624943,
+ 4.96126175, 3.01947117, 4.72057724, 2.23017120
+ };
+
+ auto a1 = graph->param("test1", {2, 2, 4}, inits::fromVector(init));
+ auto a2 = graph->param("test2", {2, 2, 4}, inits::fromVector(init));
+ auto gamma = graph->param("gamma", {1, 4}, inits::ones());
+
+ auto rms = rmsNorm(a1, gamma, nullptr, 1e-5f);
+ auto rms2 = gamma * (a2 / sqrt(mean(a2 * a2, /*axis=*/-1) + 1e-5f));
+
+ auto top = sum(flatten(rms + rms2));
+
+ graph->forward();
+ graph->backward();
+
+ CHECK(rms->shape() == Shape({2, 2, 4}));
+
+ std::vector<T> values2;
+
+ // compare values of rms and rms2 to make sure forward computation is correct
+ rms->val()->get(values);
+ rms2->val()->get(values2);
+
+ CHECK( std::equal(values.begin(), values.end(),
+ values2.begin(), floatApprox) );
+
+ // compare adjoints of a1 and a2 (parameters) to makes sure gradient computation is correct
+ a1->grad()->get(values);
+ a2->grad()->get(values2);
+
+ CHECK( std::equal(values.begin(), values.end(),
+ values2.begin(), floatApprox) );
+
+ }
+
SECTION("reductions") {
graph->clear();
values.clear();