diff options
author | Soumith Chintala <soumith@gmail.com> | 2017-02-21 17:13:59 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-02-21 17:13:59 +0300 |
commit | e2084c1f34291b0d51f7a596bcf7bea5b981f106 (patch) | |
tree | 07ea0737213ee60a07f59349e548e5cb16bb4023 /lib | |
parent | 6d1580294f162478e3125f78460023adfc88eca5 (diff) | |
parent | 7571794989afaa2c4ca14f345a2b42b3b4d55844 (diff) |
Merge pull request #434 from bottler/master
VolumetricFractionalMaxPooling like spatial
Diffstat (limited to 'lib')
-rw-r--r-- | lib/THCUNN/VolumetricFractionalMaxPooling.cu | 120 | ||||
-rw-r--r-- | lib/THCUNN/generic/THCUNN.h | 18 | ||||
-rw-r--r-- | lib/THCUNN/generic/VolumetricFractionalMaxPooling.cu | 168 |
3 files changed, 306 insertions, 0 deletions
diff --git a/lib/THCUNN/VolumetricFractionalMaxPooling.cu b/lib/THCUNN/VolumetricFractionalMaxPooling.cu new file mode 100644 index 0000000..e6260ce --- /dev/null +++ b/lib/THCUNN/VolumetricFractionalMaxPooling.cu @@ -0,0 +1,120 @@ +#include "THCUNN.h" +#include "common.h" +#include "THCDeviceTensor.cuh" +#include "THCDeviceTensorUtils.cuh" +#include "THCDeviceUtils.cuh" +#include "THCHalf.h" +#include "THCHalfAutoNumerics.cuh" +#include "THCAtomics.cuh" + +#include <cfloat> + +template <typename Dtype, typename Acctype> +__device__ inline int getInterval(Acctype sample, + int index, + int inputSize, + int outputSize, + int poolSize) { + Acctype alpha = (Acctype)(inputSize - poolSize) / (Acctype) (outputSize - 1); + if (index == outputSize - 1) { + return inputSize - poolSize; + } else { + return (int) ((index + sample) * alpha) - (int) (sample * alpha); + } +} + +// We template on poolSizeW to allow the innermost loop to be unrolled +template <int PoolSizeTStatic, typename Dtype, typename Acctype> +__global__ void VolumetricFractionalMaxPooling_updateOutput( + THCDeviceTensor<Dtype, 5> input, + THCDeviceTensor<Dtype, 5> output, + THCDeviceTensor<THCIndex_t, 5> indices, + THCDeviceTensor<Dtype, 3> samples, + int poolSizeT, int poolSizeW, int poolSizeH) { + + // Output (h, w) point that this thread is responsible for + int ourOutputPoint = threadIdx.x + blockIdx.x * blockDim.x; + int plane = blockIdx.y; + int batch = blockIdx.z; + + // Each thread generates a specific output point + if (ourOutputPoint < output.getSize(2) * output.getSize(3) * output.getSize(4)){ + int outputT = ourOutputPoint % output.getSize(4); + int outputW = (ourOutputPoint / output.getSize(4)) % output.getSize(3); + int outputH = ourOutputPoint / (output.getSize(3)*output.getSize(4)); + + int poolT = getInterval<Dtype, Acctype>(ScalarConvert<Dtype, Acctype>::to(samples[batch][plane][0]), outputT, + input.getSize(4), output.getSize(4), poolSizeT); + int poolW = getInterval<Dtype, Acctype>(ScalarConvert<Dtype, Acctype>::to(samples[batch][plane][1]), outputW, + input.getSize(3), output.getSize(3), poolSizeW); + int poolH = getInterval<Dtype, Acctype>(ScalarConvert<Dtype, Acctype>::to(samples[batch][plane][2]), outputH, + input.getSize(2), output.getSize(2), poolSizeH); + + Dtype maxVal = THCNumerics<Dtype>::min(); + int maxIndex = -1; + + for (int h = poolH; h < poolH + poolSizeH; ++h) { + for (int w = poolW; w < poolW + poolSizeW; ++w) { + if (PoolSizeTStatic == -1) { + for (int t = poolT; t < poolT + poolSizeT; ++t) { + Dtype val = input[batch][plane][h][w][t]; + // for consistency with THNN, favor the first max + if (val > maxVal) { + maxIndex = h * input.getSize(3)*input.getSize(4) + w * input.getSize(4) + t; + maxVal = val; + } + } + } else { +#pragma unroll + for (int i = 0; i < PoolSizeTStatic; ++i) { + int t = i + poolT; + Dtype val = input[batch][plane][h][w][t]; + // for consistency with THNN, favor the first max + if (val > maxVal) { + maxIndex = h * input.getSize(3)*input.getSize(4) + w * input.getSize(4) + t; + maxVal = val; + } + } + } + } + } + + assert(THCNumerics<Dtype>::ne(maxVal, THCNumerics<Dtype>::min())); + assert(maxIndex != -1); + + // +1 for Lua index + indices[batch][plane][outputH][outputW][outputT] = maxIndex + TH_INDEX_BASE; + output[batch][plane][outputH][outputW][outputT] = maxVal; + } +} + +template <typename Dtype> +__global__ void VolumetricFractionalMaxPooling_updateGradInput( + THCDeviceTensor<Dtype, 5> gradInput, + THCDeviceTensor<Dtype, 5> gradOutput, + THCDeviceTensor<THCIndex_t, 5> indices) { + // Output (h, w) point that this thread is responsible for + int ourOutputPoint = threadIdx.x + blockIdx.x * blockDim.x; + int plane = blockIdx.y; + int batch = blockIdx.z; + + // Each thread generates a specific output point + if (ourOutputPoint < gradOutput.getSize(2) * gradOutput.getSize(3) * gradOutput.getSize(4)) { + int outputT = ourOutputPoint % gradOutput.getSize(4); + int outputW = (ourOutputPoint / gradOutput.getSize(4)) % gradOutput.getSize(3); + int outputH = ourOutputPoint / (gradOutput.getSize(3)*gradOutput.getSize(4)); + + int index = indices[batch][plane][outputH][outputW][outputT] - TH_INDEX_BASE; + assert(index >= 0); + int inputT = index % gradInput.getSize(4); + int inputW = (index / gradInput.getSize(4)) % gradInput.getSize(3); + int inputH = index / (gradInput.getSize(3) * gradInput.getSize(4)); + assert(inputH < gradInput.getSize(2)); + + atomicAdd(gradInput[batch][plane][inputH][inputW][inputT].data(), + gradOutput[batch][plane][outputH][outputW][outputT]); + } +} + +#include "generic/VolumetricFractionalMaxPooling.cu" +#include "THCGenerateFloatTypes.h" diff --git a/lib/THCUNN/generic/THCUNN.h b/lib/THCUNN/generic/THCUNN.h index 930f4de..fabb8e9 100644 --- a/lib/THCUNN/generic/THCUNN.h +++ b/lib/THCUNN/generic/THCUNN.h @@ -1113,6 +1113,24 @@ TH_API void THNN_(VolumetricDilatedMaxPooling_updateGradInput)( int dilationT, int dilationW, int dilationH, bool ceilMode); +TH_API void THNN_(VolumetricFractionalMaxPooling_updateOutput)( + THCState *state, + THCTensor *input, + THCTensor *output, + int outputT, int outputW, int outputH, + int poolSizeT, int poolSizeW, int poolSizeH, + THCIndexTensor *indices, + THCTensor *randomSamples); + +TH_API void THNN_(VolumetricFractionalMaxPooling_updateGradInput)( + THCState *state, + THCTensor *input, + THCTensor *gradOutput, + THCTensor *gradInput, + int outputT, int outputW, int outputH, + int poolSizeT, int poolSizeW, int poolSizeH, + THCIndexTensor *indices); + TH_API void THNN_(VolumetricFullConvolution_updateOutput)( THCState *state, THCTensor *input, diff --git a/lib/THCUNN/generic/VolumetricFractionalMaxPooling.cu b/lib/THCUNN/generic/VolumetricFractionalMaxPooling.cu new file mode 100644 index 0000000..cbc9a11 --- /dev/null +++ b/lib/THCUNN/generic/VolumetricFractionalMaxPooling.cu @@ -0,0 +1,168 @@ +#ifndef THC_GENERIC_FILE +#define THC_GENERIC_FILE "generic/VolumetricFractionalMaxPooling.cu" +#else + +void THNN_(VolumetricFractionalMaxPooling_updateOutput)( + THCState *state, + THCTensor *input, + THCTensor *output, + int outputT, int outputW, int outputH, + int poolSizeT, int poolSizeW, int poolSizeH, + THCIndexTensor *indices, + THCTensor *randomSamples) +{ + int planeDim = 0; + int dimh = 1; + int dimw = 2; + int dimt = 3; + long numBatch = 1; + + long numInputDims = THCTensor_(nDimension)(state, input); + THCUNN_argCheck(state, numInputDims == 4 || numInputDims == 5, 2, input, + "4D or 5D (batch mode) tensor expected for input, but got: %s"); + + if (numInputDims == 5) { + numBatch = THCTensor_(size)(state, input, 0); + planeDim++; + dimh++; + dimw++; + dimt++; + } + + /* sizes */ + long numPlanes = THCTensor_(size)(state, input, planeDim); + long inputH = THCTensor_(size)(state, input, dimh); + long inputW = THCTensor_(size)(state, input, dimw); + long inputT = THCTensor_(size)(state, input, dimt); + + THArgCheck(outputH + poolSizeH - 1 < inputH, 7, + "poolSizeH (%d) too large relative to input height (%d)", + poolSizeH, inputH); + THArgCheck(outputW + poolSizeW - 1 < inputW, 6, + "poolSizeW (%d) too large relative to input width (%d)", + poolSizeW, inputW); + THArgCheck(outputT + poolSizeT - 1 < inputW, 5, + "poolSizeT (%d) too large relative to input time (%d)", + poolSizeT, inputT); + + THCDeviceTensor<real, 5> devInput; + THCDeviceTensor<real, 5> devOutput; + THCDeviceTensor<THCIndex_t, 5> devIndices; + THCDeviceTensor<real, 3> devSamples = + toDeviceTensor<real, 3>(state, randomSamples); + + if (numInputDims == 4) { + /* resize output */ + THCTensor_(resize4d)(state, output, numPlanes, outputH, outputW, outputT); + /* indices will contain the locations for each output point */ + THCIndexTensor_(resize4d)(state, indices, numPlanes, outputH, outputW, outputT); + + devInput = toDeviceTensor<real, 4>(state, input).upcastOuter<5>(); + devOutput = toDeviceTensor<real, 4>(state, output).upcastOuter<5>(); + devIndices = toDeviceTensor<THCIndex_t, 4>(state, indices).upcastOuter<5>(); + } else { + THCTensor_(resize5d)(state, output, numBatch, numPlanes, outputH, outputW, outputT); + /* indices will contain the locations for each output point */ + THCIndexTensor_(resize5d)(state, indices, numBatch, numPlanes, outputH, outputW, outputT); + + devInput = toDeviceTensor<real, 5>(state, input); + devOutput = toDeviceTensor<real, 5>(state, output); + devIndices = toDeviceTensor<THCIndex_t, 5>(state, indices); + } + + // block is limited to 4 warps + // grid handles overflow per each plane + int outputPlaneSize = devOutput.getSize(2) * devOutput.getSize(3) * devOutput.getSize(4); + dim3 grid(THCCeilDiv(outputPlaneSize, 128), + devInput.getSize(1), + devInput.getSize(0)); + dim3 block(outputPlaneSize > 128 ? 128 : outputPlaneSize); + +#define SFMP_UPDATE_OUTPUT(POOL_W) \ + VolumetricFractionalMaxPooling_updateOutput<POOL_W, real, accreal> \ + <<<grid, block, 0, THCState_getCurrentStream(state)>>>( \ + devInput, devOutput, devIndices, devSamples, poolSizeT, poolSizeW, poolSizeH); + +#define SFMP_UPDATE_OUTPUT_CASE(POOL_W) \ + case POOL_W: SFMP_UPDATE_OUTPUT(POOL_W); break + + switch (poolSizeW) { + SFMP_UPDATE_OUTPUT_CASE(2); + SFMP_UPDATE_OUTPUT_CASE(3); + SFMP_UPDATE_OUTPUT_CASE(4); + SFMP_UPDATE_OUTPUT_CASE(5); + SFMP_UPDATE_OUTPUT_CASE(6); + SFMP_UPDATE_OUTPUT_CASE(7); + default: + // dynamic pool width + SFMP_UPDATE_OUTPUT_CASE(-1); + } + THCudaCheck(cudaGetLastError()); +} + +void THNN_(VolumetricFractionalMaxPooling_updateGradInput)( + THCState *state, + THCTensor *input, + THCTensor *gradOutput, + THCTensor *gradInput, + int outputT, int outputW, int outputH, + int poolSizeT, int poolSizeW, int poolSizeH, + THCIndexTensor *indices) +{ + int dimh = 1; + int dimw = 2; + int dimt = 3; + + long numInputDims = THCTensor_(nDimension)(state, input); + if (numInputDims == 5) { + dimh++; + dimw++; + dimt++; + } + + /* sizes */ + long inputH = THCTensor_(size)(state, input, dimh); + long inputW = THCTensor_(size)(state, input, dimw); + long inputT = THCTensor_(size)(state, input, dimt); + + THArgCheck(outputH == THCTensor_(size)(state, gradOutput, dimh), 3, + "gradOutput height unexpected"); + THArgCheck(outputW == THCTensor_(size)(state, gradOutput, dimw), 3, + "gradOutput width unexpected"); + THArgCheck(outputT == THCTensor_(size)(state, gradOutput, dimt), 3, + "gradOutput time unexpected"); + + /* resize */ + THCTensor_(resizeAs)(state, gradInput, input); + THCTensor_(zero)(state, gradInput); + + THCDeviceTensor<real, 5> devGradInput; + THCDeviceTensor<real, 5> devGradOutput; + THCDeviceTensor<THCIndex_t, 5> devIndices; + + /* backprop */ + if (numInputDims == 4) { + devGradInput = toDeviceTensor<real, 4>(state, gradInput).upcastOuter<5>(); + devGradOutput = toDeviceTensor<real, 4>(state, gradOutput).upcastOuter<5>(); + devIndices = toDeviceTensor<THCIndex_t, 4>(state, indices).upcastOuter<5>(); + } else { + devGradInput = toDeviceTensor<real, 5>(state, gradInput); + devGradOutput = toDeviceTensor<real, 5>(state, gradOutput); + devIndices = toDeviceTensor<THCIndex_t, 5>(state, indices); + } + + // block is limited to 4 warps + // grid handles overflow per each plane + int outputPlaneSize = devGradOutput.getSize(2) * devGradOutput.getSize(3) * devGradOutput.getSize(4); + dim3 grid(THCCeilDiv(outputPlaneSize, 128), + devGradInput.getSize(1), + devGradInput.getSize(0)); + dim3 block(outputPlaneSize > 128 ? 128 : outputPlaneSize); + + VolumetricFractionalMaxPooling_updateGradInput + <<<grid, block, 0, THCState_getCurrentStream(state)>>>( + devGradInput, devGradOutput, devIndices); + THCudaCheck(cudaGetLastError()); +} + +#endif |