From caddad90cdb06283da2f5d17a3340ca8c6387b38 Mon Sep 17 00:00:00 2001 From: Martin Junczys-Dowmunt Date: Sat, 10 Apr 2021 15:28:38 +0000 Subject: Merged PR 18505: RMSNorm on GPU Support for RMSNorm as drop-in replace for LayerNorm from _Biao Zhang; Rico Sennrich (2019). Root Mean Square Layer Normalization_. Enabled in Transformer model via `--transformer-postprocess dar` instead of `dan`. --- src/graph/expression_graph.cpp | 11 +- src/graph/expression_operators.cpp | 12 ++ src/graph/expression_operators.h | 12 ++ src/graph/node_operators_binary.h | 58 ++++++++ src/layers/generic.h | 6 + src/models/transformer.h | 4 + src/tensors/cpu/tensor_operators.cpp | 196 ++++++++++++++++++++++--- src/tensors/gpu/tensor_operators.cu | 267 +++++++++++++++++++++++++++++++++++ src/tensors/tensor_operators.h | 49 +++++++ src/tests/units/operator_tests.cpp | 43 ++++++ 10 files changed, 636 insertions(+), 22 deletions(-) (limited to 'src') diff --git a/src/graph/expression_graph.cpp b/src/graph/expression_graph.cpp index 827fb3ed..12a1195e 100644 --- a/src/graph/expression_graph.cpp +++ b/src/graph/expression_graph.cpp @@ -208,8 +208,15 @@ void ExpressionGraph::backward(bool reset, float clipValue) { } if(v->trainable() && v->marked_for_debug()) { - LOG(info, "Debug Grad: {} op={}", v->debug_message(), v->type()); - LOG(info, v->grad()->debug()); + Logger log = spdlog::get("general"); + if(log) { + LOG(info, "Debug Grad: {} op={}", v->debug_message(), v->type()); + LOG(info, v->grad()->debug()); + } + else { + std::cerr << "Debug Grad: " << v->debug_message() << " op=" << v->type() << std::endl; + std::cerr << v->grad()->debug() << std::endl; + } } if(v->trainable() && clipValue != 0) { diff --git a/src/graph/expression_operators.cpp b/src/graph/expression_operators.cpp index 048c7478..6c7ef91c 100644 --- a/src/graph/expression_operators.cpp +++ b/src/graph/expression_operators.cpp @@ -749,6 +749,18 @@ Expr layerNorm(Expr x, return Expression(nodes, eps); } +Expr rmsNorm(Expr x, + Expr gamma, + Expr beta /*= nullptr*/, + float eps /*= 1e-9*/) { + + // layerNorm accumulates in float, so small eps is fine + std::vector nodes = {x, gamma}; + if(beta) + nodes.push_back(beta); + return Expression(nodes, eps); +} + Expr highway(Expr y, Expr x, Expr t) { std::vector nodes = {y, x, t}; return Expression(nodes); diff --git a/src/graph/expression_operators.h b/src/graph/expression_operators.h index 81b0f5ea..f3d84eb6 100644 --- a/src/graph/expression_operators.h +++ b/src/graph/expression_operators.h @@ -915,6 +915,18 @@ Expr weighted_average(Expr in, Expr weights, int ax = 0); */ Expr layerNorm(Expr x, Expr gamma, Expr beta = nullptr, float eps = 1e-9); +/** + * Applies RMS normalization over the last dimension. + * + * See: Biao Zhang; Rico Sennrich (2019). Root Mean Square Layer Normalization. + * In Advances in Neural Information Processing Systems 32. Vancouver, Canada. + * @f[ + \frac{x}{\sqrt{\frac{1}{N}\sum x^2 + \mathrm{eps}}} \times \gamma + \beta + * @f] + * @see RMSNormalizationOp + */ +Expr rmsNorm(Expr x, Expr gamma, Expr beta = nullptr, float eps = 1e-9); + /** * Highway transformation. * Computes the highway tranform on @p y and @p x as gated by @p t: diff --git a/src/graph/node_operators_binary.h b/src/graph/node_operators_binary.h index 55f105a9..91fc29da 100644 --- a/src/graph/node_operators_binary.h +++ b/src/graph/node_operators_binary.h @@ -1369,6 +1369,64 @@ private: float eps_; }; +// RMS norm along last axis +struct RMSNormalizationOp : public NaryNodeOp { +public: + RMSNormalizationOp(const std::vector& nodes, float eps = 1e-9) + : NaryNodeOp(nodes), eps_(eps) { + // @TODO: dimension check + } + + NodeOps forwardOps() override { + return {NodeOp( + RMSNormalization(val_, + child(0)->val(), + child(1)->val(), + (children_.size() == 3) ? child(2)->val() : nullptr, + eps_))}; + } + + // @BUGBUG: backward has not been tested for broadcasting gamma/beta + NodeOps backwardOps() override { + return {NodeOp( + RMSNormalizationGrad( + graph()->allocator(), + child(0)->grad(), + child(1)->grad(), + (children_.size() == 3) ? child(2)->grad() : nullptr, + adj_, + val_, + child(0)->val(), + child(1)->val(), + (children_.size() == 3) ? child(2)->val() : nullptr, + eps_))}; + } + + const std::string type() override { return "rms_normalization"; } + + virtual size_t hash() override { + size_t seed = NaryNodeOp::hash(); + util::hash_combine(seed, eps_); + return seed; + } + + virtual bool equal(Expr node) override { + if(!NaryNodeOp::equal(node)) + return false; + auto cnode = std::dynamic_pointer_cast(node); + if(!cnode) + return false; + if(eps_ != cnode->eps_) + return false; + return true; + } + +private: + friend class SerializationHelpers; // @TODO: use the same name for this as SqrtNodeOp + float eps_; +}; + + struct HighwayNodeOp : public NaryNodeOp { HighwayNodeOp(const std::vector& nodes) : NaryNodeOp(nodes) {} diff --git a/src/layers/generic.h b/src/layers/generic.h index 5eb93615..2746bc85 100644 --- a/src/layers/generic.h +++ b/src/layers/generic.h @@ -212,4 +212,10 @@ static inline Expr layerNorm(Expr x, std::string prefix, std::string suffix = st return marian::layerNorm(x, scale, bias, 1e-6f); } +static inline Expr rmsNorm(Expr x, std::string prefix, std::string suffix = std::string()) { + int dimModel = x->shape()[-1]; + auto scale = x->graph()->param(prefix + "_rms_scale" + suffix, {1, dimModel}, inits::ones()); + return marian::rmsNorm(x, scale, nullptr, 1e-6f); +} + } // namespace marian diff --git a/src/models/transformer.h b/src/models/transformer.h index 79b59000..1da02318 100644 --- a/src/models/transformer.h +++ b/src/models/transformer.h @@ -176,6 +176,8 @@ public: // layer normalization else if (op == 'n') output = layerNorm(output, prefix, "_pre"); + else if (op == 'r') + output = rmsNorm(output, prefix, "_pre"); else ABORT("Unknown pre-processing operation '{}'", op); } @@ -201,6 +203,8 @@ public: // layer normalization else if(op == 'n') output = layerNorm(output, prefix); + else if(op == 'r') + output = rmsNorm(output, prefix); else ABORT("Unknown pre-processing operation '{}'", op); } diff --git a/src/tensors/cpu/tensor_operators.cpp b/src/tensors/cpu/tensor_operators.cpp index 1191a2be..67d993fc 100755 --- a/src/tensors/cpu/tensor_operators.cpp +++ b/src/tensors/cpu/tensor_operators.cpp @@ -977,7 +977,7 @@ float L2Norm(Tensor in, Ptr /*not used*/) { float sum = 0.f; size_t size = in->size(); const float* data = in->data(); -#pragma omp parallel for simd reduction(+ : sum) + #pragma omp parallel for simd reduction(+ : sum) for(size_t i = 0; i < size; ++i) { sum += data[i] * data[i]; } @@ -998,14 +998,14 @@ void Att(Tensor out_, Tensor va_, Tensor context_, Tensor state_) { int rows = m; int cols = k; -#pragma omp parallel for + #pragma omp parallel for for(int j = 0; j < rows; ++j) { const float* vaRow = va; const float* ctxRow = ctx + (j % (b * t)) * cols; const float* stateRow = state + ((j / (b * t)) * b + j % b) * cols; float sum = 0.f; -#pragma omp simd reduction(+ : sum) + #pragma omp simd reduction(+ : sum) for(int i = 0; i < cols; ++i) { float z = ctxRow[i] + stateRow[i]; sum += std::tanh(z) * vaRow[i]; @@ -1035,7 +1035,7 @@ void AttBack(Tensor gVa_, size_t k = context_->shape()[-1]; size_t n = context_->shape()[-2]; -#pragma omp parallel for reduction(+ : gState[:n * k], gVa[:k]) + #pragma omp parallel for reduction(+ : gState[:n * k], gVa[:k]) for(size_t j = 0; j < m; ++j) { float* gcRow = gContext + j * k; float* gsRow = gState + (j % n) * k; @@ -1045,7 +1045,7 @@ void AttBack(Tensor gVa_, float adj_j = adj[j]; -#pragma omp simd + #pragma omp simd for(size_t i = 0; i < k; ++i) { float z = cRow[i] + sRow[i]; @@ -1070,20 +1070,20 @@ void LayerNormalizationImpl(float* out, float eps, int rows, int cols) { -#pragma omp parallel for + #pragma omp parallel for for(int j = 0; j < rows; ++j) { float* so = out + j * cols; const float* sp = in + j * cols; float sum = 0.f; -#pragma omp simd reduction(+ : sum) + #pragma omp simd reduction(+ : sum) for(int i = 0; i < cols; ++i) { sum += sp[i]; } float mean = sum / cols; float sqSum = 0.f; -#pragma omp simd reduction(+ : sqSum) + #pragma omp simd reduction(+ : sqSum) for(int i = 0; i < cols; ++i) { float ex = sp[i] - mean; sqSum += ex * ex; @@ -1091,7 +1091,7 @@ void LayerNormalizationImpl(float* out, float sigma = std::sqrt(sqSum / cols + eps); -#pragma omp simd + #pragma omp simd for(int i = 0; i < cols; ++i) { float t = alpha[alphaStride * i] * ((sp[i] - mean) / sigma); if(hasBeta) @@ -1168,7 +1168,7 @@ void LayerNormalizationGrad(Tensor gradX_, size_t cols = y_->shape()[-1]; if(beta) { -#pragma omp parallel for reduction(+ : gradGamma[:cols], gradBeta[:cols]) + #pragma omp parallel for reduction(+ : gradGamma[:cols], gradBeta[:cols]) for(size_t j = 0; j < rows; ++j) { const float* xRow = x + j * cols; const float* yRow = y + j * cols; @@ -1180,7 +1180,7 @@ void LayerNormalizationGrad(Tensor gradX_, float sum_adj_x = 0.f; float sum_sqr = 0.f; -#pragma omp simd reduction(+ : sum_x, sum_adj_x, sum_adj) + #pragma omp simd reduction(+ : sum_x, sum_adj_x, sum_adj) for(size_t i = 0; i < cols; ++i) { sum_x += xRow[i]; sum_adj_x += adjRow[i] * (yRow[i] - (beta ? beta[betaStride * i] : 0.f)) / gamma[gammaStride * i]; @@ -1188,14 +1188,14 @@ void LayerNormalizationGrad(Tensor gradX_, } float mean = sum_x / cols; -#pragma omp simd reduction(+ : sum_sqr) + #pragma omp simd reduction(+ : sum_sqr) for(size_t i = 0; i < cols; ++i) { float ex = xRow[i] - mean; sum_sqr += ex * ex; } float sigma = std::sqrt(sum_sqr / cols + eps); -#pragma omp simd + #pragma omp simd for(size_t i = 0; i < cols; ++i) { float grad_x = 0.f; float x_hat = (yRow[i] - beta[betaStride * i]) / gamma[gammaStride * i]; @@ -1209,8 +1209,8 @@ void LayerNormalizationGrad(Tensor gradX_, gradBeta[betaStride * i] += adjRow[i]; } } - } else { -#pragma omp parallel for reduction(+ : gradGamma[:cols]) + } else { // @TODO: this code duplication is really ugly, but required for omp to work correctly? + #pragma omp parallel for reduction(+ : gradGamma[:cols]) for(size_t j = 0; j < rows; ++j) { const float* xRow = x + j * cols; const float* yRow = y + j * cols; @@ -1222,23 +1222,22 @@ void LayerNormalizationGrad(Tensor gradX_, float sum_adj_x = 0.f; float sum_sqr = 0.f; -#pragma omp simd reduction(+ : sum_x, sum_adj_x, sum_adj) + #pragma omp simd reduction(+ : sum_x, sum_adj_x, sum_adj) for(size_t i = 0; i < cols; ++i) { sum_x += xRow[i]; - sum_adj_x += adjRow[i] * (yRow[i] - (beta ? beta[betaStride * i] : 0.f)) / gamma[gammaStride * i]; - // @TODO: beta is NULL here ^^ + sum_adj_x += adjRow[i] * yRow[i] / gamma[gammaStride * i]; sum_adj += adjRow[i]; } float mean = sum_x / cols; -#pragma omp simd reduction(+ : sum_sqr) + #pragma omp simd reduction(+ : sum_sqr) for(size_t i = 0; i < cols; ++i) { float ex = xRow[i] - mean; sum_sqr += ex * ex; } float sigma = std::sqrt(sum_sqr / cols + eps); -#pragma omp simd + #pragma omp simd for(size_t i = 0; i < cols; ++i) { float grad_x = 0.f; float x_hat = yRow[i] / gamma[gammaStride * i]; @@ -1255,6 +1254,163 @@ void LayerNormalizationGrad(Tensor gradX_, } MARIAN_FFAST_MATH_END +MARIAN_FFAST_MATH_BEGIN +template +void RMSNormalizationImpl(float* out, + const float* in, + const float* alpha, + const float* beta, + float eps, + int rows, + int cols) { + #pragma omp parallel for + for(int j = 0; j < rows; ++j) { + float* so = out + j * cols; + const float* sp = in + j * cols; + + float sqSum = 0.f; + #pragma omp simd reduction(+ : sqSum) + for(int i = 0; i < cols; ++i) { + sqSum += sp[i] * sp[i]; + } + + float rms = std::sqrt(sqSum / cols + eps); + + #pragma omp simd + for(int i = 0; i < cols; ++i) { + float t = alpha[alphaStride * i] * (sp[i] / rms); + if(hasBeta) + t += beta[betaStride * i]; + + so[i] = t; + } + } +} +MARIAN_FFAST_MATH_END + +template +inline void RMSNormalizationDispatchBeta(float* out, + const float* in, + const float* alpha, + Tensor beta, + float eps, + int rows, + int cols) { + if (beta) { + if (beta->shape().back() > 1) { + RMSNormalizationImpl(out, in, alpha, beta->data(), eps, rows, cols); + } else { + RMSNormalizationImpl(out, in, alpha, beta->data(), eps, rows, cols); + } + } else { + RMSNormalizationImpl(out, in, alpha, nullptr, eps, rows, cols); + } +} + +void RMSNormalization(Tensor out, + Tensor in, + Tensor gamma, + Tensor beta, + float eps) { + const float* alpha = gamma->data(); + const int alphaStride = gamma->shape().back() > 1; // broadcasting for alpha and beta + + int rows = in->shape().elements() / in->shape().back(); + int cols = in->shape().back(); + if (alphaStride == 0) { + RMSNormalizationDispatchBeta<0>(out->data(), in->data(), alpha, beta, eps, rows, cols); + } else { + RMSNormalizationDispatchBeta<1>(out->data(), in->data(), alpha, beta, eps, rows, cols); + } +} + +MARIAN_FFAST_MATH_BEGIN +void RMSNormalizationGrad(Tensor gradX_, + Tensor gradGamma_, + Tensor gradBeta_, + Tensor adj_, + Tensor y_, + Tensor x_, + Tensor gamma_, + Tensor beta_, + float eps) { + float* gradX = gradX_->data(); + float* gradGamma = gradGamma_->data(); + float* gradBeta = gradBeta_ ? gradBeta_->data() : nullptr; + float* adj = adj_->data(); + float* x = x_->data(); + float* y = y_->data(); + float* gamma = gamma_->data(); + float* beta = beta_ ? beta_->data() : nullptr; + // @TODO: The CPU implementation supports scalar gamma and beta. This is a left-over, + // we should enable that in the GPU version as well. + const int gammaStride = gamma_->shape().back() > 1; // broadcasting for alpha and beta. 0 means it's a scalar + const int betaStride = beta_ && beta_->shape().back() > 1; + + size_t rows = y_->shape().elements() / y_->shape()[-1]; + size_t cols = y_->shape()[-1]; + + if(beta) { + #pragma omp parallel for reduction(+ : gradGamma[:cols], gradBeta[:cols]) + for(size_t j = 0; j < rows; ++j) { + const float* xRow = x + j * cols; + const float* yRow = y + j * cols; + const float* adjRow = adj + j * cols; + float* gradXRow = gradX + j * cols; + + float sum_adj_r = 0.f; + float sum_sqr = 0.f; + + #pragma omp simd reduction(+ : sum_adj_r, sum_sqr) + for(size_t i = 0; i < cols; ++i) { + sum_adj_r += adjRow[i] * (yRow[i] - beta[betaStride * i]) / gamma[gammaStride * i]; + sum_sqr += xRow[i] * xRow[i]; + } + + float rms = std::sqrt(sum_sqr / cols + eps); + #pragma omp simd + for(size_t i = 0; i < cols; ++i) { + float rmsNorm = (yRow[i] - beta[betaStride * i]) / gamma[gammaStride * i]; + float gradNorm = cols * adjRow[i] - rmsNorm * sum_adj_r; + gradNorm /= cols * rms; + + gradXRow[i] += gamma[gammaStride * i] * gradNorm; + gradGamma[gammaStride * i] += adjRow[i] * rmsNorm; + gradBeta[betaStride * i] += adjRow[i]; + } + } + } else { + #pragma omp parallel for reduction(+ : gradGamma[:cols]) + for(size_t j = 0; j < rows; ++j) { + const float* xRow = x + j * cols; + const float* yRow = y + j * cols; + const float* adjRow = adj + j * cols; + float* gradXRow = gradX + j * cols; + + float sum_adj_r = 0.f; + float sum_sqr = 0.f; + + #pragma omp simd reduction(+ : sum_adj_r, sum_sqr) + for(size_t i = 0; i < cols; ++i) { + sum_adj_r += yRow[i] / gamma[gammaStride * i]; + sum_sqr += xRow[i] * xRow[i]; + } + + float rms = std::sqrt(sum_sqr / cols + eps); + #pragma omp simd + for(size_t i = 0; i < cols; ++i) { + float rmsNorm = yRow[i] / gamma[gammaStride * i]; + float gradNorm = cols * adjRow[i] - rmsNorm * sum_adj_r; + gradNorm /= cols * rms; + + gradXRow[i] += gamma[gammaStride * i] * gradNorm; + gradGamma[gammaStride * i] += adjRow[i] * rmsNorm; + } + } + } +} +MARIAN_FFAST_MATH_END + void Shift(Tensor out_, Tensor in_, marian::Shape shift, diff --git a/src/tensors/gpu/tensor_operators.cu b/src/tensors/gpu/tensor_operators.cu index 97f0cdfe..d55214bc 100644 --- a/src/tensors/gpu/tensor_operators.cu +++ b/src/tensors/gpu/tensor_operators.cu @@ -2303,6 +2303,273 @@ void LayerNormalizationGrad(Ptr allocator, allocator->free(tempOnesMemory); } +template +__global__ void gRMSNormalization(T* out, + const T* in, + const T* gamma, + const T* beta, + int rows, + int cols, + AccType eps = 1e-9) { + extern __shared__ uint8_t _sharedBytes[]; + AccType* _shareAccType = (AccType*)_sharedBytes; + + AccType N = cols; + for(int bid = 0; bid < rows; bid += gridDim.x) { + int j = bid + blockIdx.x; + if(j < rows) { + T* yRow = out + j * cols; + const T* xRow = in + j * cols; + + AccType* _sqSum = _shareAccType; + + _sqSum[threadIdx.x] = (AccType)0.0f; + for(int tid = 0; tid < cols; tid += blockDim.x) { + int id = tid + threadIdx.x; + if(id < cols) { + AccType xv = (AccType)xRow[id]; + _sqSum[threadIdx.x] += xv * xv; + } + } + __syncthreads(); + int len = blockDim.x; + while(len != 1) { + __syncthreads(); + int skip = (len + 1) >> 1; + if(threadIdx.x < (len >> 1)) + _sqSum[threadIdx.x] += _sqSum[threadIdx.x + skip]; + len = (len + 1) >> 1; + } + __syncthreads(); + AccType rms = functional::Ops::sqrt(_sqSum[0] / N + eps); // all AccType + __syncthreads(); + + for(int tid = 0; tid < cols; tid += blockDim.x) { + int id = tid + threadIdx.x; + if(id < cols) { + AccType gammav = (AccType)gamma[id]; + AccType xv = (AccType)xRow[id]; + AccType betav = beta ? (AccType)beta[id] : (AccType)0.f; + AccType rmsNorm = xv / rms; + AccType y = gammav * rmsNorm + betav; + yRow[id] = (T)y; + } + } + } + __syncthreads(); + } +} + +void RMSNormalization(Tensor out, + Tensor in, + Tensor gamma, + Tensor beta, + float eps) { + cudaSetDevice(out->getDeviceId().no); + + int rows = in->shape().elements() / in->shape().back(); + int cols = in->shape().back(); + + int blocks = std::min(MAX_BLOCKS, (int)rows); + int threads = std::min(MAX_THREADS, (int)cols); + int shared = threads * sizeof(float); + + if(out->type() == Type::float32) { + gRMSNormalization<<>>(out->data(), + in->data(), + gamma->data(), + beta ? beta->data() : nullptr, + rows, + cols, + eps); +#if COMPILE_FP16 + } else if (out->type() == Type::float16) { + gRMSNormalization<<>>(out->data(), + in->data(), + gamma->data(), + beta ? beta->data() : nullptr, + rows, + cols, + eps); +#endif + } else { + ABORT("RMSNormalization not implemented for type {}", out->type()); + } +} + +template +__global__ void gRMSNormalizationGrad(T* gradX, + T* gradGamma, + T* adj, + T* y, + T* x, + T* gamma, + T* beta, + int rows, + int cols, + AccType eps = 1e-9) { + extern __shared__ uint8_t sharedBytes[]; + AccType* shared = (AccType*)sharedBytes; + + AccType N = cols; + + for(int bid = 0; bid < rows; bid += gridDim.x) { + int j = bid + blockIdx.x; + if(j < rows) { + AccType* sum_adj_r = shared; // sum of gradient coming in times layerNorm from value + AccType* sum_sqr = shared + blockDim.x; // sum of x^2 + + const T* xRow = x + j * cols; + const T* yRow = y + j * cols; + const T* adjRow = adj + j * cols; + + sum_adj_r[threadIdx.x] = (AccType)0.0f; + sum_sqr[threadIdx.x] = (AccType)0.0f; + + for(int tid = 0; tid < cols; tid += blockDim.x) { + int id = tid + threadIdx.x; + if(id < cols) { + AccType xv = xRow[id]; + AccType yv = yRow[id]; + AccType betav = beta ? (AccType)beta[id] : (AccType)0.f; + AccType gammav = (AccType)gamma[id]; + AccType adjv = adjRow[id]; + AccType rv = (yv - betav) / gammav; // go back to RMSNorm(x) from scaled and shifted version for accumulation + + sum_adj_r[threadIdx.x] += adjv * rv; + sum_sqr[threadIdx.x] += xv * xv; + } + } + __syncthreads(); + int len = blockDim.x; + while(len != 1) { + __syncthreads(); + int skip = (len + 1) >> 1; + if(threadIdx.x < (len >> 1)) { + sum_adj_r[threadIdx.x] += sum_adj_r[threadIdx.x + skip]; // Accumulates in AccType + sum_sqr[threadIdx.x] += sum_sqr[threadIdx.x + skip]; // Accumulates in AccType + } + len = (len + 1) >> 1; + } + + __syncthreads(); + AccType rms = functional::Ops::sqrt(sum_sqr[0] / N + eps); + __syncthreads(); + + // Jacobian of RMS norm + // J = [ \frac{1}{N * rms} (N\delta_{ij} - RN_i RN_j) ]_{ij} + // J * a = dC/dx_i = ( N a_i - RN_i \sum_j RN_j a_j ) / (N * rms) + + for(int tid = 0; tid < cols; tid += blockDim.x) { + int id = tid + threadIdx.x; + if(id < cols) { + + AccType xv = xRow[id]; + AccType gammav = (AccType)gamma[id]; + AccType adjv = adjRow[id]; + AccType rmsNorm = xv / rms; + + AccType gradNorm = N * adjv - rmsNorm * sum_adj_r[0]; + gradNorm /= N * rms; + + AccType gradXv = gammav * gradNorm; + + // Keep RMSN gradient between [-1000, 1000] for TensorOps, this currently used for making values fit into fp16. This wil also clip inf. + // @TODO: to be fixed and removed. + AccType sign = functional::Ops::sgn(gradXv); + AccType cutoff = (AccType)1000.f; // @TODO: expose this somehow as an option? or better: make obsolete. + gradXv = functional::Ops::abs(gradXv) > cutoff ? sign * cutoff : gradXv; // if gradXv is NaN the value return is NaN too because NaN > value is false. + + // @TODO: frankly, this is embarrasing and should rather be removed or optional? It does help for low precision computation though. Maybe turn into option? + gradXv = isnan(gradXv) ? 0.f : gradXv; // turn NaN into 0. + + T* gradXRow = gradX + j * cols; + gradXRow[id] += (T)(gradXv); + + T* gradGammaRow = gradGamma + j * cols; + // assignment is correct here as this gets summed up + // in the next kernel via matrix product + gradGammaRow[id] = (T)(adjv * rmsNorm); + } + } + } + __syncthreads(); + } +} + +void RMSNormalizationGrad(Ptr allocator, + Tensor gradX, + Tensor gradGamma, + Tensor gradBeta, + Tensor adj, + Tensor y, + Tensor x, + Tensor gamma, + Tensor beta, + float eps) { + cudaSetDevice(adj->getDeviceId().no); + int rows = y->shape().elements() / y->shape()[-1]; + int cols = y->shape()[-1]; + + int threads = std::min(MAX_THREADS, cols); + int blocks = std::min(MAX_BLOCKS, rows); + + auto tempGradGammaMemory = allocator->alloc(adj->memory()->size()); + Tensor tempGradGamma = TensorBase::New(tempGradGammaMemory, adj->shape(), adj->type(), adj->getBackend()); + tempGradGamma->set(0.f); + + auto tempOnesMemory = allocator->alloc(rows * sizeOf(adj->type())); + Tensor tempOnes = TensorBase::New(tempOnesMemory, Shape({1, rows}), adj->type(), adj->getBackend()); + tempOnes->set(1.f); + + if(gradX->type() == Type::float32) { + int shared = sizeof(float) * threads * 2; + gRMSNormalizationGrad<<>>( + gradX->data(), + tempGradGamma->data(), + adj->data(), + y->data(), + x->data(), + gamma->data(), + (beta) ? beta->data() : nullptr, + rows, + cols, + eps); +#if COMPILE_FP16 + } else if (gradX->type() == Type::float16) { + // accumulate in float + int shared = sizeof(float) * threads * 2; + gRMSNormalizationGrad<<>>( + gradX->data(), + tempGradGamma->data(), + adj->data(), + y->data(), + x->data(), + gamma->data(), + (beta) ? beta->data() : nullptr, + rows, + cols, + eps); +#endif + } else { + ABORT("RMSNormalizationGrad not implemented for type {}", gradX->type()); + } + + // We use this go get rid of the atomicAdd and perform a reduce of the gradients afterwards. + // This is much faster for fp16 which seems to have a broken atomicAdd implementation. + // We reduce bias gradients with a matrix multiply, but use a 32-bit compute type. + // This preserves precision with larger batches where all batch entries reduce into a single vector. + // See also AffineNodeOp where we do the same for biases + gpu::Prod(gradGamma, tempOnes, tempGradGamma, false, false, 1, 1, Type::float32); // beta set to one to add + + if(gradBeta) // dC/dbeta = adj - inverse broadcasting (reduction) + gpu::Prod(gradBeta, tempOnes, adj, false, false, 1, 1, Type::float32); // beta set to one to add + + allocator->free(tempGradGammaMemory); + allocator->free(tempOnesMemory); +} + + template __global__ void gShift(T* out, const T* in, diff --git a/src/tensors/tensor_operators.h b/src/tensors/tensor_operators.h index af7946dd..ef485068 100644 --- a/src/tensors/tensor_operators.h +++ b/src/tensors/tensor_operators.h @@ -218,6 +218,55 @@ static inline void LayerNormalizationGrad( cpu::LayerNormalizationGrad(gradX, gradGamma, gradBeta, adj, y, x, gamma, beta, eps); } +// clang-format off +DISPATCH5(RMSNormalization, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor, float) + +#ifdef CUDA_FOUND +namespace gpu { +void RMSNormalizationGrad(Ptr allocator, + Tensor gradX, + Tensor gradGamma, + Tensor gradBeta, + Tensor adj, + Tensor y, + Tensor x, + Tensor gamma, + Tensor beta, + float eps); +} +#endif + +namespace cpu { +void RMSNormalizationGrad(Tensor gradX, + Tensor gradGamma, + Tensor gradBeta, + Tensor adj, + Tensor y, + Tensor x, + Tensor gamma, + Tensor beta, + float eps); +} + +static inline void RMSNormalizationGrad( + Ptr allocator, + Tensor gradX, + Tensor gradGamma, + Tensor gradBeta, + Tensor adj, + Tensor y, + Tensor x, + Tensor gamma, + Tensor beta, + float eps) { +#ifdef CUDA_FOUND + if(gradX->getBackend()->getDeviceId().type == DeviceType::gpu) + gpu::RMSNormalizationGrad(allocator, gradX, gradGamma, gradBeta, adj, y, x, gamma, beta, eps); + else +#endif + cpu::RMSNormalizationGrad(gradX, gradGamma, gradBeta, adj, y, x, gamma, beta, eps); +} + DISPATCH4(HighwayForward, marian::Tensor, const marian::Tensor, const marian::Tensor, const marian::Tensor) DISPATCH7(HighwayBackward, marian::Tensor, marian::Tensor, marian::Tensor, const marian::Tensor, const marian::Tensor, const marian::Tensor, const marian::Tensor) diff --git a/src/tests/units/operator_tests.cpp b/src/tests/units/operator_tests.cpp index c3fd4a9e..1a18da99 100644 --- a/src/tests/units/operator_tests.cpp +++ b/src/tests/units/operator_tests.cpp @@ -300,6 +300,49 @@ void tests(DeviceType device, Type floatType = Type::float32) { } + SECTION("RMS normalization") { + graph->clear(); + values.clear(); + + std::vector init = { + 2.88794374, 4.67853451, 3.96257305, 3.28433037, + 0.37778997, 0.67662024, 4.24959183, 1.23910618, + 0.68929380, 2.00369596, 4.38251686, 1.75624943, + 4.96126175, 3.01947117, 4.72057724, 2.23017120 + }; + + auto a1 = graph->param("test1", {2, 2, 4}, inits::fromVector(init)); + auto a2 = graph->param("test2", {2, 2, 4}, inits::fromVector(init)); + auto gamma = graph->param("gamma", {1, 4}, inits::ones()); + + auto rms = rmsNorm(a1, gamma, nullptr, 1e-5f); + auto rms2 = gamma * (a2 / sqrt(mean(a2 * a2, /*axis=*/-1) + 1e-5f)); + + auto top = sum(flatten(rms + rms2)); + + graph->forward(); + graph->backward(); + + CHECK(rms->shape() == Shape({2, 2, 4})); + + std::vector values2; + + // compare values of rms and rms2 to make sure forward computation is correct + rms->val()->get(values); + rms2->val()->get(values2); + + CHECK( std::equal(values.begin(), values.end(), + values2.begin(), floatApprox) ); + + // compare adjoints of a1 and a2 (parameters) to makes sure gradient computation is correct + a1->grad()->get(values); + a2->grad()->get(values2); + + CHECK( std::equal(values.begin(), values.end(), + values2.begin(), floatApprox) ); + + } + SECTION("reductions") { graph->clear(); values.clear(); -- cgit v1.2.3