diff options
author | Soumith Chintala <soumith@gmail.com> | 2017-07-26 00:01:59 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-07-26 00:01:59 +0300 |
commit | 8d9e9562beb792e17c6614ec2b515094f9663776 (patch) | |
tree | 7e528a1e2674254b1f22cba2939c5b5d4818af80 | |
parent | b336dc940c513d0b42d2ef2940bec9199b4377cf (diff) | |
parent | 2c84b988883c98b5617dd24d06b633951afa0de6 (diff) |
Merge pull request #477 from wickedfoo/feature_lp_pooling
GPU implementation of L_p feature pooling
-rw-r--r-- | lib/THCUNN/FeatureLPPooling.cu | 653 | ||||
-rw-r--r-- | lib/THCUNN/generic/FeatureLPPooling.cu | 267 | ||||
-rw-r--r-- | lib/THCUNN/generic/THCUNN.h | 20 | ||||
-rw-r--r-- | test.lua | 143 |
4 files changed, 1083 insertions, 0 deletions
diff --git a/lib/THCUNN/FeatureLPPooling.cu b/lib/THCUNN/FeatureLPPooling.cu new file mode 100644 index 0000000..4ad190f --- /dev/null +++ b/lib/THCUNN/FeatureLPPooling.cu @@ -0,0 +1,653 @@ +#include "THCUNN.h" +#include "THCAtomics.cuh" +#include "THCDeviceTensor.cuh" +#include "THCDeviceTensorUtils.cuh" +#include "THCDeviceUtils.cuh" +#include "THCNumerics.cuh" +#include "THCTensorTypeUtils.cuh" + +#define OUTPUT_FEATURES_PER_THREAD 32 +#define MAX_WARPS_PER_RUN 4 + +namespace detail { + +/// Various utilities for dealing with arrays of values which are +/// maintained in thread-local registers. All accesses are done in such +/// a way such that the index is statically known, which preserves the +/// compiler's ability to allocate the values to registers, as opposed +/// to local memory. +template <typename T, int N> +struct RegisterUtils { + /// Register shifting: move elements towards the beginning of the + /// array (towards 0) by `Shift` places: + /// arr[i] = arr[i + Shift] + /// The `Shift` elements at the end are left unchanged. + template <int Shift> + __device__ __forceinline__ static void shiftLeft(T arr[N]) { + // e.g., N = 5, Shift = 2: + // 0 1 2 3 4 becomes => + // 2 3 4 3 4 (last are unchanged) +#pragma unroll + for (int i = 0; i < N - Shift; ++i) { + arr[i] = arr[i + Shift]; + } + } +}; + +template <typename T> +__device__ __forceinline__ +int getDim1Point(const THCDeviceTensor<T, 4>& input) { + int threadPoint = blockIdx.x * blockDim.x + threadIdx.x; + return threadPoint / input.getSize(3); +} + +template <typename T> +__device__ __forceinline__ +int getDim2Point(const THCDeviceTensor<T, 4>& input) { + int threadPoint = blockIdx.x * blockDim.x + threadIdx.x; + return threadPoint % input.getSize(3); +} + +__device__ __forceinline__ +int getStartOutputFeature() { + return blockIdx.y * OUTPUT_FEATURES_PER_THREAD; +} + +template <typename T> +__device__ __forceinline__ +int getEndOutputFeature(const THCDeviceTensor<T, 4>& output) { + return min((blockIdx.y + 1) * OUTPUT_FEATURES_PER_THREAD, output.getSize(1)); +} + +__device__ __forceinline__ +int getBatch() { + return blockIdx.z; +} + +// All of these functions that follow are MathOps; they are template +// parameters so L2 can be more efficiently implemented +// template <typename T> +// typedef T (*MathOp)(const T in, const T arg); + +template <typename T> +__device__ __forceinline__ T power2(const T in, const T power) { + return THCNumerics<T>::mul(in, in); +} + +template <typename T> +__device__ __forceinline__ T root2(const T in, const T power) { + return THCNumerics<T>::sqrt(in); +} + +template <typename T> +__device__ __forceinline__ T powerGrad2(const T in, const T power) { + return in; +} + +template <typename T> +__device__ __forceinline__ T powerN(const T in, const T power) { + return THCNumerics<T>::pow(in, power); +} + +template <typename T> +__device__ __forceinline__ T rootN(const T in, const T power) { + const T invPower = THCNumerics<T>::cinv(power); + return THCNumerics<T>::pow(in, invPower); +} + +template <typename T> +__device__ __forceinline__ T powerGradN(const T in, const T power) { + return THCNumerics<T>::pow(in, + THCNumerics<T>::sub(power, + ScalarConvert<int, T>::to(1))); +} + +// Input is of the form: +// [batch][feature dim][optional dim 1][optional dim 2] +template <typename T, + int Width, + int Stride, + T (*PowerFunc)(T in, T power), + T (*RootFunc)(T in, T power)> +__global__ void +featureLPPoolingUpdateOutput(const THCDeviceTensor<T, 4> input, + THCDeviceTensor<T, 4> output, + T power) { + // What non-feature points is this thread handling? + int dim1Point = getDim1Point(input); + int dim2Point = getDim2Point(input); + + if (dim1Point >= input.getSize(2) || dim2Point >= input.getSize(3)) { + // This thread in the warp is out of bounds + return; + } + + // What feature points is this thread handling? + int startOutputFeature = getStartOutputFeature(); + int endOutputFeature = getEndOutputFeature(output); + int startInputFeature = startOutputFeature * Stride; + + // What batch points is this thread handling? + int batch = getBatch(); + + // If stride >= width, then there is no loaded data reuse. + // If stride > 1 and stride < width, then shift by stride, since we + // can reuse Width - Stride elements from the previous round. + // e.g., width = 5, stride = 2, + // output 0 uses input 0 1 2 3 4 + // output 1 uses input 2 3 4 5 6 (inputs 2 - 4 are reused, i.e., 5 - + // 2 elements are reused, and we have to shift the array by 2) + // + // e.g., width = 5, stride = 3, + // output 0 uses input 0 1 2 3 4 + // output 1 uses input 3 4 5 6 7 (inputs 3 - 4 are reused, i.e., 5 - 3 + // elements are reused, and we have to shift the array by 3) + + // Valid only pooling: load Width elements from input (Width - + // Stride is handled here, at the top of the loop we handle the + // remaining Stride elements). We already verified that the input is + // larger than the width. + // `in` will contain the input values ^ power. + T in[Width]; + +#pragma unroll + for (int i = 0; i < Width - Stride; ++i) { + const T data = + input[batch][startInputFeature + i][dim1Point][dim2Point]; + in[i] = PowerFunc(data, power); + } + + for (int outputFeature = startOutputFeature; + outputFeature < endOutputFeature; + ++outputFeature) { + // If Stride < Width, we're loading Stride new values starting at + // Width - Stride + // If Stride >= Width, we're loading Width new values starting at 0 + if (Stride < Width) { + int nextInputFeature = outputFeature * Stride + Width - Stride; + +#pragma unroll + for (int i = 0; i < Stride; ++i) { + const T data = + input[batch][nextInputFeature + i][dim1Point][dim2Point]; + in[Width - Stride + i] = PowerFunc(data, power); + } + } else { + int nextInputFeature = outputFeature * Stride; + +#pragma unroll + for (int i = 0; i < Width; ++i) { + T data = input[batch][nextInputFeature + i][dim1Point][dim2Point]; + in[i] = PowerFunc(data, power); + } + } + + // Calculate the new output feature + T val = ScalarConvert<int, T>::to(0); + for (int i = 0; i < Width; ++i) { + val = THCNumerics<T>::add(val, in[i]); + } + + val = RootFunc(val, power); + output[batch][outputFeature][dim1Point][dim2Point] = val; + + if (Stride < Width) { + // Shift registers for calculating the next point + RegisterUtils<T, Width>::shiftLeft<Stride>(in); + } + } +} + +// forward pass: f(a, ..., z) = (a^p + ... + z^p)^(1 / p) +// for bprop: +// partial df(a, ... z)/da = a^(p - 1) * (a^p + ... + z^p)^((1 / p) - 1) = +// a^(p - 1) * 1/(f(a, ..., z)^(p - 1)) = (a / f(a, ..., z))^(p - 1) +// +// example: for p = 2, df(a, ..., z)/da = a / f(a, ..., z) +// example: for p = 3, df(a, ..., z)/da = (a / f(a, ..., z))^2 +// +// PowerGradFunc implements x^(p - 1) +template <typename T, + int Width, + int Stride, + T (*PowerGradFunc)(T in, T arg)> +__global__ void +featureLPPoolingUpdateGradInput(const THCDeviceTensor<T, 4> gradOutput, + const THCDeviceTensor<T, 4> input, + const THCDeviceTensor<T, 4> output, + THCDeviceTensor<T, 4> gradInput, + T power) { + // What non-feature points is this thread handling? + int dim1Point = getDim1Point(input); + int dim2Point = getDim2Point(input); + + if (dim1Point >= input.getSize(2) || dim2Point >= input.getSize(3)) { + // This thread in the warp is out of bounds + return; + } + + // What feature points is this thread handling? [start, end) + int startOutputFeature = getStartOutputFeature(); + int endOutputFeature = getEndOutputFeature(output); + + // What is the first input point that the output features depend + // upon? [start, end) + int startInputFeature = startOutputFeature * Stride; + int endInputFeature = endOutputFeature * Stride; + + // What batch points is this thread handling? + int batch = getBatch(); + + // atomicAdd into gradInput is slow, avoid it where possible. + // We can do this because there is a range of gradInput elements + // that we are updating exclusively. This is how we find it + // + // width = 3 stride = 1 example: + // ------------------------------ + // startOutputFeature for this thread + // | + // | + // previous thread's output feature + // | | + // | | gradOutput + // __v____v___________________ + // | | | | | | + // --------------------------- + // |\ \_____ + // | \__ \ gradInput + // __v____v____v_____________ + // | | | | | | + // --------------------------- + // A A + // | | + // startInputFeature + // | + // exclusiveStartInputFeature + // + // exclusiveStartInputFeature is the first input feature that we can + // write into exclusively; the one right before it overlaps with + // updates from a previous thread and thus has to use atomicAdd. + int exclusiveStartInputFeature = + startInputFeature == 0 ? + // no thread is before ourselves + 0 : + // there is a thread before ourselves + startInputFeature + (Width - 1) * Stride; + + // Similarly, exclusiveEndInputFeature is the last input feature + // that we can write into exclusively, since we might be overlapping + // with the following thread + int exclusiveEndInputFeature = + endOutputFeature == output.getSize(1) ? + // no thread is after ourselves + endInputFeature + (Width - 1) * Stride : + // there is a thread after ourselves + endInputFeature; + + // As with updateOutput preload input elements, except no need to + // transform them + T in[Width]; +#pragma unroll + for (int i = 0; i < Width - Stride; ++i) { + in[i] = input[batch][startInputFeature + i][dim1Point][dim2Point]; + } + + for (int outputFeature = startOutputFeature; + outputFeature < endOutputFeature; + ++outputFeature) { + // As with updateOutput load the subsequent input elements that we + // need, except no need to transform them + // + // If Stride < Width, we're loading Stride new values starting at + // Width - Stride + // If Stride >= Width, we're loading Width new values starting at 0 + if (Stride < Width) { + int nextInputFeature = outputFeature * Stride + Width - Stride; + +#pragma unroll + for (int i = 0; i < Stride; ++i) { + in[Width - Stride + i] = + input[batch][nextInputFeature + i][dim1Point][dim2Point]; + } + } else { + int nextInputFeature = outputFeature * Stride; + +#pragma unroll + for (int i = 0; i < Width; ++i) { + in[i] = input[batch][nextInputFeature + i][dim1Point][dim2Point]; + } + } + + // A given output feature gradient contributes to `Width` input + // gradients + const T gradOut = + gradOutput[batch][outputFeature][dim1Point][dim2Point]; + + // Load output (f(x_is)). It is possible that this is zero, in + // which case we'll ignore this point. + T out = output[batch][outputFeature][dim1Point][dim2Point]; + if (THCNumerics<T>::eq(out, ScalarConvert<int, T>::to(0))) { + continue; + } + + int curStartInputFeature = outputFeature * Stride; + int curEndInputFeature = outputFeature * Stride + Width - 1; + + if (curStartInputFeature >= exclusiveStartInputFeature && + curEndInputFeature < exclusiveEndInputFeature) { + // This thread is exclusively responsible for updating these + // input points, so we need not make the addition atomic + for (int i = 0; i < Width; ++i) { + int inputFeature = outputFeature * Stride + i; + + // Calculate grad * (x_i / f(x_is))^(p - 1) + const T val = THCNumerics<T>::mul( + gradOut, + PowerGradFunc(THCNumerics<T>::div(in[i], out), power)); + + gradInput[batch][inputFeature][dim1Point][dim2Point] = + THCNumerics<T>::add( + gradInput[batch][inputFeature][dim1Point][dim2Point], val); + } + } else { + // Handle start and end boundary cases: potential overlap with + // other threads + for (int i = 0; i < Width; ++i) { + int inputFeature = outputFeature * Stride + i; + + // Calculate grad * (x_i / f(x_is))^(p - 1) + T val = THCNumerics<T>::mul( + gradOut, + PowerGradFunc(THCNumerics<T>::div(in[i], out), power)); + + // We don't overlap other threads for this range + if (inputFeature >= exclusiveStartInputFeature && + inputFeature < exclusiveEndInputFeature) { + gradInput[batch][inputFeature][dim1Point][dim2Point] + = THCNumerics<T>::add( + gradInput[batch][inputFeature][dim1Point][dim2Point], val); + } else { + // We are potentially overlapping with threads handling + // features before ourselves, so these need to be added atomically + atomicAdd(&gradInput[batch][inputFeature][dim1Point][dim2Point], + val); + } + } + } + + if (Stride < Width) { + // Shift registers for calculating the next point + RegisterUtils<T, Width>::shiftLeft<Stride>(in); + } + } +} + +} // namespace detail + +inline int lpPoolingOutputSize(int inputSize, int width, int stride) { + return ((inputSize - width) / stride) + 1; +} + +template <typename T> +bool +runFeatureLPPoolingUpdateOutput(THCState* state, + const THCDeviceTensor<T, 4>& input, + THCDeviceTensor<T, 4>& output, + float power, int width, int stride) { + cudaStream_t stream = + THCState_getCurrentStream(state); + const cudaDeviceProp* deviceProperties = + THCState_getCurrentDeviceProperties(state); + + int outputFeatures = ((input.getSize(1) - width) / stride) + 1; + + THAssert(input.getSize(0) == output.getSize(0)); + THAssert(outputFeatures == output.getSize(1)); + THAssert(input.getSize(1) >= width); + + THAssert(input.getSize(2) == output.getSize(2)); + THAssert(input.getSize(3) == output.getSize(3)); + THAssert(power > 0.0f); + THAssert(width >= 1); + THAssert(stride >= 1); + + // Split non-features among threads and grid x + int totalNonFeatureSize = input.getSize(2) * input.getSize(3); + int numWarps = + min(THCCeilDiv(totalNonFeatureSize, deviceProperties->warpSize), + MAX_WARPS_PER_RUN); + int blockSize = deviceProperties->warpSize * numWarps; + + // Split non-features among grid x + int nonFeatureSizeBlocks = THCCeilDiv(totalNonFeatureSize, blockSize); + + // Split features among grid y, up to a maximum number of features per thread + int featureBlocks = THCCeilDiv(outputFeatures, OUTPUT_FEATURES_PER_THREAD); + + // Split batch among grid z. + dim3 grid(nonFeatureSizeBlocks, featureBlocks, input.getSize(0)); + dim3 block(blockSize); + +#define L2_STRIDE_CASE(STRIDE, WIDTH) \ + case STRIDE: \ + detail:: \ + featureLPPoolingUpdateOutput<T, WIDTH, \ + STRIDE, \ + detail::power2, \ + detail::root2><<<grid, block, 0, stream>>>( \ + input, output, \ + ScalarConvert<float, T>::to(power)); \ + return true; + +#define L2_WIDTH_CASE(WIDTH) \ + case WIDTH: \ + switch (stride) { \ + L2_STRIDE_CASE(1, WIDTH); \ + L2_STRIDE_CASE(2, WIDTH); \ + L2_STRIDE_CASE(3, WIDTH); \ + L2_STRIDE_CASE(4, WIDTH); \ + } + +#define LP_STRIDE_CASE(STRIDE, WIDTH) \ + case STRIDE: \ + detail:: \ + featureLPPoolingUpdateOutput<T, WIDTH, \ + STRIDE, \ + detail::powerN, \ + detail::rootN><<<grid, block, 0, stream>>>( \ + input, output, \ + ScalarConvert<float, T>::to(power)); \ + return true; + +#define LP_WIDTH_CASE(WIDTH) \ + case WIDTH: \ + switch (stride) { \ + LP_STRIDE_CASE(1, WIDTH); \ + LP_STRIDE_CASE(2, WIDTH); \ + LP_STRIDE_CASE(3, WIDTH); \ + LP_STRIDE_CASE(4, WIDTH); \ + } + + if (power == 2.0f) { + switch (width) { + L2_WIDTH_CASE(2); + L2_WIDTH_CASE(3); + L2_WIDTH_CASE(4); + L2_WIDTH_CASE(5); + L2_WIDTH_CASE(6); + L2_WIDTH_CASE(7); + L2_WIDTH_CASE(8); + L2_WIDTH_CASE(9); + L2_WIDTH_CASE(10); + L2_WIDTH_CASE(11); + L2_WIDTH_CASE(12); + L2_WIDTH_CASE(13); + L2_WIDTH_CASE(14); + L2_WIDTH_CASE(15); + L2_WIDTH_CASE(16); + } + } else { + switch (width) { + LP_WIDTH_CASE(2); + LP_WIDTH_CASE(3); + LP_WIDTH_CASE(4); + LP_WIDTH_CASE(5); + LP_WIDTH_CASE(6); + LP_WIDTH_CASE(7); + LP_WIDTH_CASE(8); + LP_WIDTH_CASE(9); + LP_WIDTH_CASE(10); + LP_WIDTH_CASE(11); + LP_WIDTH_CASE(12); + LP_WIDTH_CASE(13); + LP_WIDTH_CASE(14); + LP_WIDTH_CASE(15); + LP_WIDTH_CASE(16); + } + } + + // Otherwise, we have an unhandled width and/or stride. + return false; + +#undef L2_STRIDE_CASE +#undef L2_WIDTH_CASE +#undef LP_STRIDE_CASE +#undef LP_WIDTH_CASE +} + +template <typename T> +bool +runFeatureLPPoolingUpdateGradInput(THCState* state, + const THCDeviceTensor<T, 4>& gradOutput, + const THCDeviceTensor<T, 4>& input, + const THCDeviceTensor<T, 4>& output, + THCDeviceTensor<T, 4>& gradInput, + float power, int width, int stride) { + cudaStream_t stream = + THCState_getCurrentStream(state); + const cudaDeviceProp* deviceProperties = + THCState_getCurrentDeviceProperties(state); + + for (int i = 0; i < 4; ++i) { + THAssert(gradOutput.getSize(i) == output.getSize(i)); + THAssert(gradInput.getSize(i) == input.getSize(i)); + } + + int outputFeatures = ((input.getSize(1) - width) / stride) + 1; + + THAssert(gradInput.getSize(0) == gradOutput.getSize(0)); + THAssert(outputFeatures == gradOutput.getSize(1)); + THAssert(gradInput.getSize(1) >= width); + + THAssert(gradInput.getSize(2) == gradOutput.getSize(2)); + THAssert(gradInput.getSize(3) == gradOutput.getSize(3)); + THAssert(power > 0.0f); + THAssert(width >= 1); + THAssert(stride >= 1); + + // Different threads are potentially adding into overlapping input + // points, so we must clear out gradInput before continuing. + gradInput.zero(stream); + + // Split non-features among threads and grid x + int totalNonFeatureSize = input.getSize(2) * input.getSize(3); + int numWarps = + min(THCCeilDiv(totalNonFeatureSize, deviceProperties->warpSize), + MAX_WARPS_PER_RUN); + int blockSize = deviceProperties->warpSize * numWarps; + + // Split non-features among grid x + int nonFeatureSizeBlocks = THCCeilDiv(totalNonFeatureSize, blockSize); + + // Split features among grid y, up to a maximum number of features per thread + int featureBlocks = THCCeilDiv(outputFeatures, OUTPUT_FEATURES_PER_THREAD); + + // Split batch among grid z. + dim3 grid(nonFeatureSizeBlocks, featureBlocks, input.getSize(0)); + dim3 block(blockSize); + +#define L2_STRIDE_CASE(STRIDE, WIDTH) \ + case STRIDE: \ + detail:: \ + featureLPPoolingUpdateGradInput< \ + T, WIDTH, STRIDE, detail::powerGrad2><<<grid, block, 0, stream>>>( \ + gradOutput, input, output, gradInput, \ + ScalarConvert<float, T>::to(power)); \ + return true; + +#define L2_WIDTH_CASE(WIDTH) \ + case WIDTH: \ + switch (stride) { \ + L2_STRIDE_CASE(1, WIDTH); \ + L2_STRIDE_CASE(2, WIDTH); \ + L2_STRIDE_CASE(3, WIDTH); \ + L2_STRIDE_CASE(4, WIDTH); \ + } + +#define LP_STRIDE_CASE(STRIDE, WIDTH) \ + case STRIDE: \ + detail:: \ + featureLPPoolingUpdateGradInput< \ + T, WIDTH, STRIDE, detail::powerGradN><<<grid, block, 0, stream>>>( \ + gradOutput, input, output, gradInput, \ + ScalarConvert<float, T>::to(power)); \ + return true; + +#define LP_WIDTH_CASE(WIDTH) \ + case WIDTH: \ + switch (stride) { \ + LP_STRIDE_CASE(1, WIDTH); \ + LP_STRIDE_CASE(2, WIDTH); \ + LP_STRIDE_CASE(3, WIDTH); \ + LP_STRIDE_CASE(4, WIDTH); \ + } + + if (power == 2.0f) { + switch (width) { + L2_WIDTH_CASE(2); + L2_WIDTH_CASE(3); + L2_WIDTH_CASE(4); + L2_WIDTH_CASE(5); + L2_WIDTH_CASE(6); + L2_WIDTH_CASE(7); + L2_WIDTH_CASE(8); + L2_WIDTH_CASE(9); + L2_WIDTH_CASE(10); + L2_WIDTH_CASE(11); + L2_WIDTH_CASE(12); + L2_WIDTH_CASE(13); + L2_WIDTH_CASE(14); + L2_WIDTH_CASE(15); + L2_WIDTH_CASE(16); + } + } else { + switch (width) { + LP_WIDTH_CASE(2); + LP_WIDTH_CASE(3); + LP_WIDTH_CASE(4); + LP_WIDTH_CASE(5); + LP_WIDTH_CASE(6); + LP_WIDTH_CASE(7); + LP_WIDTH_CASE(8); + LP_WIDTH_CASE(9); + LP_WIDTH_CASE(10); + LP_WIDTH_CASE(11); + LP_WIDTH_CASE(12); + LP_WIDTH_CASE(13); + LP_WIDTH_CASE(14); + LP_WIDTH_CASE(15); + LP_WIDTH_CASE(16); + } + } + + // Otherwise, we have an unhandled width and/or stride. + return false; + +#undef L2_STRIDE_CASE +#undef L2_WIDTH_CASE +#undef LP_STRIDE_CASE +#undef LP_WIDTH_CASE +} + +#include "generic/FeatureLPPooling.cu" +#include "THCGenerateFloatTypes.h" diff --git a/lib/THCUNN/generic/FeatureLPPooling.cu b/lib/THCUNN/generic/FeatureLPPooling.cu new file mode 100644 index 0000000..9300450 --- /dev/null +++ b/lib/THCUNN/generic/FeatureLPPooling.cu @@ -0,0 +1,267 @@ +#ifndef THC_GENERIC_FILE +#define THC_GENERIC_FILE "generic/FeatureLPPooling.cu" +#else + +#include "../common.h" + +// non-batch mode: +// [feature dim] +// [feature dim][opt dim 1] +// [feature dim][opt dim 1][opt dim 2] +// +// batch mode: +// [batch dim][feature dim] +// [batch dim][feature dim][opt dim 1] +// [batch dim][feature dim][opt dim 1][opt dim 2] +THCDeviceTensor<real, 4> +THNN_(FeatureLPPooling_upcast)(THCState* state, THCTensor* t, bool batchMode) { + int inputDim = THCTensor_(nDimension)(state, t); + + if (inputDim == 1) { + // [feature dim] + return toDeviceTensor<real, 1>(state, t). + upcastOuter<2>().upcastInner<4>(); + } else if (inputDim == 2) { + if (batchMode) { + // [batch dim][feature dim] + return toDeviceTensor<real, 2>(state, t). + upcastInner<4>(); + } else { + // [feature dim][opt dim 1] + return toDeviceTensor<real, 2>(state, t). + upcastOuter<3>().upcastInner<4>(); + } + } else if (inputDim == 3) { + if (batchMode) { + // [batch dim][feature dim][opt dim 1] + return toDeviceTensor<real, 3>(state, t). + upcastInner<4>(); + } else { + // [feature dim][opt dim 1][opt dim 2] + return toDeviceTensor<real, 3>(state, t). + upcastOuter<4>(); + } + } else { + // inputDim == 4 + // [batch dim][feature dim][opt dim 1][opt dim 2] + THAssert(batchMode); + return toDeviceTensor<real, 4>(state, t); + } +} + +// Resizes `toResize` based on the output size for `src` as an input +// tensor +void +THNN_(FeatureLPPooling_resizeForOutput)(THCState* state, + THCTensor* toResize, + THCTensor* input, + bool batchMode, + int width, + int stride) { + int inputDim = THCTensor_(nDimension)(state, input); + THAssert(inputDim >= 1 && inputDim <= 4); + + long outSize = + lpPoolingOutputSize(THCTensor_(size)(state, input, 0), width, stride); + if (batchMode) { + THAssert(inputDim > 1); + outSize = + lpPoolingOutputSize(THCTensor_(size)(state, input, 1), width, stride); + } else { + THAssert(inputDim < 4); + } + + if (inputDim == 1) { + THCTensor_(resize1d)(state, toResize, outSize); + } else if (inputDim == 2) { + if (batchMode) { + THCTensor_(resize2d)( + state, toResize, THCTensor_(size)(state, input, 0), outSize); + } else { + THCTensor_(resize2d)( + state, toResize, outSize, THCTensor_(size)(state, input, 1)); + } + } else if (inputDim == 3) { + if (batchMode) { + THCTensor_(resize3d)( + state, + toResize, + THCTensor_(size)(state, input, 0), outSize, + THCTensor_(size)(state, input, 2)); + } else { + THCTensor_(resize3d)( + state, + toResize, + outSize, THCTensor_(size)(state, input, 1), + THCTensor_(size)(state, input, 2)); + } + } else if (inputDim == 4) { + THCTensor_(resize4d)( + state, + toResize, + THCTensor_(size)(state, input, 0), outSize, + THCTensor_(size)(state, input, 2), THCTensor_(size)(state, input, 3)); + } +} + +// Makes `toResize` the same size/dimensionality as `src` +void +THNN_(FeatureLPPooling_resize)(THCState* state, + THCTensor* toResize, + THCTensor* src) { + int inputDim = THCTensor_(nDimension)(state, src); + THAssert(inputDim >= 1 && inputDim <= 4); + + if (inputDim == 1) { + THCTensor_(resize1d)(state, + toResize, + THCTensor_(size)(state, src, 0)); + } else if (inputDim == 2) { + THCTensor_(resize2d)( + state, + toResize, + THCTensor_(size)(state, src, 0), + THCTensor_(size)(state, src, 1)); + } else if (inputDim == 3) { + THCTensor_(resize3d)( + state, + toResize, + THCTensor_(size)(state, src, 0), + THCTensor_(size)(state, src, 1), + THCTensor_(size)(state, src, 2)); + } else if (inputDim == 4) { + THCTensor_(resize4d)( + state, + toResize, + THCTensor_(size)(state, src, 0), + THCTensor_(size)(state, src, 1), + THCTensor_(size)(state, src, 2), + THCTensor_(size)(state, src, 3)); + } +} + +void THNN_(FeatureLPPooling_updateOutput)(THCState* state, + THCTensor* inputTH, + THCTensor* outputTH, + accreal power, + int width, + int stride, + bool batchMode) { + THCUNN_assertSameGPU(state, 2, inputTH, outputTH); + + int inputDim = THCTensor_(nDimension)(state, inputTH); + + if (batchMode) { + THArgCheck(inputDim >= 2 && inputDim <= 4, 2, + "input must be 2-4 dimensions for batch mode"); + } else { + THArgCheck(inputDim >= 1 && inputDim <= 3, 2, + "input must be 1-3 dimensions for non-batch mode"); + } + + THArgCheck(TensorUtils<THCTensor>::canUse32BitIndexMath(state, inputTH), 2, + "input tensor must fit into 32-bit index math"); + + THCDeviceTensor<TensorUtils<THCTensor>::DataType, 4> input; + THCDeviceTensor<TensorUtils<THCTensor>::DataType, 4> output; + + input = THNN_(FeatureLPPooling_upcast)(state, inputTH, batchMode); + + // Make sure the feature dimension is properly sized + THArgCheck(input.getSize(1) >= width, 2, + "input: feature dimension must be >= width"); + + // Make sure that width and stride are within range + THArgCheck(width >= 2 && width <= 16, 5, + "width must be between 2 - 16"); + + THArgCheck(stride >= 1 && stride <= 4, 6, + "stride must be between 1 - 4"); + + THNN_(FeatureLPPooling_resizeForOutput)( + state, outputTH, inputTH, batchMode, width, stride); + + output = THNN_(FeatureLPPooling_upcast)(state, outputTH, batchMode); + + bool found = runFeatureLPPoolingUpdateOutput(state, + input, + output, + power, + width, + stride); + THAssert(found); +} + +void THNN_(FeatureLPPooling_updateGradInput)(THCState* state, + THCTensor* gradOutputTH, + THCTensor* inputTH, + THCTensor* outputTH, + THCTensor* gradInputTH, + accreal power, + int width, + int stride, + bool batchMode) { + THArgCheck(TensorUtils<THCTensor>::canUse32BitIndexMath(state, gradOutputTH), 2, + "output gradient tensor must fit into 32-bit index math"); + THArgCheck(TensorUtils<THCTensor>::canUse32BitIndexMath(state, inputTH), 3, + "input tensor must fit into 32-bit index math"); + THCUNN_assertSameGPU(state, 4, gradOutputTH, inputTH, outputTH, gradInputTH); + + int inputDim = THCTensor_(nDimension)(state, inputTH); + + if (batchMode) { + THArgCheck(inputDim >= 2 && inputDim <= 4, 2, + "input must be 2-4 dimensions for batch mode"); + } else { + THArgCheck(inputDim >= 1 && inputDim <= 3, 2, + "input must be 1-3 dimensions for non-batch mode"); + } + + THCDeviceTensor<TensorUtils<THCTensor>::DataType, 4> gradOutput; + THCDeviceTensor<TensorUtils<THCTensor>::DataType, 4> input; + THCDeviceTensor<TensorUtils<THCTensor>::DataType, 4> output; + THCDeviceTensor<TensorUtils<THCTensor>::DataType, 4> gradInput; + + input = THNN_(FeatureLPPooling_upcast)(state, inputTH, batchMode); + + // Make sure the feature dimension is properly sized + THArgCheck(input.getSize(1) >= width, 3, + "input: feature dimension must be >= width"); + + // Make sure that width and stride are within range + THArgCheck(width >= 2 && width <= 16, 7, + "width must be between 2 - 16"); + + THArgCheck(stride >= 1 && stride <= 4, 8, + "stride must be between 1 - 4"); + + gradOutput = THNN_(FeatureLPPooling_upcast)(state, gradOutputTH, batchMode); + output = THNN_(FeatureLPPooling_upcast)(state, outputTH, batchMode); + + for (int i = 0; i < 4; ++i) { + THAssertMsg(output.getSize(i) == gradOutput.getSize(i), + "output and gradOutput sizes do not match"); + } + + // Make sure that the input sizes produce the output sizes + THArgCheck(lpPoolingOutputSize(input.getSize(1), width, stride) == + output.getSize(1), 3, + "input and output sizes do not match with respect to " + "width and stride"); + + // Resize `gradInput` based on `input` + THNN_(FeatureLPPooling_resize)(state, gradInputTH, inputTH); + gradInput = THNN_(FeatureLPPooling_upcast)(state, gradInputTH, batchMode); + + bool found = runFeatureLPPoolingUpdateGradInput(state, + gradOutput, + input, + output, + gradInput, + power, + width, + stride); + THAssert(found); +} + +#endif diff --git a/lib/THCUNN/generic/THCUNN.h b/lib/THCUNN/generic/THCUNN.h index e770dff..9692094 100644 --- a/lib/THCUNN/generic/THCUNN.h +++ b/lib/THCUNN/generic/THCUNN.h @@ -123,6 +123,26 @@ TH_API void THNN_(ELU_updateGradInput)( accreal alpha, bool inplace); +TH_API void THNN_(FeatureLPPooling_updateOutput)( + THCState* state, + THCTensor* inputTH, + THCTensor* outputTH, + accreal power, + int width, + int stride, + bool batchMode); + +TH_API void THNN_(FeatureLPPooling_updateGradInput)( + THCState* state, + THCTensor* gradOutputTH, + THCTensor* inputTH, + THCTensor* outputTH, + THCTensor* gradInputTH, + accreal power, + int width, + int stride, + bool batchMode); + TH_API void THNN_(HardTanh_updateOutput)( THCState *state, THCTensor *input, @@ -5281,6 +5281,149 @@ function cunntest.VolumetricAveragePooling_backward() end end +function cunntest.FeatureLPPooling_forward() + for tries = 1, 5 do + local batch_mode = {true, false} + batch_mode = batch_mode[math.random(1, 2)] + local power = {2, 3} + power = power[math.random(1, 2)] + + local dims = math.random(1, 3) + + if batch_mode then + dims = dims + 1 + end + + local width = torch.random(2, 16) + local stride = torch.random(1, 4) + + local output_size = torch.random(1, 100) + local input_size = (output_size - 1) * stride + width + + local baseInput = nil + if dims == 1 then + baseInput = torch.Tensor(input_size):uniform() + elseif dims == 2 then + if batch_mode then + baseInput = torch.Tensor(math.random(1, 5), input_size):uniform() + else + baseInput = torch.Tensor(input_size, math.random(1, 5)):uniform() + end + elseif dims == 3 then + if batch_mode then + baseInput = torch.Tensor(math.random(1, 5), input_size, + math.random(1, 5)):uniform() + else + baseInput = torch.Tensor(input_size, math.random(1, 5), + math.random(1, 5)):uniform() + end + else + baseInput = torch.Tensor(math.random(1, 5), input_size, + math.random(1, 5), math.random(1, 5)):uniform() + end + + for k, typename in ipairs(typenames) do + local input = baseInput:type(typename) + + local ctype = t2cpu[typename] + input = makeNonContiguous(input:type(ctype)) + local sconv = nn.FeatureLPPooling(width, stride, power, batch_mode):type(ctype) + local groundtruth = sconv:forward(input) + + input = makeNonContiguous(input:type(typename)) + local gconv = nn.FeatureLPPooling(width, stride, power, batch_mode):type(typename) + local rescuda = gconv:forward(input) + + local error = rescuda:double() - groundtruth:double() + mytester:assertlt(error:abs():max(), + precision_forward_type(precision_forward, typename), + string.format('error on state (forward) with %s', typename)) + end + end +end + +function cunntest.FeatureLPPooling_backward() + for tries = 1, 5 do + local batch_mode = {true, false} + batch_mode = batch_mode[math.random(1, 2)] + local power = {2, 3} + power = power[math.random(1, 2)] + + local dims = math.random(1, 3) + + if batch_mode then + dims = dims + 1 + end + + local width = torch.random(2, 16) + local stride = torch.random(1, 4) + + local output_size = torch.random(1, 100) + local input_size = (output_size - 1) * stride + width + + local baseInput = nil + local baseGradOutput = nil + + if dims == 1 then + baseInput = torch.Tensor(input_size):uniform() + baseGradOutput = torch.Tensor(output_size):uniform() + elseif dims == 2 then + local a = math.random(1, 5) + if batch_mode then + baseInput = torch.Tensor(a, input_size):uniform() + baseGradOutput = torch.Tensor(a, output_size):uniform() + else + baseInput = torch.Tensor(input_size, a):uniform() + baseGradOutput = torch.Tensor(output_size, a):uniform() + end + elseif dims == 3 then + local a = math.random(1, 5) + local b = math.random(1, 5) + if batch_mode then + baseInput = torch.Tensor(a, input_size, b):uniform() + baseGradOutput = torch.Tensor(a, output_size, b):uniform() + else + baseInput = torch.Tensor(input_size, a, b):uniform() + baseGradOutput = torch.Tensor(output_size, a, b):uniform() + end + else + local a = math.random(1, 5) + local b = math.random(1, 5) + local c = math.random(1, 5) + baseInput = torch.Tensor(a, input_size, b, c):uniform() + baseGradOutput = torch.Tensor(a, output_size, b, c):uniform() + end + + for k, typename in ipairs(typenames) do + local input = baseInput:type(typename) + local gradOutput = baseGradOutput:type(typename) + local ctype = t2cpu[typename] + input = makeNonContiguous(input:type(ctype)) + gradOutput = makeNonContiguous(gradOutput:type(ctype)) + + local sconv = nn.FeatureLPPooling(width, stride, power, batch_mode):type(ctype) + if ceil_mode then sconv:ceil() end + sconv:forward(input) + sconv:zeroGradParameters() + local groundgrad = sconv:backward(input, gradOutput) + + input = makeNonContiguous(input:type(typename)) + gradOutput = makeNonContiguous(gradOutput:type(typename)) + local gconv = nn.FeatureLPPooling(width, stride, power, batch_mode):type(typename) + if ceil_mode then gconv:ceil() end + + gconv:forward(input) + gconv:zeroGradParameters() + local rescuda = gconv:backward(input, gradOutput) + + local error = rescuda:double() - groundgrad:double() + + mytester:assertlt(error:abs():max(), precision_backward_type(precision_backward, typename), + string.format('error on state (backward) with %s', typename)) + end + end +end + function cunntest.CMul_forward_batch() local bs = math.random(8,32) local nini = math.random(1,100) |