diff options
Diffstat (limited to 'src/tensors/gpu/tensor_operators.cu')
-rw-r--r-- | src/tensors/gpu/tensor_operators.cu | 147 |
1 files changed, 110 insertions, 37 deletions
diff --git a/src/tensors/gpu/tensor_operators.cu b/src/tensors/gpu/tensor_operators.cu index d55214bc..1347c3bb 100644 --- a/src/tensors/gpu/tensor_operators.cu +++ b/src/tensors/gpu/tensor_operators.cu @@ -16,15 +16,12 @@ namespace gpu { namespace atomics { static inline __device__ void atomicAdd(float *address, float val) { - //*address += val; ::atomicAdd(address, val); } #if COMPILE_FP16 // @TODO: copied from CuTorch, adapt this better, give credit. static inline __device__ void atomicAdd(half *address, half val) { - //*address += val; - #if __CUDA_ARCH__ >= 700 && CUDA_VERSION >= 10000 // compute capability 70 and higher with CUDA 10 ::atomicAdd(address, val); #else // __CUDA_ARCH__ < 700 @@ -50,7 +47,8 @@ static inline __device__ void atomicAdd(half *address, half val) { } while (assumed != old); #endif // __CUDA_ARCH__ } -#endif +#endif // COMPILE_FP16 + } @@ -96,6 +94,81 @@ void IsNaN(const Tensor in, Ptr<Allocator> allocator, bool& isNaN, bool& isInf) cudaStreamSynchronize(0); } +template <typename T> +__global__ void gSanitizeGradient(T* in, int length, + bool* isNaN, bool* isInf, + bool pruneNaN, bool clipInf, + float forNaN = 0.f, float forInf = 65504.f, float forInfNeg = -65504.f) { + for(int bid = 0; bid < length; bid += blockDim.x * gridDim.x) { + int index = bid + blockDim.x * blockIdx.x + threadIdx.x; + if(index < length) { + float v = (float)in[index]; + // handle NaN + if(isnan(v)) { + if(pruneNaN) { + in[index] = (T)forNaN; + } else { + *isNaN = true; + } + } + // handle +/- Inf + if(isinf(v)) { + if(clipInf) { + in[index] = v > 0 ? (T)forInf : (T)forInfNeg; + } else { + *isInf = true; + } + } + } + } +} + +// This function is meant to clean gradients, i.e. clip infinities and prune NaNs if required. +// If all NaNs and Infs have been removed we return `true` for indicating a sane gradient. +// If `clipInf` is set, infinities are replaced with the maximum/minimum non-inf value for the tensor. +// In that case infinities do not result in a bad gradient, since they get clipped. +// If `pruneNaN` is set, NaNs are replaced with 0. Since NaNs get removed now they do not result +// in a bad gradient. +// If NaNs or infinities are detected but not removed (either because of `pruneNaN=false` or `clipInf=false`), +// we return `false` indicating a bad gradient. +bool SanitizeGradient(marian::Tensor in, Ptr<Allocator> allocator, bool pruneNaN, bool clipInf) { + cudaSetDevice(in->getDeviceId().no); + + int length = in->size(); + + int threads = std::min(MAX_THREADS, length); + int blocks = std::min(MAX_BLOCKS, length / threads + (length % threads != 0)); + + auto mem = allocator->alloc<bool>(2); + bool* dIsNaN = &mem->data<bool>()[0]; + bool* dIsInf = &mem->data<bool>()[1]; + fill(in->getBackend(), dIsNaN, dIsNaN + 2, false); + + float forNaN = 0.f; + float forInf = NumericLimits<float>(in->type()).max; + float forInfNeg = NumericLimits<float>(in->type()).lowest; + + if(in->type() == Type::float32) { + gSanitizeGradient<<<blocks, threads>>>(in->data<float>(), length, dIsNaN, dIsInf, pruneNaN, clipInf, forNaN, forInf, forInfNeg); +#if COMPILE_FP16 + } else if(in->type() == Type::float16) { + gSanitizeGradient<<<blocks, threads>>>(in->data<half>(), length, dIsNaN, dIsInf, pruneNaN, clipInf, forNaN, forInf, forInfNeg); +#endif + } else { + ABORT("gSanitizeGradient for type {} not implemented", in->type()); + } + + bool isNaN, isInf; + CudaCopy(dIsNaN, dIsNaN + 1, &isNaN); + CudaCopy(dIsInf, dIsInf + 1, &isInf); + + allocator->free(mem); + + cudaStreamSynchronize(0); + + return !isNaN && !isInf; +} + template <bool add, typename To, typename From> __global__ void gCopyCastTo(To* out, const From* in, int length) { for(int bid = 0; bid < length; bid += blockDim.x * gridDim.x) { @@ -1090,7 +1163,7 @@ void PasteRows(Tensor out, size_t rowsToCopy = indices->size(); int threads = std::min(MAX_THREADS, (int)cols); -#if 1 // @TODO: make this configurable with a 'deterministic' flag +#if 0 // @TODO: make this configurable with a 'deterministic' flag // If we only use one block, then each core operates on a different column, // hence the summation becomes deterministic. // However, we only use e.g. 512 cores out of possibly 3000+, so this will be @@ -1355,7 +1428,7 @@ __global__ void gGRUFastForward(T* out, for(int bid = 0; bid < rows; bid += gridDim.x) { int j = bid + blockIdx.x; if(j < rows) { - T m = !mask || mask[j]; + float m = !mask || mask[j]; T* rowOut = out + j * cols; const T* rowState = state + j * cols; @@ -1365,21 +1438,21 @@ __global__ void gGRUFastForward(T* out, for(int tid = 0; tid < cols; tid += blockDim.x) { int i = tid + threadIdx.x; if(i < cols) { - T r = functional::Ops<T>::sigmoid(xWrow[i] + sUrow[i] + b[i]); + float r = functional::Ops<float>::sigmoid((float)xWrow[i] + (float)sUrow[i] + (float)b[i]); int k = i + cols; - T z = functional::Ops<T>::sigmoid(xWrow[k] + sUrow[k] + b[k]); + float z = functional::Ops<float>::sigmoid((float)xWrow[k] + (float)sUrow[k] + (float)b[k]); int l = i + 2 * cols; - T h; + float h; if(final) - h = functional::Ops<T>::tanh(xWrow[l] + (sUrow[l] + b[l]) * r); + h = functional::Ops<float>::tanh((float)xWrow[l] + ((float)sUrow[l] + (float)b[l]) * r); else - h = functional::Ops<T>::tanh(xWrow[l] + sUrow[l] * r + b[l]); + h = functional::Ops<float>::tanh((float)xWrow[l] + (float)sUrow[l] * r + (float)b[l]); - T out = ((T)1.f - z) * h + z * rowState[i]; - rowOut[i] = m * out + ((T)1.f - m) * rowState[i]; + float out = (1.f - z) * h + z * (float)rowState[i]; + rowOut[i] = (T)(m * out + (1.f - m) * (float)rowState[i]); } } } @@ -1441,7 +1514,7 @@ __global__ void gGRUFastBackward(T* outState, for(int bid = 0; bid < rows; bid += gridDim.x) { int j = bid + blockIdx.x; if(j < rows) { - T m = !mask || mask[j]; + float m = !mask || mask[j]; T* rowOutState = outState + j * cols; T* rowOutXW = outXW + j * cols * 3; @@ -1459,56 +1532,56 @@ __global__ void gGRUFastBackward(T* outState, int k = i + cols; int l = i + 2 * cols; - T r = functional::Ops<T>::sigmoid(rowXW[i] + rowSU[i] + b[i]); - T z = functional::Ops<T>::sigmoid(rowXW[k] + rowSU[k] + b[k]); + float r = functional::Ops<float>::sigmoid((float)rowXW[i] + (float)rowSU[i] + (float)b[i]); + float z = functional::Ops<float>::sigmoid((float)rowXW[k] + (float)rowSU[k] + (float)b[k]); - T h; + float h; if(final) - h = functional::Ops<T>::tanh(rowXW[l] + (rowSU[l] + b[l]) * r); + h = functional::Ops<float>::tanh((float)rowXW[l] + ((float)rowSU[l] + (float)b[l]) * r); else - h = functional::Ops<T>::tanh(rowXW[l] + rowSU[l] * r + b[l]); + h = functional::Ops<float>::tanh((float)rowXW[l] + (float)rowSU[l] * r + (float)b[l]); - T adj = rowAdj[i]; + float adj = rowAdj[i]; - T t = ((T)1.f - z) * ((T)1.f - h * h); + float t = (1.f - z) * (1.f - h * h); // df/ds if(outState) - rowOutState[i] += (m * z - m + (T)1.f) * adj; + rowOutState[i] += (T)((m * z - m + 1.f) * adj); // df/d(xW_r) ... - T dfdxW_r = m * r * ((T)1.f - r) * t * adj; + float dfdxW_r = m * r * (1.f - r) * t * adj; if(final) - dfdxW_r *= rowSU[l] + b[l]; + dfdxW_r *= (float)rowSU[l] + (float)b[l]; else - dfdxW_r *= rowSU[l]; + dfdxW_r *= (float)rowSU[l]; if(outXW) - rowOutXW[i] += dfdxW_r; + rowOutXW[i] += (T)dfdxW_r; if(outSU) - rowOutSU[i] += dfdxW_r; + rowOutSU[i] += (T)dfdxW_r; if(outB) - rowOutB[i] += dfdxW_r; + rowOutB[i] += (T)dfdxW_r; // df/d(xW_z) ... - T dfdxW_z = m * ((T)1.f - z) * z * (rowState[i] - h) * adj; + float dfdxW_z = m * (1.f - z) * z * ((float)rowState[i] - h) * adj; if(outXW) - rowOutXW[k] += dfdxW_z; + rowOutXW[k] += (T)dfdxW_z; if(outSU) - rowOutSU[k] += dfdxW_z; + rowOutSU[k] += (T)dfdxW_z; if(outB) - rowOutB[k] += dfdxW_z; + rowOutB[k] += (T)dfdxW_z; // df/d(xW_x) ... - T dfdxW_x = m * t * adj; + float dfdxW_x = m * t * adj; if(outXW) - rowOutXW[l] += dfdxW_x; + rowOutXW[l] += (T)dfdxW_x; if(outSU) - rowOutSU[l] += dfdxW_x * r; + rowOutSU[l] += (T)(dfdxW_x * r); if(outB) if(final) - rowOutB[l] += dfdxW_x * r; + rowOutB[l] += (T)(dfdxW_x * r); else - rowOutB[l] += dfdxW_x; + rowOutB[l] += (T)dfdxW_x; } } } |