Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cunn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorsoumith <soumith@gmail.com>2016-06-06 02:32:22 +0300
committersoumith <soumith@gmail.com>2016-06-06 02:32:22 +0300
commitd95c58ec284cb44656e63eead6d481ce3018e3a4 (patch)
tree9fecb321806b42bb9d907597ff6aba57ebc58b26
parente4e0a8c94fcc20b832f1f74e58966b0b4e344252 (diff)
fixing Volumetric Average and Max Pooling for large inputs
-rw-r--r--lib/THCUNN/VolumetricAveragePooling.cu136
-rw-r--r--lib/THCUNN/VolumetricMaxPooling.cu100
2 files changed, 136 insertions, 100 deletions
diff --git a/lib/THCUNN/VolumetricAveragePooling.cu b/lib/THCUNN/VolumetricAveragePooling.cu
index 470ff50..9542232 100644
--- a/lib/THCUNN/VolumetricAveragePooling.cu
+++ b/lib/THCUNN/VolumetricAveragePooling.cu
@@ -6,12 +6,12 @@
__global__ void cuda_VolumetricAveragePooling_updateOutput(
THCDeviceTensor<float, 4> input, THCDeviceTensor<float, 4> output,
- int kT, int kH, int kW, int dT, int dH, int dW, float normFactor)
+ int kT, int kH, int kW, int dT, int dH, int dW, float normFactor, int offsetZ)
{
int oCol = blockIdx.x * blockDim.x + threadIdx.x;
int oRow = blockIdx.y * blockDim.y + threadIdx.y;
- int oFrame = blockIdx.z % output.getSize(1); // output frame/time
- int slice = blockIdx.z / output.getSize(1); // output slice/feature
+ int oFrame = (blockIdx.z + offsetZ) % output.getSize(1); // output frame/time
+ int slice = (blockIdx.z + offsetZ) / output.getSize(1); // output slice/feature
if (oRow < output.getSize(2) && oCol < output.getSize(3))
{
@@ -52,12 +52,12 @@ __global__ void cuda_VolumetricAveragePooling_updateOutput(
template<int KERNEL_WIDTH>
__global__ void cuda_VolumetricAveragePooling_updateOutput(
THCDeviceTensor<float, 4> input, THCDeviceTensor<float, 4> output,
- int kT, int kH, int dT, int dH, int dW, float normFactor)
+ int kT, int kH, int dT, int dH, int dW, float normFactor, int offsetZ)
{
int oCol = blockIdx.x * blockDim.x + threadIdx.x;
int oRow = blockIdx.y * blockDim.y + threadIdx.y;
- int oFrame = blockIdx.z % output.getSize(1); // output frame/time
- int slice = blockIdx.z / output.getSize(1); // output slice/feature
+ int oFrame = (blockIdx.z + offsetZ) % output.getSize(1); // output frame/time
+ int slice = (blockIdx.z + offsetZ) / output.getSize(1); // output slice/feature
if (oRow < output.getSize(2) && oCol < output.getSize(3))
{
@@ -94,7 +94,7 @@ __global__ void cuda_VolumetricAveragePooling_updateOutput(
#define LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(KW) case KW: \
cuda_VolumetricAveragePooling_updateOutput<KW><<<grid, block>>>( \
- cudaInput, cudaOutput, kT, kH, dT, dH, dW, normFactor); \
+ cudaInput, cudaOutput, kT, kH, dT, dH, dW, normFactor, offsetZ); \
break
@@ -178,45 +178,51 @@ void THNN_CudaVolumetricAveragePooling_updateOutput(
cudaOutput = toDeviceTensor<float, 5>(state, output).downcastOuter<4>();
}
+ int totalZ = outputTime * inputSlices * batchSize;
+ int offsetZ = 0;
dim3 block(32, 8);
- dim3 grid(THCCeilDiv(outputWidth, static_cast<int>(block.x)),
- THCCeilDiv(outputHeight, static_cast<int>(block.y)),
- outputTime * inputSlices * batchSize);
+ while (totalZ > 0) {
+ dim3 grid(THCCeilDiv(outputWidth, static_cast<int>(block.x)),
+ THCCeilDiv(outputHeight, static_cast<int>(block.y)),
+ totalZ > 65535 ? 65535 : totalZ);
- float normFactor = 1.0f / static_cast<float>(kT * kH * kW);
- switch (kW)
- {
- LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(1);
- LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(2);
- LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(3);
- LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(4);
- LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(5);
- LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(6);
- LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(7);
- default:
- cuda_VolumetricAveragePooling_updateOutput<<<grid, block>>>(
- cudaInput,
- cudaOutput,
- kT, kH, kW,
- dT, dH, dW,
- normFactor
- );
- break;
+ float normFactor = 1.0f / static_cast<float>(kT * kH * kW);
+ switch (kW)
+ {
+ LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(1);
+ LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(2);
+ LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(3);
+ LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(4);
+ LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(5);
+ LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(6);
+ LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(7);
+ default:
+ cuda_VolumetricAveragePooling_updateOutput<<<grid, block>>>(
+ cudaInput,
+ cudaOutput,
+ kT, kH, kW,
+ dT, dH, dW,
+ normFactor,
+ offsetZ
+ );
+ break;
+ }
+ totalZ -= 65535;
+ offsetZ += 65535;
+ THCudaCheck(cudaGetLastError());
}
- THCudaCheck(cudaGetLastError());
-
THCudaTensor_free(state, input);
}
__global__ void cuda_VolumetricAveragePooling_updateGradInput_Stride1(
THCDeviceTensor<float, 4> gradOutput,
THCDeviceTensor<float, 4> gradInput,
- int kT, int kH, int kW, float normFactor)
+ int kT, int kH, int kW, float normFactor, int offsetZ)
{
int iCol = blockIdx.x * blockDim.x + threadIdx.x;
int iRow = blockIdx.y * blockDim.y + threadIdx.y;
- int iFrame = blockIdx.z % gradInput.getSize(1); // input frame/time
- int slice = blockIdx.z / gradInput.getSize(1); // input slice/feature
+ int iFrame = (blockIdx.z + offsetZ) % gradInput.getSize(1); // input frame/time
+ int slice = (blockIdx.z + offsetZ) / gradInput.getSize(1); // input slice/feature
// guard against over-tiled threads
if (iRow < gradInput.getSize(2) && iCol < gradInput.getSize(3))
@@ -253,12 +259,12 @@ __global__ void cuda_VolumetricAveragePooling_updateGradInput_Stride1(
__global__ void cuda_VolumetricAveragePooling_updateGradInput_atomicAdd(
THCDeviceTensor<float, 4> gradOutput,
THCDeviceTensor<float, 4> gradInput,
- int kT, int kH, int kW, int dT, int dH, int dW)
+ int kT, int kH, int kW, int dT, int dH, int dW, int offsetZ)
{
int oCol = blockIdx.x * blockDim.x + threadIdx.x;
int oRow = blockIdx.y * blockDim.y + threadIdx.y;
- int oFrame = blockIdx.z % gradOutput.getSize(1); // gradOutput frame/time
- int slice = blockIdx.z / gradOutput.getSize(1); // gradOutput slice/feature
+ int oFrame = (blockIdx.z + offsetZ) % gradOutput.getSize(1); // gradOutput frame/time
+ int slice = (blockIdx.z + offsetZ) / gradOutput.getSize(1); // gradOutput slice/feature
// guard against over-tiled threads
if (oRow < gradOutput.getSize(2) && oCol < gradOutput.getSize(3))
@@ -281,12 +287,12 @@ __global__ void cuda_VolumetricAveragePooling_updateGradInput(
THCDeviceTensor<float, 4> gradOutput,
THCDeviceTensor<float, 4> gradInput,
int kT, int kH, int kW,
- int dT, int dH, int dW)
+ int dT, int dH, int dW, int offsetZ)
{
int oCol = blockIdx.x * blockDim.x + threadIdx.x;
int oRow = blockIdx.y * blockDim.y + threadIdx.y;
- int oFrame = blockIdx.z % gradOutput.getSize(1); // gradOutput frame/time
- int slice = blockIdx.z / gradOutput.getSize(1); // gradOutput slice/feature
+ int oFrame = (blockIdx.z + offsetZ) % gradOutput.getSize(1); // gradOutput frame/time
+ int slice = (blockIdx.z + offsetZ) / gradOutput.getSize(1); // gradOutput slice/feature
// guard against over-tiled threads
if (oRow < gradOutput.getSize(2) && oCol < gradOutput.getSize(3))
@@ -378,29 +384,43 @@ void THNN_CudaVolumetricAveragePooling_updateGradInput(
// specialization yields 3x speedup over the atomicAdd implementation.
if (dT == 1 && dH == 1 && dW == 1)
{
- dim3 grid(THCCeilDiv(inputWidth, static_cast<int>(block.x)),
- THCCeilDiv(inputHeight, static_cast<int>(block.y)),
- inputTime * inputSlices * batchSize);
- cuda_VolumetricAveragePooling_updateGradInput_Stride1<<<grid, block>>>(
- cudaGradOutput, cudaGradInput, kT, kH, kW, 1.0f/(kT * kH * kW));
+ int totalZ = inputTime * inputSlices * batchSize;
+ int offsetZ = 0;
+ while (totalZ > 0) {
+ dim3 grid(THCCeilDiv(inputWidth, static_cast<int>(block.x)),
+ THCCeilDiv(inputHeight, static_cast<int>(block.y)),
+ totalZ > 65535 ? 65535 : totalZ);
+ cuda_VolumetricAveragePooling_updateGradInput_Stride1<<<grid, block>>>(
+ cudaGradOutput, cudaGradInput, kT, kH, kW, 1.0f/(kT * kH * kW), offsetZ);
+ THCudaCheck(cudaGetLastError());
+ totalZ -= 65535;
+ offsetZ += 65535;
+ }
}
else
{
- dim3 grid(THCCeilDiv(outputWidth, static_cast<int>(block.x)),
- THCCeilDiv(outputHeight, static_cast<int>(block.y)),
- outputTime * inputSlices * batchSize);
- if (kernelsOverlap)
- {
- cuda_VolumetricAveragePooling_updateGradInput_atomicAdd<<<grid, block>>>(
- cudaGradOutput, cudaGradInput, kT, kH, kW, dT, dH, dW);
- }
- else
- {
- cuda_VolumetricAveragePooling_updateGradInput<<<grid, block>>>(
- cudaGradOutput, cudaGradInput, kT, kH, kW, dT, dH, dW);
+ int totalZ = outputTime * inputSlices * batchSize;
+ int offsetZ = 0;
+ while (totalZ > 0) {
+
+ dim3 grid(THCCeilDiv(outputWidth, static_cast<int>(block.x)),
+ THCCeilDiv(outputHeight, static_cast<int>(block.y)),
+ totalZ > 65535 ? 65535 : totalZ);
+ if (kernelsOverlap)
+ {
+ cuda_VolumetricAveragePooling_updateGradInput_atomicAdd<<<grid, block>>>(
+ cudaGradOutput, cudaGradInput, kT, kH, kW, dT, dH, dW, offsetZ);
+ }
+ else
+ {
+ cuda_VolumetricAveragePooling_updateGradInput<<<grid, block>>>(
+ cudaGradOutput, cudaGradInput, kT, kH, kW, dT, dH, dW, offsetZ);
+ }
+ THCudaCheck(cudaGetLastError());
+ totalZ -= 65535;
+ offsetZ += 65535;
}
}
- THCudaCheck(cudaGetLastError());
THCudaTensor_free(state, gradOutput);
}
diff --git a/lib/THCUNN/VolumetricMaxPooling.cu b/lib/THCUNN/VolumetricMaxPooling.cu
index 91baf5c..31c68b6 100644
--- a/lib/THCUNN/VolumetricMaxPooling.cu
+++ b/lib/THCUNN/VolumetricMaxPooling.cu
@@ -12,12 +12,12 @@ __global__ void cuda_VolumetricMaxPooling_updateOutput(
THCDeviceTensor<float, 4> output,
int kT, int kH, int kW,
int dT, int dH, int dW,
- int padT, int padH, int padW)
+ int padT, int padH, int padW, int offsetZ)
{
int oColumn = blockIdx.x * blockDim.x + threadIdx.x;
int oRow = blockIdx.y * blockDim.y + threadIdx.y;
- int oFrame = blockIdx.z % output.getSize(1); // output frame/time
- int slice = blockIdx.z / output.getSize(1); // output slice/feature
+ int oFrame = (blockIdx.z + offsetZ) % output.getSize(1); // output frame/time
+ int slice = (blockIdx.z + offsetZ) / output.getSize(1); // output slice/feature
if (oRow < output.getSize(2) && oColumn < output.getSize(3))
{
@@ -74,12 +74,12 @@ __global__ void cuda_VolumetricMaxPooling_updateOutput(
THCDeviceTensor<float, 4> output,
int kT, int kH,
int dT, int dH, int dW,
- int padT, int padH, int padW)
+ int padT, int padH, int padW, int offsetZ)
{
int oColumn = blockIdx.x * blockDim.x + threadIdx.x;
int oRow = blockIdx.y * blockDim.y + threadIdx.y;
- int oFrame = blockIdx.z % output.getSize(1); // output frame/time
- int slice = blockIdx.z / output.getSize(1); // output slice/feature
+ int oFrame = (blockIdx.z + offsetZ) % output.getSize(1); // output frame/time
+ int slice = (blockIdx.z + offsetZ) / output.getSize(1); // output slice/feature
if (oRow < output.getSize(2) && oColumn < output.getSize(3))
{
@@ -130,10 +130,10 @@ __global__ void cuda_VolumetricMaxPooling_updateOutput(
}
}
-#define UPDATE_OUTPUT_KERNEL_WIDTH(KW) case KW: \
- cuda_VolumetricMaxPooling_updateOutput<KW><<<grid, block, \
- 0, THCState_getCurrentStream(state)>>>( \
- cudaInput, cudaIndices, cudaOutput, kT, kH, dT, dH, dW, padT, padH, padW); \
+#define UPDATE_OUTPUT_KERNEL_WIDTH(KW) case KW: \
+ cuda_VolumetricMaxPooling_updateOutput<KW><<<grid, block, \
+ 0, THCState_getCurrentStream(state)>>>( \
+ cudaInput, cudaIndices, cudaOutput, kT, kH, dT, dH, dW, padT, padH, padW, offsetZ); \
break
@@ -269,26 +269,35 @@ void THNN_CudaVolumetricMaxPooling_updateOutput(
THCDeviceTensor<float, 4> cudaIndices =
toDeviceTensor<float, 4>(state, indices1);
+ int totalZ = outputTime * inputSlices * batchSize;
+ int offsetZ = 0;
dim3 block(32, 8);
- dim3 grid(THCCeilDiv(outputWidth, static_cast<int>(block.x)),
- THCCeilDiv(outputHeight, static_cast<int>(block.y)),
- outputTime * inputSlices * batchSize);
- switch (kW)
- {
- UPDATE_OUTPUT_KERNEL_WIDTH(1);
- UPDATE_OUTPUT_KERNEL_WIDTH(2);
- UPDATE_OUTPUT_KERNEL_WIDTH(3);
- UPDATE_OUTPUT_KERNEL_WIDTH(4);
- UPDATE_OUTPUT_KERNEL_WIDTH(5);
- UPDATE_OUTPUT_KERNEL_WIDTH(6);
- UPDATE_OUTPUT_KERNEL_WIDTH(7);
- default:
- cuda_VolumetricMaxPooling_updateOutput<<<grid, block,
- 0, THCState_getCurrentStream(state)>>>(
- cudaInput, cudaIndices, cudaOutput, kT, kH, kW, dT, dH, dW, padT, padH, padW);
+ while (totalZ > 0) {
+ dim3 grid(THCCeilDiv(outputWidth, static_cast<int>(block.x)),
+ THCCeilDiv(outputHeight, static_cast<int>(block.y)),
+ totalZ > 65535 ? 65535 : totalZ);
+
+ switch (kW)
+ {
+ UPDATE_OUTPUT_KERNEL_WIDTH(1);
+ UPDATE_OUTPUT_KERNEL_WIDTH(2);
+ UPDATE_OUTPUT_KERNEL_WIDTH(3);
+ UPDATE_OUTPUT_KERNEL_WIDTH(4);
+ UPDATE_OUTPUT_KERNEL_WIDTH(5);
+ UPDATE_OUTPUT_KERNEL_WIDTH(6);
+ UPDATE_OUTPUT_KERNEL_WIDTH(7);
+ default:
+ cuda_VolumetricMaxPooling_updateOutput<<<grid, block,
+ 0, THCState_getCurrentStream(state)>>>(
+ cudaInput, cudaIndices, cudaOutput,
+ kT, kH, kW, dT, dH, dW,
+ padT, padH, padW, offsetZ);
+ }
+ THCudaCheck(cudaGetLastError());
+ totalZ -= 65535;
+ offsetZ += 65535;
}
- THCudaCheck(cudaGetLastError());
THCudaTensor_free(state, input);
THCudaTensor_free(state, indices1);
@@ -301,12 +310,12 @@ __global__ void cuda_VolumetricMaxPooling_updateGradInput(
THCDeviceTensor<float, 4> indices,
THCDeviceTensor<float, 4> gradInput,
int dT, int dH, int dW,
- int padT, int padH, int padW)
+ int padT, int padH, int padW, int offsetZ)
{
int oColumn = blockIdx.x * blockDim.x + threadIdx.x;
int oRow = blockIdx.y * blockDim.y + threadIdx.y;
- int oFrame = blockIdx.z % gradOutput.getSize(1); // output frame/time
- int slice = blockIdx.z / gradOutput.getSize(1); // output slice/feature
+ int oFrame = (blockIdx.z + offsetZ) % gradOutput.getSize(1); // output frame/time
+ int slice = (blockIdx.z + offsetZ) / gradOutput.getSize(1); // output slice/feature
if (oRow < gradOutput.getSize(2) && oColumn < gradOutput.getSize(3))
{
@@ -387,19 +396,26 @@ void THNN_CudaVolumetricMaxPooling_updateGradInput(
THCDeviceTensor<float, 4> cudaIndices =
toDeviceTensor<float, 4>(state, indices1);
+ int totalZ = outputTime * inputSlices * batchSize;
+ int offsetZ = 0;
dim3 block(32, 8);
- dim3 grid(THCCeilDiv(outputWidth, static_cast<int>(block.x)),
- THCCeilDiv(outputHeight, static_cast<int>(block.y)),
- outputTime * inputSlices * batchSize);
-
- cuda_VolumetricMaxPooling_updateGradInput<<<grid, block,
- 0, THCState_getCurrentStream(state)>>>(
- cudaGradOutput,
- cudaIndices,
- cudaGradInput,
- dT, dH, dW,
- padT, padH, padW);
- THCudaCheck(cudaGetLastError());
+
+ while (totalZ > 0) {
+ dim3 grid(THCCeilDiv(outputWidth, static_cast<int>(block.x)),
+ THCCeilDiv(outputHeight, static_cast<int>(block.y)),
+ totalZ > 65535 ? 65535 : totalZ);
+
+ cuda_VolumetricMaxPooling_updateGradInput<<<grid, block,
+ 0, THCState_getCurrentStream(state)>>>(
+ cudaGradOutput,
+ cudaIndices,
+ cudaGradInput,
+ dT, dH, dW,
+ padT, padH, padW, offsetZ);
+ THCudaCheck(cudaGetLastError());
+ totalZ -= 65535;
+ offsetZ += 65535;
+ }
// cleanup
THCudaTensor_free(state, gradOutput);