Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cunn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2016-08-06 00:07:07 +0300
committerGitHub <noreply@github.com>2016-08-06 00:07:07 +0300
commit6409c15b2fa93bd6395a980af821ba29e288b213 (patch)
treea0fd374362ec724e2b89fe250f4bef7f00535bbe
parent7f27d6a6468f32f496863f182d3ae4510bdeaf4d (diff)
parent49fdc1ccad83049b4a9c11eada2c7b32469030cc (diff)
Merge pull request #314 from colesbury/bn
Fix "invalid configuration" when using very large batch sizes in evaluate mode
-rw-r--r--lib/THCUNN/BatchNormalization.cu11
1 files changed, 6 insertions, 5 deletions
diff --git a/lib/THCUNN/BatchNormalization.cu b/lib/THCUNN/BatchNormalization.cu
index 5540437..8bd5982 100644
--- a/lib/THCUNN/BatchNormalization.cu
+++ b/lib/THCUNN/BatchNormalization.cu
@@ -165,7 +165,6 @@ __global__ void BatchNormalizationUpdateOutputInference_kernel(
float epsilon) {
int plane = blockIdx.x;
- int batch = blockIdx.y;
float invstd = 1.0f / sqrt(runningVar[plane].ldg() + epsilon);
float mean = runningMean[plane].ldg();
@@ -173,9 +172,11 @@ __global__ void BatchNormalizationUpdateOutputInference_kernel(
float beta = bias.numElements() > 0 ? bias[plane].ldg() : 0.0f;
// Write normalized and update the output
- for (int x = threadIdx.x; x < input.getSize(2); x += blockDim.x) {
- float inp = input[batch][plane][x].ldg();
- output[batch][plane][x] = gamma * (inp - mean) * invstd + beta;
+ for (int batch = 0; batch < input.getSize(0); batch++) {
+ for (int x = threadIdx.x; x < input.getSize(2); x += blockDim.x) {
+ float inp = input[batch][plane][x].ldg();
+ output[batch][plane][x] = gamma * (inp - mean) * invstd + beta;
+ }
}
}
@@ -245,7 +246,7 @@ void THNN_CudaBatchNormalization_updateOutput(
cudaDeviceProp *prop = THCState_getCurrentDeviceProperties(state);
if (!train) {
- dim3 blocks(input.getSize(1), input.getSize(0));
+ dim3 blocks(input.getSize(1));
dim3 threads(getNumThreads(input.getSize(2)));
BatchNormalizationUpdateOutputInference_kernel<<<blocks, threads, 0, s>>>(
input, output, runningMean, runningVar, weight, bias, eps);