diff options
author | Soumith Chintala <soumith@gmail.com> | 2016-08-06 00:07:07 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2016-08-06 00:07:07 +0300 |
commit | 6409c15b2fa93bd6395a980af821ba29e288b213 (patch) | |
tree | a0fd374362ec724e2b89fe250f4bef7f00535bbe | |
parent | 7f27d6a6468f32f496863f182d3ae4510bdeaf4d (diff) | |
parent | 49fdc1ccad83049b4a9c11eada2c7b32469030cc (diff) |
Merge pull request #314 from colesbury/bn
Fix "invalid configuration" when using very large batch sizes in evaluate mode
-rw-r--r-- | lib/THCUNN/BatchNormalization.cu | 11 |
1 files changed, 6 insertions, 5 deletions
diff --git a/lib/THCUNN/BatchNormalization.cu b/lib/THCUNN/BatchNormalization.cu index 5540437..8bd5982 100644 --- a/lib/THCUNN/BatchNormalization.cu +++ b/lib/THCUNN/BatchNormalization.cu @@ -165,7 +165,6 @@ __global__ void BatchNormalizationUpdateOutputInference_kernel( float epsilon) { int plane = blockIdx.x; - int batch = blockIdx.y; float invstd = 1.0f / sqrt(runningVar[plane].ldg() + epsilon); float mean = runningMean[plane].ldg(); @@ -173,9 +172,11 @@ __global__ void BatchNormalizationUpdateOutputInference_kernel( float beta = bias.numElements() > 0 ? bias[plane].ldg() : 0.0f; // Write normalized and update the output - for (int x = threadIdx.x; x < input.getSize(2); x += blockDim.x) { - float inp = input[batch][plane][x].ldg(); - output[batch][plane][x] = gamma * (inp - mean) * invstd + beta; + for (int batch = 0; batch < input.getSize(0); batch++) { + for (int x = threadIdx.x; x < input.getSize(2); x += blockDim.x) { + float inp = input[batch][plane][x].ldg(); + output[batch][plane][x] = gamma * (inp - mean) * invstd + beta; + } } } @@ -245,7 +246,7 @@ void THNN_CudaBatchNormalization_updateOutput( cudaDeviceProp *prop = THCState_getCurrentDeviceProperties(state); if (!train) { - dim3 blocks(input.getSize(1), input.getSize(0)); + dim3 blocks(input.getSize(1)); dim3 threads(getNumThreads(input.getSize(2))); BatchNormalizationUpdateOutputInference_kernel<<<blocks, threads, 0, s>>>( input, output, runningMean, runningVar, weight, bias, eps); |