diff options
author | Soumith Chintala <soumith@gmail.com> | 2017-02-21 17:15:12 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-02-21 17:15:12 +0300 |
commit | 0d01aaad4435081262212b740f12fa534b23c526 (patch) | |
tree | 299aeb514bd8fa54bb0f86823c34a1bf6bccfce5 /lib | |
parent | e2084c1f34291b0d51f7a596bcf7bea5b981f106 (diff) | |
parent | 8a553b7ac6fc0bf8309dcd9bd494ac6039852593 (diff) |
Merge pull request #418 from ruotianluo/adaptiveAverage
Add SpatialAdaptiveAveragePooling.
Diffstat (limited to 'lib')
-rw-r--r-- | lib/THCUNN/SpatialAdaptiveAveragePooling.cu | 200 | ||||
-rw-r--r-- | lib/THCUNN/generic/SpatialAdaptiveAveragePooling.cu | 173 | ||||
-rw-r--r-- | lib/THCUNN/generic/THCUNN.h | 13 |
3 files changed, 386 insertions, 0 deletions
diff --git a/lib/THCUNN/SpatialAdaptiveAveragePooling.cu b/lib/THCUNN/SpatialAdaptiveAveragePooling.cu new file mode 100644 index 0000000..b1e5e5c --- /dev/null +++ b/lib/THCUNN/SpatialAdaptiveAveragePooling.cu @@ -0,0 +1,200 @@ +#include "THCUNN.h" +#include "THCHalf.h" +#include "THCHalfAutoNumerics.cuh" +#include "THCAtomics.cuh" + +#define START_IND(a,b,c) (int)floor((float)(a * c) / b) +#define END_IND(a,b,c) (int)ceil((float)((a + 1) * c) / b) +// #define START_IND(a,b,c) a * c / b +// #define END_IND(a,b,c) (a + 1) * c / b + ((a + 1) * c % b > 0)?1:0 + + +#define CUDA_MAX_THREADS 1024 // this is safe, in reality 256 is our limit + +/* + * Description: + * this function adaptively average pools an input 4D tensor along dimensions 2 and 3 + * 4D input, 4D output + */ + template <typename T> +__global__ void adaptiveaveragepool(T *input, T *output, + int input_n, int input_h, int input_w, + int output_h, int output_w, + int strideh, int stridew, + int strided) +{ + // iterators + int xx, yy; + + // compute offsets based on thread/block ID + int o = blockIdx.x; + int i = o; + //int k = blockIdx.x % input_n; + + int xx_start = threadIdx.x; + int xx_end = output_w; + const int xx_step = blockDim.x; + + int yy_start = blockDim.y*blockIdx.y + threadIdx.y; + int yy_end = output_h; + const int yy_step = blockDim.y*gridDim.y; + // select input/output plane + output = output + o*output_w*output_h; + input = input + i*strided; + + // For all output pixels... + for(yy = yy_start; yy < yy_end; yy+=yy_step) { + + int y_start = START_IND(yy, output_h, input_h); + int y_end = END_IND(yy, output_h, input_h); + int kH = y_end-y_start; + + for(xx = xx_start; xx < xx_end; xx+=xx_step) { + + int x_start = START_IND(xx, output_w, input_w); + int x_end = END_IND(xx, output_w, input_w); + int kW = x_end-x_start; + + // Compute the average pooling + T *ptr_input = input + y_start*strideh + x_start*stridew; + T *ptr_output = output + yy*output_w + xx; + T sum = ScalarConvert<int, T>::to(0); + int kx, ky; + for(ky = 0; ky < kH; ++ky) { + for(kx = 0; kx < kW; ++kx) { + T val = ptr_input[kx*stridew]; + sum += val; + } + ptr_input += strideh; // next input line + } + // Update output + *ptr_output = sum / kH / kW; + } + } +} + +/* + * Description: + * this function computes the gradInput from gradOutput + */ + template <typename T> +__global__ void adaptiveaveragegradinput( + T *gradInput, T *gradOutput, + int input_n, int input_h, int input_w, int output_h, int output_w +) +{ + // iterators + int x, y; + + // compute offsets based on thread/block ID + int o = blockIdx.x; + int i = o; + + int x_start = threadIdx.x; + int x_end = input_w; + int x_step = blockDim.x; + + int y_start = blockDim.y*blockIdx.y + threadIdx.y; + int y_end = input_h; + int y_step = blockDim.y*gridDim.y; + + // select input/output plane + gradOutput = gradOutput + o*output_w*output_h; + gradInput = gradInput + i*input_w*input_h; + + // compute gradInput + for(y = y_start; y < y_end; y+=y_step) { + + int yy_start = START_IND(y, input_h, output_h); + int yy_end = END_IND(y, input_h, output_h); + int kH = yy_end-yy_start; + + for(x = x_start; x < x_end; x+=x_step) { + + int xx_start = START_IND(x, input_w, output_w); + int xx_end = END_IND(x, input_w, output_w); + int kW = xx_end-xx_start; + + // Compute the gradients + T *ptr_gradInput = gradInput + y*input_w + x; + T *ptr_gradOutput = gradOutput + yy_start*output_w + xx_start; + + int kx, ky; + for(ky = 0; ky < kH; ++ky) { + int yy = yy_start + ky; + int kkH = START_IND(yy, output_h, input_h) - END_IND(yy, output_h, input_h); + for(kx = 0; kx < kW; ++kx) { + int xx = xx_start + kx; + int kkW = START_IND(xx, output_w, input_w) - END_IND(xx, output_w, input_w); + T z = ptr_gradOutput[kx + ky*output_w] / kkW / kkH; + *ptr_gradInput += z; + } + } + } + } +} + +/* + * Description: + * this function computes the gradInput from gradOutput + * (uses atomic add) + */ + template <typename T> +__global__ void atomicadaptiveaveragegradinput( + T *gradInput, T *gradOutput, + int input_n, int input_h, int input_w, int output_h, int output_w +) +{ + // iterators + int xx, yy; + + // compute offsets based on thread/block ID + int o = blockIdx.x; + int i = o; + + int xx_start = threadIdx.x; + int xx_end = output_w; + int xx_step = blockDim.x; + + int yy_start = blockDim.y*blockIdx.y + threadIdx.y; + int yy_end = output_h; + int yy_step = blockDim.y*gridDim.y; + + // select input/output plane + gradOutput = gradOutput + o*output_w*output_h; + gradInput = gradInput + i*input_w*input_h; + + // compute gradInput + for(yy = yy_start; yy < yy_end; yy+=yy_step) { + + int y_start = START_IND(yy, output_h, input_h); + int y_end = END_IND(yy, output_h, input_h); + int kH = y_end-y_start; + + for(xx = xx_start; xx < xx_end; xx+=xx_step) { + + int x_start = START_IND(xx, output_w, input_w); + int x_end = END_IND(xx, output_w, input_w); + int kW = x_end-x_start; + + // Compute the gradients + T *ptr_gradInput = gradInput + y_start*input_w + x_start; + T *ptr_gradOutput = gradOutput + yy*output_w + xx; + T z = *ptr_gradOutput / kW / kH; + int kx, ky; + for(ky = 0; ky < kH; ++ky) { + for(kx = 0; kx < kW; ++kx) { + // atomic add since different threads could update same variable + atomicAdd(&(ptr_gradInput[kx + ky*input_w]), z); + } + } + } + } +} + +#include "generic/SpatialAdaptiveAveragePooling.cu" +#include "THCGenerateFloatTypes.h" + +#undef CUDA_MAX_THREADS +#undef START_IND +#undef END_IND diff --git a/lib/THCUNN/generic/SpatialAdaptiveAveragePooling.cu b/lib/THCUNN/generic/SpatialAdaptiveAveragePooling.cu new file mode 100644 index 0000000..e444d1f --- /dev/null +++ b/lib/THCUNN/generic/SpatialAdaptiveAveragePooling.cu @@ -0,0 +1,173 @@ +#ifndef THC_GENERIC_FILE +#define THC_GENERIC_FILE "generic/SpatialAdaptiveAveragePooling.cu" +#else + +#include "../common.h" + +void THNN_(SpatialAdaptiveAveragePooling_updateOutput)( + THCState *state, + THCTensor *input, + THCTensor *output, + int nOutputCols, + int nOutputRows) +{ + THCUNN_assertSameGPU(state, 2, input, output); + + real *output_data; + real *input_data; + + THCUNN_argCheck(state, input->nDimension == 3 || input->nDimension == 4, 2, input, + "3D or 4D (batch mode) tensor expected for input, but got: %s"); + + if (input->nDimension == 3) { + long nInputCols = input->size[2]; + long nInputRows = input->size[1]; + long nInputPlane = input->size[0]; + + long istride_d = input->stride[0]; + long istride_h = input->stride[1]; + long istride_w = input->stride[2]; + + input_data = THCTensor_(data)(state, input); + + THCTensor_(resize3d)(state, output, nInputPlane, nOutputRows, nOutputCols); + + output_data = THCTensor_(data)(state, output); + + // cuda blocks & threads: + int yblocks = (int)(16L / nInputPlane); + yblocks = yblocks < 1 ? 1 : yblocks; + dim3 blocks(nInputPlane,yblocks); + dim3 threads(32,8); + + // run averagepool kernel + adaptiveaveragepool <<<blocks, threads, 0, THCState_getCurrentStream(state)>>> (input_data, output_data, + nInputPlane, nInputRows, nInputCols, nOutputRows, nOutputCols, + istride_h, istride_w, istride_d); + THCudaCheck(cudaGetLastError()); + + } else { + input = THCTensor_(newContiguous)(state, input); + long nInputCols = input->size[3]; + long nInputRows = input->size[2]; + long nInputPlane = input->size[1]; + long nbatch = input->size[0]; + + long istride_d = input->stride[1]; + long istride_h = input->stride[2]; + long istride_w = input->stride[3]; + + input_data = THCTensor_(data)(state, input); + + THCTensor_(resize4d)(state, output, nbatch, nInputPlane, nOutputRows, nOutputCols); + + output_data = THCTensor_(data)(state, output); + + // cuda blocks & threads: + int yblocks = (int)(16L / nInputPlane); + yblocks = yblocks < 1 ? 1 : yblocks; + dim3 blocks(nInputPlane*nbatch,yblocks); + dim3 threads(32,8); + + // run averagepool kernel + adaptiveaveragepool <<<blocks, threads, 0, THCState_getCurrentStream(state)>>> (input_data, output_data, + nInputPlane, nInputRows, nInputCols, nOutputRows, nOutputCols, + istride_h, istride_w, istride_d); + THCudaCheck(cudaGetLastError()); + // clean + THCTensor_(free)(state, input); + } +} + +void THNN_(SpatialAdaptiveAveragePooling_updateGradInput)( + THCState *state, + THCTensor *input, + THCTensor *gradOutput, + THCTensor *gradInput) +{ + bool atomic = true; // suboptimal, but without atomic it doesn't pass the tests + + THCUNN_assertSameGPU(state, 3, input, gradOutput, gradInput); + + real *gradInput_data; + real *gradOutput_data; + + gradOutput = THCTensor_(newContiguous)(state, gradOutput); + + if (input->nDimension == 3) { + long nInputCols = input->size[2]; + long nInputRows = input->size[1]; + long nInputPlane = input->size[0]; + long nOutputCols = gradOutput->size[2]; + long nOutputRows = gradOutput->size[1]; + + //bool atomic = (nInputCols%nOutputCols != 0) || (nInputRows%nOutputRows != 0); + + THCTensor_(resizeAs)(state, gradInput, input); + THCTensor_(zero)(state, gradInput); + + gradOutput_data = THCTensor_(data)(state, gradOutput); + gradInput_data = THCTensor_(data)(state, gradInput); + + // cuda blocks & threads: + int yblocks = (int)(16L / nInputPlane); + yblocks = yblocks < 1 ? 1 : yblocks; + dim3 blocks(nInputPlane,yblocks); + dim3 threads(32,8); + + if(atomic) + { + // run updateGradInput kernel, accumulate gradients atomically + atomicadaptiveaveragegradinput <<<blocks, threads, 0, THCState_getCurrentStream(state)>>> (gradInput_data, gradOutput_data, + nInputPlane, nInputRows, nInputCols, nOutputRows, nOutputCols); + } + else + { + // run updateGradInput kernel + adaptiveaveragegradinput <<<blocks, threads, 0, THCState_getCurrentStream(state)>>> (gradInput_data, gradOutput_data, + nInputPlane, nInputRows, nInputCols, nOutputRows, nOutputCols); + } + THCudaCheck(cudaGetLastError()); + } else { + long nInputCols = input->size[3]; + long nInputRows = input->size[2]; + long nInputPlane = input->size[1]; + long nbatch = input->size[0]; + long nOutputCols = gradOutput->size[3]; + long nOutputRows = gradOutput->size[2]; + + //bool atomic = //(nInputCols%nOutputCols != 0) || (nInputRows%nOutputRows != 0); + + THCTensor_(resizeAs)(state, gradInput, input); + THCTensor_(zero)(state, gradInput); + + gradOutput_data = THCTensor_(data)(state, gradOutput); + gradInput_data = THCTensor_(data)(state, gradInput); + + // cuda blocks & threads: + int yblocks = (int)(16L / nInputPlane); + yblocks = yblocks < 1 ? 1 : yblocks; + dim3 blocks(nInputPlane*nbatch,yblocks); + dim3 threads(32,8); + + if(atomic) + { + // run updateGradInput kernel, accumulate gradients atomically + atomicadaptiveaveragegradinput <<<blocks, threads, 0, THCState_getCurrentStream(state)>>> (gradInput_data, gradOutput_data, + nInputPlane, nInputRows, nInputCols, nOutputRows, nOutputCols); + } + else + { + // run updateGradInput kernel, accumulate gradients atomically + adaptiveaveragegradinput <<<blocks, threads, 0, THCState_getCurrentStream(state)>>> (gradInput_data, gradOutput_data, + nInputPlane, nInputRows, nInputCols, nOutputRows, nOutputCols); + } + THCudaCheck(cudaGetLastError()); + } + + // clean + THCTensor_(free)(state,gradOutput); + +} + +#endif diff --git a/lib/THCUNN/generic/THCUNN.h b/lib/THCUNN/generic/THCUNN.h index fabb8e9..a71209a 100644 --- a/lib/THCUNN/generic/THCUNN.h +++ b/lib/THCUNN/generic/THCUNN.h @@ -394,6 +394,19 @@ TH_API void THNN_(SpatialAdaptiveMaxPooling_updateGradInput)( THCTensor *gradInput, THCIndexTensor *indices); +TH_API void THNN_(SpatialAdaptiveAveragePooling_updateOutput)( + THCState *state, + THCTensor *input, + THCTensor *output, + int nOutputCols, + int nOutputRows); + +TH_API void THNN_(SpatialAdaptiveAveragePooling_updateGradInput)( + THCState *state, + THCTensor *input, + THCTensor *gradOutput, + THCTensor *gradInput); + TH_API void THNN_(SpatialAveragePooling_updateOutput)( THCState *state, THCTensor *input, |