Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cunn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2017-07-26 00:01:59 +0300
committerGitHub <noreply@github.com>2017-07-26 00:01:59 +0300
commit8d9e9562beb792e17c6614ec2b515094f9663776 (patch)
tree7e528a1e2674254b1f22cba2939c5b5d4818af80
parentb336dc940c513d0b42d2ef2940bec9199b4377cf (diff)
parent2c84b988883c98b5617dd24d06b633951afa0de6 (diff)
Merge pull request #477 from wickedfoo/feature_lp_pooling
GPU implementation of L_p feature pooling
-rw-r--r--lib/THCUNN/FeatureLPPooling.cu653
-rw-r--r--lib/THCUNN/generic/FeatureLPPooling.cu267
-rw-r--r--lib/THCUNN/generic/THCUNN.h20
-rw-r--r--test.lua143
4 files changed, 1083 insertions, 0 deletions
diff --git a/lib/THCUNN/FeatureLPPooling.cu b/lib/THCUNN/FeatureLPPooling.cu
new file mode 100644
index 0000000..4ad190f
--- /dev/null
+++ b/lib/THCUNN/FeatureLPPooling.cu
@@ -0,0 +1,653 @@
+#include "THCUNN.h"
+#include "THCAtomics.cuh"
+#include "THCDeviceTensor.cuh"
+#include "THCDeviceTensorUtils.cuh"
+#include "THCDeviceUtils.cuh"
+#include "THCNumerics.cuh"
+#include "THCTensorTypeUtils.cuh"
+
+#define OUTPUT_FEATURES_PER_THREAD 32
+#define MAX_WARPS_PER_RUN 4
+
+namespace detail {
+
+/// Various utilities for dealing with arrays of values which are
+/// maintained in thread-local registers. All accesses are done in such
+/// a way such that the index is statically known, which preserves the
+/// compiler's ability to allocate the values to registers, as opposed
+/// to local memory.
+template <typename T, int N>
+struct RegisterUtils {
+ /// Register shifting: move elements towards the beginning of the
+ /// array (towards 0) by `Shift` places:
+ /// arr[i] = arr[i + Shift]
+ /// The `Shift` elements at the end are left unchanged.
+ template <int Shift>
+ __device__ __forceinline__ static void shiftLeft(T arr[N]) {
+ // e.g., N = 5, Shift = 2:
+ // 0 1 2 3 4 becomes =>
+ // 2 3 4 3 4 (last are unchanged)
+#pragma unroll
+ for (int i = 0; i < N - Shift; ++i) {
+ arr[i] = arr[i + Shift];
+ }
+ }
+};
+
+template <typename T>
+__device__ __forceinline__
+int getDim1Point(const THCDeviceTensor<T, 4>& input) {
+ int threadPoint = blockIdx.x * blockDim.x + threadIdx.x;
+ return threadPoint / input.getSize(3);
+}
+
+template <typename T>
+__device__ __forceinline__
+int getDim2Point(const THCDeviceTensor<T, 4>& input) {
+ int threadPoint = blockIdx.x * blockDim.x + threadIdx.x;
+ return threadPoint % input.getSize(3);
+}
+
+__device__ __forceinline__
+int getStartOutputFeature() {
+ return blockIdx.y * OUTPUT_FEATURES_PER_THREAD;
+}
+
+template <typename T>
+__device__ __forceinline__
+int getEndOutputFeature(const THCDeviceTensor<T, 4>& output) {
+ return min((blockIdx.y + 1) * OUTPUT_FEATURES_PER_THREAD, output.getSize(1));
+}
+
+__device__ __forceinline__
+int getBatch() {
+ return blockIdx.z;
+}
+
+// All of these functions that follow are MathOps; they are template
+// parameters so L2 can be more efficiently implemented
+// template <typename T>
+// typedef T (*MathOp)(const T in, const T arg);
+
+template <typename T>
+__device__ __forceinline__ T power2(const T in, const T power) {
+ return THCNumerics<T>::mul(in, in);
+}
+
+template <typename T>
+__device__ __forceinline__ T root2(const T in, const T power) {
+ return THCNumerics<T>::sqrt(in);
+}
+
+template <typename T>
+__device__ __forceinline__ T powerGrad2(const T in, const T power) {
+ return in;
+}
+
+template <typename T>
+__device__ __forceinline__ T powerN(const T in, const T power) {
+ return THCNumerics<T>::pow(in, power);
+}
+
+template <typename T>
+__device__ __forceinline__ T rootN(const T in, const T power) {
+ const T invPower = THCNumerics<T>::cinv(power);
+ return THCNumerics<T>::pow(in, invPower);
+}
+
+template <typename T>
+__device__ __forceinline__ T powerGradN(const T in, const T power) {
+ return THCNumerics<T>::pow(in,
+ THCNumerics<T>::sub(power,
+ ScalarConvert<int, T>::to(1)));
+}
+
+// Input is of the form:
+// [batch][feature dim][optional dim 1][optional dim 2]
+template <typename T,
+ int Width,
+ int Stride,
+ T (*PowerFunc)(T in, T power),
+ T (*RootFunc)(T in, T power)>
+__global__ void
+featureLPPoolingUpdateOutput(const THCDeviceTensor<T, 4> input,
+ THCDeviceTensor<T, 4> output,
+ T power) {
+ // What non-feature points is this thread handling?
+ int dim1Point = getDim1Point(input);
+ int dim2Point = getDim2Point(input);
+
+ if (dim1Point >= input.getSize(2) || dim2Point >= input.getSize(3)) {
+ // This thread in the warp is out of bounds
+ return;
+ }
+
+ // What feature points is this thread handling?
+ int startOutputFeature = getStartOutputFeature();
+ int endOutputFeature = getEndOutputFeature(output);
+ int startInputFeature = startOutputFeature * Stride;
+
+ // What batch points is this thread handling?
+ int batch = getBatch();
+
+ // If stride >= width, then there is no loaded data reuse.
+ // If stride > 1 and stride < width, then shift by stride, since we
+ // can reuse Width - Stride elements from the previous round.
+ // e.g., width = 5, stride = 2,
+ // output 0 uses input 0 1 2 3 4
+ // output 1 uses input 2 3 4 5 6 (inputs 2 - 4 are reused, i.e., 5 -
+ // 2 elements are reused, and we have to shift the array by 2)
+ //
+ // e.g., width = 5, stride = 3,
+ // output 0 uses input 0 1 2 3 4
+ // output 1 uses input 3 4 5 6 7 (inputs 3 - 4 are reused, i.e., 5 - 3
+ // elements are reused, and we have to shift the array by 3)
+
+ // Valid only pooling: load Width elements from input (Width -
+ // Stride is handled here, at the top of the loop we handle the
+ // remaining Stride elements). We already verified that the input is
+ // larger than the width.
+ // `in` will contain the input values ^ power.
+ T in[Width];
+
+#pragma unroll
+ for (int i = 0; i < Width - Stride; ++i) {
+ const T data =
+ input[batch][startInputFeature + i][dim1Point][dim2Point];
+ in[i] = PowerFunc(data, power);
+ }
+
+ for (int outputFeature = startOutputFeature;
+ outputFeature < endOutputFeature;
+ ++outputFeature) {
+ // If Stride < Width, we're loading Stride new values starting at
+ // Width - Stride
+ // If Stride >= Width, we're loading Width new values starting at 0
+ if (Stride < Width) {
+ int nextInputFeature = outputFeature * Stride + Width - Stride;
+
+#pragma unroll
+ for (int i = 0; i < Stride; ++i) {
+ const T data =
+ input[batch][nextInputFeature + i][dim1Point][dim2Point];
+ in[Width - Stride + i] = PowerFunc(data, power);
+ }
+ } else {
+ int nextInputFeature = outputFeature * Stride;
+
+#pragma unroll
+ for (int i = 0; i < Width; ++i) {
+ T data = input[batch][nextInputFeature + i][dim1Point][dim2Point];
+ in[i] = PowerFunc(data, power);
+ }
+ }
+
+ // Calculate the new output feature
+ T val = ScalarConvert<int, T>::to(0);
+ for (int i = 0; i < Width; ++i) {
+ val = THCNumerics<T>::add(val, in[i]);
+ }
+
+ val = RootFunc(val, power);
+ output[batch][outputFeature][dim1Point][dim2Point] = val;
+
+ if (Stride < Width) {
+ // Shift registers for calculating the next point
+ RegisterUtils<T, Width>::shiftLeft<Stride>(in);
+ }
+ }
+}
+
+// forward pass: f(a, ..., z) = (a^p + ... + z^p)^(1 / p)
+// for bprop:
+// partial df(a, ... z)/da = a^(p - 1) * (a^p + ... + z^p)^((1 / p) - 1) =
+// a^(p - 1) * 1/(f(a, ..., z)^(p - 1)) = (a / f(a, ..., z))^(p - 1)
+//
+// example: for p = 2, df(a, ..., z)/da = a / f(a, ..., z)
+// example: for p = 3, df(a, ..., z)/da = (a / f(a, ..., z))^2
+//
+// PowerGradFunc implements x^(p - 1)
+template <typename T,
+ int Width,
+ int Stride,
+ T (*PowerGradFunc)(T in, T arg)>
+__global__ void
+featureLPPoolingUpdateGradInput(const THCDeviceTensor<T, 4> gradOutput,
+ const THCDeviceTensor<T, 4> input,
+ const THCDeviceTensor<T, 4> output,
+ THCDeviceTensor<T, 4> gradInput,
+ T power) {
+ // What non-feature points is this thread handling?
+ int dim1Point = getDim1Point(input);
+ int dim2Point = getDim2Point(input);
+
+ if (dim1Point >= input.getSize(2) || dim2Point >= input.getSize(3)) {
+ // This thread in the warp is out of bounds
+ return;
+ }
+
+ // What feature points is this thread handling? [start, end)
+ int startOutputFeature = getStartOutputFeature();
+ int endOutputFeature = getEndOutputFeature(output);
+
+ // What is the first input point that the output features depend
+ // upon? [start, end)
+ int startInputFeature = startOutputFeature * Stride;
+ int endInputFeature = endOutputFeature * Stride;
+
+ // What batch points is this thread handling?
+ int batch = getBatch();
+
+ // atomicAdd into gradInput is slow, avoid it where possible.
+ // We can do this because there is a range of gradInput elements
+ // that we are updating exclusively. This is how we find it
+ //
+ // width = 3 stride = 1 example:
+ // ------------------------------
+ // startOutputFeature for this thread
+ // |
+ // |
+ // previous thread's output feature
+ // | |
+ // | | gradOutput
+ // __v____v___________________
+ // | | | | | |
+ // ---------------------------
+ // |\ \_____
+ // | \__ \ gradInput
+ // __v____v____v_____________
+ // | | | | | |
+ // ---------------------------
+ // A A
+ // | |
+ // startInputFeature
+ // |
+ // exclusiveStartInputFeature
+ //
+ // exclusiveStartInputFeature is the first input feature that we can
+ // write into exclusively; the one right before it overlaps with
+ // updates from a previous thread and thus has to use atomicAdd.
+ int exclusiveStartInputFeature =
+ startInputFeature == 0 ?
+ // no thread is before ourselves
+ 0 :
+ // there is a thread before ourselves
+ startInputFeature + (Width - 1) * Stride;
+
+ // Similarly, exclusiveEndInputFeature is the last input feature
+ // that we can write into exclusively, since we might be overlapping
+ // with the following thread
+ int exclusiveEndInputFeature =
+ endOutputFeature == output.getSize(1) ?
+ // no thread is after ourselves
+ endInputFeature + (Width - 1) * Stride :
+ // there is a thread after ourselves
+ endInputFeature;
+
+ // As with updateOutput preload input elements, except no need to
+ // transform them
+ T in[Width];
+#pragma unroll
+ for (int i = 0; i < Width - Stride; ++i) {
+ in[i] = input[batch][startInputFeature + i][dim1Point][dim2Point];
+ }
+
+ for (int outputFeature = startOutputFeature;
+ outputFeature < endOutputFeature;
+ ++outputFeature) {
+ // As with updateOutput load the subsequent input elements that we
+ // need, except no need to transform them
+ //
+ // If Stride < Width, we're loading Stride new values starting at
+ // Width - Stride
+ // If Stride >= Width, we're loading Width new values starting at 0
+ if (Stride < Width) {
+ int nextInputFeature = outputFeature * Stride + Width - Stride;
+
+#pragma unroll
+ for (int i = 0; i < Stride; ++i) {
+ in[Width - Stride + i] =
+ input[batch][nextInputFeature + i][dim1Point][dim2Point];
+ }
+ } else {
+ int nextInputFeature = outputFeature * Stride;
+
+#pragma unroll
+ for (int i = 0; i < Width; ++i) {
+ in[i] = input[batch][nextInputFeature + i][dim1Point][dim2Point];
+ }
+ }
+
+ // A given output feature gradient contributes to `Width` input
+ // gradients
+ const T gradOut =
+ gradOutput[batch][outputFeature][dim1Point][dim2Point];
+
+ // Load output (f(x_is)). It is possible that this is zero, in
+ // which case we'll ignore this point.
+ T out = output[batch][outputFeature][dim1Point][dim2Point];
+ if (THCNumerics<T>::eq(out, ScalarConvert<int, T>::to(0))) {
+ continue;
+ }
+
+ int curStartInputFeature = outputFeature * Stride;
+ int curEndInputFeature = outputFeature * Stride + Width - 1;
+
+ if (curStartInputFeature >= exclusiveStartInputFeature &&
+ curEndInputFeature < exclusiveEndInputFeature) {
+ // This thread is exclusively responsible for updating these
+ // input points, so we need not make the addition atomic
+ for (int i = 0; i < Width; ++i) {
+ int inputFeature = outputFeature * Stride + i;
+
+ // Calculate grad * (x_i / f(x_is))^(p - 1)
+ const T val = THCNumerics<T>::mul(
+ gradOut,
+ PowerGradFunc(THCNumerics<T>::div(in[i], out), power));
+
+ gradInput[batch][inputFeature][dim1Point][dim2Point] =
+ THCNumerics<T>::add(
+ gradInput[batch][inputFeature][dim1Point][dim2Point], val);
+ }
+ } else {
+ // Handle start and end boundary cases: potential overlap with
+ // other threads
+ for (int i = 0; i < Width; ++i) {
+ int inputFeature = outputFeature * Stride + i;
+
+ // Calculate grad * (x_i / f(x_is))^(p - 1)
+ T val = THCNumerics<T>::mul(
+ gradOut,
+ PowerGradFunc(THCNumerics<T>::div(in[i], out), power));
+
+ // We don't overlap other threads for this range
+ if (inputFeature >= exclusiveStartInputFeature &&
+ inputFeature < exclusiveEndInputFeature) {
+ gradInput[batch][inputFeature][dim1Point][dim2Point]
+ = THCNumerics<T>::add(
+ gradInput[batch][inputFeature][dim1Point][dim2Point], val);
+ } else {
+ // We are potentially overlapping with threads handling
+ // features before ourselves, so these need to be added atomically
+ atomicAdd(&gradInput[batch][inputFeature][dim1Point][dim2Point],
+ val);
+ }
+ }
+ }
+
+ if (Stride < Width) {
+ // Shift registers for calculating the next point
+ RegisterUtils<T, Width>::shiftLeft<Stride>(in);
+ }
+ }
+}
+
+} // namespace detail
+
+inline int lpPoolingOutputSize(int inputSize, int width, int stride) {
+ return ((inputSize - width) / stride) + 1;
+}
+
+template <typename T>
+bool
+runFeatureLPPoolingUpdateOutput(THCState* state,
+ const THCDeviceTensor<T, 4>& input,
+ THCDeviceTensor<T, 4>& output,
+ float power, int width, int stride) {
+ cudaStream_t stream =
+ THCState_getCurrentStream(state);
+ const cudaDeviceProp* deviceProperties =
+ THCState_getCurrentDeviceProperties(state);
+
+ int outputFeatures = ((input.getSize(1) - width) / stride) + 1;
+
+ THAssert(input.getSize(0) == output.getSize(0));
+ THAssert(outputFeatures == output.getSize(1));
+ THAssert(input.getSize(1) >= width);
+
+ THAssert(input.getSize(2) == output.getSize(2));
+ THAssert(input.getSize(3) == output.getSize(3));
+ THAssert(power > 0.0f);
+ THAssert(width >= 1);
+ THAssert(stride >= 1);
+
+ // Split non-features among threads and grid x
+ int totalNonFeatureSize = input.getSize(2) * input.getSize(3);
+ int numWarps =
+ min(THCCeilDiv(totalNonFeatureSize, deviceProperties->warpSize),
+ MAX_WARPS_PER_RUN);
+ int blockSize = deviceProperties->warpSize * numWarps;
+
+ // Split non-features among grid x
+ int nonFeatureSizeBlocks = THCCeilDiv(totalNonFeatureSize, blockSize);
+
+ // Split features among grid y, up to a maximum number of features per thread
+ int featureBlocks = THCCeilDiv(outputFeatures, OUTPUT_FEATURES_PER_THREAD);
+
+ // Split batch among grid z.
+ dim3 grid(nonFeatureSizeBlocks, featureBlocks, input.getSize(0));
+ dim3 block(blockSize);
+
+#define L2_STRIDE_CASE(STRIDE, WIDTH) \
+ case STRIDE: \
+ detail:: \
+ featureLPPoolingUpdateOutput<T, WIDTH, \
+ STRIDE, \
+ detail::power2, \
+ detail::root2><<<grid, block, 0, stream>>>( \
+ input, output, \
+ ScalarConvert<float, T>::to(power)); \
+ return true;
+
+#define L2_WIDTH_CASE(WIDTH) \
+ case WIDTH: \
+ switch (stride) { \
+ L2_STRIDE_CASE(1, WIDTH); \
+ L2_STRIDE_CASE(2, WIDTH); \
+ L2_STRIDE_CASE(3, WIDTH); \
+ L2_STRIDE_CASE(4, WIDTH); \
+ }
+
+#define LP_STRIDE_CASE(STRIDE, WIDTH) \
+ case STRIDE: \
+ detail:: \
+ featureLPPoolingUpdateOutput<T, WIDTH, \
+ STRIDE, \
+ detail::powerN, \
+ detail::rootN><<<grid, block, 0, stream>>>( \
+ input, output, \
+ ScalarConvert<float, T>::to(power)); \
+ return true;
+
+#define LP_WIDTH_CASE(WIDTH) \
+ case WIDTH: \
+ switch (stride) { \
+ LP_STRIDE_CASE(1, WIDTH); \
+ LP_STRIDE_CASE(2, WIDTH); \
+ LP_STRIDE_CASE(3, WIDTH); \
+ LP_STRIDE_CASE(4, WIDTH); \
+ }
+
+ if (power == 2.0f) {
+ switch (width) {
+ L2_WIDTH_CASE(2);
+ L2_WIDTH_CASE(3);
+ L2_WIDTH_CASE(4);
+ L2_WIDTH_CASE(5);
+ L2_WIDTH_CASE(6);
+ L2_WIDTH_CASE(7);
+ L2_WIDTH_CASE(8);
+ L2_WIDTH_CASE(9);
+ L2_WIDTH_CASE(10);
+ L2_WIDTH_CASE(11);
+ L2_WIDTH_CASE(12);
+ L2_WIDTH_CASE(13);
+ L2_WIDTH_CASE(14);
+ L2_WIDTH_CASE(15);
+ L2_WIDTH_CASE(16);
+ }
+ } else {
+ switch (width) {
+ LP_WIDTH_CASE(2);
+ LP_WIDTH_CASE(3);
+ LP_WIDTH_CASE(4);
+ LP_WIDTH_CASE(5);
+ LP_WIDTH_CASE(6);
+ LP_WIDTH_CASE(7);
+ LP_WIDTH_CASE(8);
+ LP_WIDTH_CASE(9);
+ LP_WIDTH_CASE(10);
+ LP_WIDTH_CASE(11);
+ LP_WIDTH_CASE(12);
+ LP_WIDTH_CASE(13);
+ LP_WIDTH_CASE(14);
+ LP_WIDTH_CASE(15);
+ LP_WIDTH_CASE(16);
+ }
+ }
+
+ // Otherwise, we have an unhandled width and/or stride.
+ return false;
+
+#undef L2_STRIDE_CASE
+#undef L2_WIDTH_CASE
+#undef LP_STRIDE_CASE
+#undef LP_WIDTH_CASE
+}
+
+template <typename T>
+bool
+runFeatureLPPoolingUpdateGradInput(THCState* state,
+ const THCDeviceTensor<T, 4>& gradOutput,
+ const THCDeviceTensor<T, 4>& input,
+ const THCDeviceTensor<T, 4>& output,
+ THCDeviceTensor<T, 4>& gradInput,
+ float power, int width, int stride) {
+ cudaStream_t stream =
+ THCState_getCurrentStream(state);
+ const cudaDeviceProp* deviceProperties =
+ THCState_getCurrentDeviceProperties(state);
+
+ for (int i = 0; i < 4; ++i) {
+ THAssert(gradOutput.getSize(i) == output.getSize(i));
+ THAssert(gradInput.getSize(i) == input.getSize(i));
+ }
+
+ int outputFeatures = ((input.getSize(1) - width) / stride) + 1;
+
+ THAssert(gradInput.getSize(0) == gradOutput.getSize(0));
+ THAssert(outputFeatures == gradOutput.getSize(1));
+ THAssert(gradInput.getSize(1) >= width);
+
+ THAssert(gradInput.getSize(2) == gradOutput.getSize(2));
+ THAssert(gradInput.getSize(3) == gradOutput.getSize(3));
+ THAssert(power > 0.0f);
+ THAssert(width >= 1);
+ THAssert(stride >= 1);
+
+ // Different threads are potentially adding into overlapping input
+ // points, so we must clear out gradInput before continuing.
+ gradInput.zero(stream);
+
+ // Split non-features among threads and grid x
+ int totalNonFeatureSize = input.getSize(2) * input.getSize(3);
+ int numWarps =
+ min(THCCeilDiv(totalNonFeatureSize, deviceProperties->warpSize),
+ MAX_WARPS_PER_RUN);
+ int blockSize = deviceProperties->warpSize * numWarps;
+
+ // Split non-features among grid x
+ int nonFeatureSizeBlocks = THCCeilDiv(totalNonFeatureSize, blockSize);
+
+ // Split features among grid y, up to a maximum number of features per thread
+ int featureBlocks = THCCeilDiv(outputFeatures, OUTPUT_FEATURES_PER_THREAD);
+
+ // Split batch among grid z.
+ dim3 grid(nonFeatureSizeBlocks, featureBlocks, input.getSize(0));
+ dim3 block(blockSize);
+
+#define L2_STRIDE_CASE(STRIDE, WIDTH) \
+ case STRIDE: \
+ detail:: \
+ featureLPPoolingUpdateGradInput< \
+ T, WIDTH, STRIDE, detail::powerGrad2><<<grid, block, 0, stream>>>( \
+ gradOutput, input, output, gradInput, \
+ ScalarConvert<float, T>::to(power)); \
+ return true;
+
+#define L2_WIDTH_CASE(WIDTH) \
+ case WIDTH: \
+ switch (stride) { \
+ L2_STRIDE_CASE(1, WIDTH); \
+ L2_STRIDE_CASE(2, WIDTH); \
+ L2_STRIDE_CASE(3, WIDTH); \
+ L2_STRIDE_CASE(4, WIDTH); \
+ }
+
+#define LP_STRIDE_CASE(STRIDE, WIDTH) \
+ case STRIDE: \
+ detail:: \
+ featureLPPoolingUpdateGradInput< \
+ T, WIDTH, STRIDE, detail::powerGradN><<<grid, block, 0, stream>>>( \
+ gradOutput, input, output, gradInput, \
+ ScalarConvert<float, T>::to(power)); \
+ return true;
+
+#define LP_WIDTH_CASE(WIDTH) \
+ case WIDTH: \
+ switch (stride) { \
+ LP_STRIDE_CASE(1, WIDTH); \
+ LP_STRIDE_CASE(2, WIDTH); \
+ LP_STRIDE_CASE(3, WIDTH); \
+ LP_STRIDE_CASE(4, WIDTH); \
+ }
+
+ if (power == 2.0f) {
+ switch (width) {
+ L2_WIDTH_CASE(2);
+ L2_WIDTH_CASE(3);
+ L2_WIDTH_CASE(4);
+ L2_WIDTH_CASE(5);
+ L2_WIDTH_CASE(6);
+ L2_WIDTH_CASE(7);
+ L2_WIDTH_CASE(8);
+ L2_WIDTH_CASE(9);
+ L2_WIDTH_CASE(10);
+ L2_WIDTH_CASE(11);
+ L2_WIDTH_CASE(12);
+ L2_WIDTH_CASE(13);
+ L2_WIDTH_CASE(14);
+ L2_WIDTH_CASE(15);
+ L2_WIDTH_CASE(16);
+ }
+ } else {
+ switch (width) {
+ LP_WIDTH_CASE(2);
+ LP_WIDTH_CASE(3);
+ LP_WIDTH_CASE(4);
+ LP_WIDTH_CASE(5);
+ LP_WIDTH_CASE(6);
+ LP_WIDTH_CASE(7);
+ LP_WIDTH_CASE(8);
+ LP_WIDTH_CASE(9);
+ LP_WIDTH_CASE(10);
+ LP_WIDTH_CASE(11);
+ LP_WIDTH_CASE(12);
+ LP_WIDTH_CASE(13);
+ LP_WIDTH_CASE(14);
+ LP_WIDTH_CASE(15);
+ LP_WIDTH_CASE(16);
+ }
+ }
+
+ // Otherwise, we have an unhandled width and/or stride.
+ return false;
+
+#undef L2_STRIDE_CASE
+#undef L2_WIDTH_CASE
+#undef LP_STRIDE_CASE
+#undef LP_WIDTH_CASE
+}
+
+#include "generic/FeatureLPPooling.cu"
+#include "THCGenerateFloatTypes.h"
diff --git a/lib/THCUNN/generic/FeatureLPPooling.cu b/lib/THCUNN/generic/FeatureLPPooling.cu
new file mode 100644
index 0000000..9300450
--- /dev/null
+++ b/lib/THCUNN/generic/FeatureLPPooling.cu
@@ -0,0 +1,267 @@
+#ifndef THC_GENERIC_FILE
+#define THC_GENERIC_FILE "generic/FeatureLPPooling.cu"
+#else
+
+#include "../common.h"
+
+// non-batch mode:
+// [feature dim]
+// [feature dim][opt dim 1]
+// [feature dim][opt dim 1][opt dim 2]
+//
+// batch mode:
+// [batch dim][feature dim]
+// [batch dim][feature dim][opt dim 1]
+// [batch dim][feature dim][opt dim 1][opt dim 2]
+THCDeviceTensor<real, 4>
+THNN_(FeatureLPPooling_upcast)(THCState* state, THCTensor* t, bool batchMode) {
+ int inputDim = THCTensor_(nDimension)(state, t);
+
+ if (inputDim == 1) {
+ // [feature dim]
+ return toDeviceTensor<real, 1>(state, t).
+ upcastOuter<2>().upcastInner<4>();
+ } else if (inputDim == 2) {
+ if (batchMode) {
+ // [batch dim][feature dim]
+ return toDeviceTensor<real, 2>(state, t).
+ upcastInner<4>();
+ } else {
+ // [feature dim][opt dim 1]
+ return toDeviceTensor<real, 2>(state, t).
+ upcastOuter<3>().upcastInner<4>();
+ }
+ } else if (inputDim == 3) {
+ if (batchMode) {
+ // [batch dim][feature dim][opt dim 1]
+ return toDeviceTensor<real, 3>(state, t).
+ upcastInner<4>();
+ } else {
+ // [feature dim][opt dim 1][opt dim 2]
+ return toDeviceTensor<real, 3>(state, t).
+ upcastOuter<4>();
+ }
+ } else {
+ // inputDim == 4
+ // [batch dim][feature dim][opt dim 1][opt dim 2]
+ THAssert(batchMode);
+ return toDeviceTensor<real, 4>(state, t);
+ }
+}
+
+// Resizes `toResize` based on the output size for `src` as an input
+// tensor
+void
+THNN_(FeatureLPPooling_resizeForOutput)(THCState* state,
+ THCTensor* toResize,
+ THCTensor* input,
+ bool batchMode,
+ int width,
+ int stride) {
+ int inputDim = THCTensor_(nDimension)(state, input);
+ THAssert(inputDim >= 1 && inputDim <= 4);
+
+ long outSize =
+ lpPoolingOutputSize(THCTensor_(size)(state, input, 0), width, stride);
+ if (batchMode) {
+ THAssert(inputDim > 1);
+ outSize =
+ lpPoolingOutputSize(THCTensor_(size)(state, input, 1), width, stride);
+ } else {
+ THAssert(inputDim < 4);
+ }
+
+ if (inputDim == 1) {
+ THCTensor_(resize1d)(state, toResize, outSize);
+ } else if (inputDim == 2) {
+ if (batchMode) {
+ THCTensor_(resize2d)(
+ state, toResize, THCTensor_(size)(state, input, 0), outSize);
+ } else {
+ THCTensor_(resize2d)(
+ state, toResize, outSize, THCTensor_(size)(state, input, 1));
+ }
+ } else if (inputDim == 3) {
+ if (batchMode) {
+ THCTensor_(resize3d)(
+ state,
+ toResize,
+ THCTensor_(size)(state, input, 0), outSize,
+ THCTensor_(size)(state, input, 2));
+ } else {
+ THCTensor_(resize3d)(
+ state,
+ toResize,
+ outSize, THCTensor_(size)(state, input, 1),
+ THCTensor_(size)(state, input, 2));
+ }
+ } else if (inputDim == 4) {
+ THCTensor_(resize4d)(
+ state,
+ toResize,
+ THCTensor_(size)(state, input, 0), outSize,
+ THCTensor_(size)(state, input, 2), THCTensor_(size)(state, input, 3));
+ }
+}
+
+// Makes `toResize` the same size/dimensionality as `src`
+void
+THNN_(FeatureLPPooling_resize)(THCState* state,
+ THCTensor* toResize,
+ THCTensor* src) {
+ int inputDim = THCTensor_(nDimension)(state, src);
+ THAssert(inputDim >= 1 && inputDim <= 4);
+
+ if (inputDim == 1) {
+ THCTensor_(resize1d)(state,
+ toResize,
+ THCTensor_(size)(state, src, 0));
+ } else if (inputDim == 2) {
+ THCTensor_(resize2d)(
+ state,
+ toResize,
+ THCTensor_(size)(state, src, 0),
+ THCTensor_(size)(state, src, 1));
+ } else if (inputDim == 3) {
+ THCTensor_(resize3d)(
+ state,
+ toResize,
+ THCTensor_(size)(state, src, 0),
+ THCTensor_(size)(state, src, 1),
+ THCTensor_(size)(state, src, 2));
+ } else if (inputDim == 4) {
+ THCTensor_(resize4d)(
+ state,
+ toResize,
+ THCTensor_(size)(state, src, 0),
+ THCTensor_(size)(state, src, 1),
+ THCTensor_(size)(state, src, 2),
+ THCTensor_(size)(state, src, 3));
+ }
+}
+
+void THNN_(FeatureLPPooling_updateOutput)(THCState* state,
+ THCTensor* inputTH,
+ THCTensor* outputTH,
+ accreal power,
+ int width,
+ int stride,
+ bool batchMode) {
+ THCUNN_assertSameGPU(state, 2, inputTH, outputTH);
+
+ int inputDim = THCTensor_(nDimension)(state, inputTH);
+
+ if (batchMode) {
+ THArgCheck(inputDim >= 2 && inputDim <= 4, 2,
+ "input must be 2-4 dimensions for batch mode");
+ } else {
+ THArgCheck(inputDim >= 1 && inputDim <= 3, 2,
+ "input must be 1-3 dimensions for non-batch mode");
+ }
+
+ THArgCheck(TensorUtils<THCTensor>::canUse32BitIndexMath(state, inputTH), 2,
+ "input tensor must fit into 32-bit index math");
+
+ THCDeviceTensor<TensorUtils<THCTensor>::DataType, 4> input;
+ THCDeviceTensor<TensorUtils<THCTensor>::DataType, 4> output;
+
+ input = THNN_(FeatureLPPooling_upcast)(state, inputTH, batchMode);
+
+ // Make sure the feature dimension is properly sized
+ THArgCheck(input.getSize(1) >= width, 2,
+ "input: feature dimension must be >= width");
+
+ // Make sure that width and stride are within range
+ THArgCheck(width >= 2 && width <= 16, 5,
+ "width must be between 2 - 16");
+
+ THArgCheck(stride >= 1 && stride <= 4, 6,
+ "stride must be between 1 - 4");
+
+ THNN_(FeatureLPPooling_resizeForOutput)(
+ state, outputTH, inputTH, batchMode, width, stride);
+
+ output = THNN_(FeatureLPPooling_upcast)(state, outputTH, batchMode);
+
+ bool found = runFeatureLPPoolingUpdateOutput(state,
+ input,
+ output,
+ power,
+ width,
+ stride);
+ THAssert(found);
+}
+
+void THNN_(FeatureLPPooling_updateGradInput)(THCState* state,
+ THCTensor* gradOutputTH,
+ THCTensor* inputTH,
+ THCTensor* outputTH,
+ THCTensor* gradInputTH,
+ accreal power,
+ int width,
+ int stride,
+ bool batchMode) {
+ THArgCheck(TensorUtils<THCTensor>::canUse32BitIndexMath(state, gradOutputTH), 2,
+ "output gradient tensor must fit into 32-bit index math");
+ THArgCheck(TensorUtils<THCTensor>::canUse32BitIndexMath(state, inputTH), 3,
+ "input tensor must fit into 32-bit index math");
+ THCUNN_assertSameGPU(state, 4, gradOutputTH, inputTH, outputTH, gradInputTH);
+
+ int inputDim = THCTensor_(nDimension)(state, inputTH);
+
+ if (batchMode) {
+ THArgCheck(inputDim >= 2 && inputDim <= 4, 2,
+ "input must be 2-4 dimensions for batch mode");
+ } else {
+ THArgCheck(inputDim >= 1 && inputDim <= 3, 2,
+ "input must be 1-3 dimensions for non-batch mode");
+ }
+
+ THCDeviceTensor<TensorUtils<THCTensor>::DataType, 4> gradOutput;
+ THCDeviceTensor<TensorUtils<THCTensor>::DataType, 4> input;
+ THCDeviceTensor<TensorUtils<THCTensor>::DataType, 4> output;
+ THCDeviceTensor<TensorUtils<THCTensor>::DataType, 4> gradInput;
+
+ input = THNN_(FeatureLPPooling_upcast)(state, inputTH, batchMode);
+
+ // Make sure the feature dimension is properly sized
+ THArgCheck(input.getSize(1) >= width, 3,
+ "input: feature dimension must be >= width");
+
+ // Make sure that width and stride are within range
+ THArgCheck(width >= 2 && width <= 16, 7,
+ "width must be between 2 - 16");
+
+ THArgCheck(stride >= 1 && stride <= 4, 8,
+ "stride must be between 1 - 4");
+
+ gradOutput = THNN_(FeatureLPPooling_upcast)(state, gradOutputTH, batchMode);
+ output = THNN_(FeatureLPPooling_upcast)(state, outputTH, batchMode);
+
+ for (int i = 0; i < 4; ++i) {
+ THAssertMsg(output.getSize(i) == gradOutput.getSize(i),
+ "output and gradOutput sizes do not match");
+ }
+
+ // Make sure that the input sizes produce the output sizes
+ THArgCheck(lpPoolingOutputSize(input.getSize(1), width, stride) ==
+ output.getSize(1), 3,
+ "input and output sizes do not match with respect to "
+ "width and stride");
+
+ // Resize `gradInput` based on `input`
+ THNN_(FeatureLPPooling_resize)(state, gradInputTH, inputTH);
+ gradInput = THNN_(FeatureLPPooling_upcast)(state, gradInputTH, batchMode);
+
+ bool found = runFeatureLPPoolingUpdateGradInput(state,
+ gradOutput,
+ input,
+ output,
+ gradInput,
+ power,
+ width,
+ stride);
+ THAssert(found);
+}
+
+#endif
diff --git a/lib/THCUNN/generic/THCUNN.h b/lib/THCUNN/generic/THCUNN.h
index e770dff..9692094 100644
--- a/lib/THCUNN/generic/THCUNN.h
+++ b/lib/THCUNN/generic/THCUNN.h
@@ -123,6 +123,26 @@ TH_API void THNN_(ELU_updateGradInput)(
accreal alpha,
bool inplace);
+TH_API void THNN_(FeatureLPPooling_updateOutput)(
+ THCState* state,
+ THCTensor* inputTH,
+ THCTensor* outputTH,
+ accreal power,
+ int width,
+ int stride,
+ bool batchMode);
+
+TH_API void THNN_(FeatureLPPooling_updateGradInput)(
+ THCState* state,
+ THCTensor* gradOutputTH,
+ THCTensor* inputTH,
+ THCTensor* outputTH,
+ THCTensor* gradInputTH,
+ accreal power,
+ int width,
+ int stride,
+ bool batchMode);
+
TH_API void THNN_(HardTanh_updateOutput)(
THCState *state,
THCTensor *input,
diff --git a/test.lua b/test.lua
index 670c70e..fb65bd9 100644
--- a/test.lua
+++ b/test.lua
@@ -5281,6 +5281,149 @@ function cunntest.VolumetricAveragePooling_backward()
end
end
+function cunntest.FeatureLPPooling_forward()
+ for tries = 1, 5 do
+ local batch_mode = {true, false}
+ batch_mode = batch_mode[math.random(1, 2)]
+ local power = {2, 3}
+ power = power[math.random(1, 2)]
+
+ local dims = math.random(1, 3)
+
+ if batch_mode then
+ dims = dims + 1
+ end
+
+ local width = torch.random(2, 16)
+ local stride = torch.random(1, 4)
+
+ local output_size = torch.random(1, 100)
+ local input_size = (output_size - 1) * stride + width
+
+ local baseInput = nil
+ if dims == 1 then
+ baseInput = torch.Tensor(input_size):uniform()
+ elseif dims == 2 then
+ if batch_mode then
+ baseInput = torch.Tensor(math.random(1, 5), input_size):uniform()
+ else
+ baseInput = torch.Tensor(input_size, math.random(1, 5)):uniform()
+ end
+ elseif dims == 3 then
+ if batch_mode then
+ baseInput = torch.Tensor(math.random(1, 5), input_size,
+ math.random(1, 5)):uniform()
+ else
+ baseInput = torch.Tensor(input_size, math.random(1, 5),
+ math.random(1, 5)):uniform()
+ end
+ else
+ baseInput = torch.Tensor(math.random(1, 5), input_size,
+ math.random(1, 5), math.random(1, 5)):uniform()
+ end
+
+ for k, typename in ipairs(typenames) do
+ local input = baseInput:type(typename)
+
+ local ctype = t2cpu[typename]
+ input = makeNonContiguous(input:type(ctype))
+ local sconv = nn.FeatureLPPooling(width, stride, power, batch_mode):type(ctype)
+ local groundtruth = sconv:forward(input)
+
+ input = makeNonContiguous(input:type(typename))
+ local gconv = nn.FeatureLPPooling(width, stride, power, batch_mode):type(typename)
+ local rescuda = gconv:forward(input)
+
+ local error = rescuda:double() - groundtruth:double()
+ mytester:assertlt(error:abs():max(),
+ precision_forward_type(precision_forward, typename),
+ string.format('error on state (forward) with %s', typename))
+ end
+ end
+end
+
+function cunntest.FeatureLPPooling_backward()
+ for tries = 1, 5 do
+ local batch_mode = {true, false}
+ batch_mode = batch_mode[math.random(1, 2)]
+ local power = {2, 3}
+ power = power[math.random(1, 2)]
+
+ local dims = math.random(1, 3)
+
+ if batch_mode then
+ dims = dims + 1
+ end
+
+ local width = torch.random(2, 16)
+ local stride = torch.random(1, 4)
+
+ local output_size = torch.random(1, 100)
+ local input_size = (output_size - 1) * stride + width
+
+ local baseInput = nil
+ local baseGradOutput = nil
+
+ if dims == 1 then
+ baseInput = torch.Tensor(input_size):uniform()
+ baseGradOutput = torch.Tensor(output_size):uniform()
+ elseif dims == 2 then
+ local a = math.random(1, 5)
+ if batch_mode then
+ baseInput = torch.Tensor(a, input_size):uniform()
+ baseGradOutput = torch.Tensor(a, output_size):uniform()
+ else
+ baseInput = torch.Tensor(input_size, a):uniform()
+ baseGradOutput = torch.Tensor(output_size, a):uniform()
+ end
+ elseif dims == 3 then
+ local a = math.random(1, 5)
+ local b = math.random(1, 5)
+ if batch_mode then
+ baseInput = torch.Tensor(a, input_size, b):uniform()
+ baseGradOutput = torch.Tensor(a, output_size, b):uniform()
+ else
+ baseInput = torch.Tensor(input_size, a, b):uniform()
+ baseGradOutput = torch.Tensor(output_size, a, b):uniform()
+ end
+ else
+ local a = math.random(1, 5)
+ local b = math.random(1, 5)
+ local c = math.random(1, 5)
+ baseInput = torch.Tensor(a, input_size, b, c):uniform()
+ baseGradOutput = torch.Tensor(a, output_size, b, c):uniform()
+ end
+
+ for k, typename in ipairs(typenames) do
+ local input = baseInput:type(typename)
+ local gradOutput = baseGradOutput:type(typename)
+ local ctype = t2cpu[typename]
+ input = makeNonContiguous(input:type(ctype))
+ gradOutput = makeNonContiguous(gradOutput:type(ctype))
+
+ local sconv = nn.FeatureLPPooling(width, stride, power, batch_mode):type(ctype)
+ if ceil_mode then sconv:ceil() end
+ sconv:forward(input)
+ sconv:zeroGradParameters()
+ local groundgrad = sconv:backward(input, gradOutput)
+
+ input = makeNonContiguous(input:type(typename))
+ gradOutput = makeNonContiguous(gradOutput:type(typename))
+ local gconv = nn.FeatureLPPooling(width, stride, power, batch_mode):type(typename)
+ if ceil_mode then gconv:ceil() end
+
+ gconv:forward(input)
+ gconv:zeroGradParameters()
+ local rescuda = gconv:backward(input, gradOutput)
+
+ local error = rescuda:double() - groundgrad:double()
+
+ mytester:assertlt(error:abs():max(), precision_backward_type(precision_backward, typename),
+ string.format('error on state (backward) with %s', typename))
+ end
+ end
+end
+
function cunntest.CMul_forward_batch()
local bs = math.random(8,32)
local nini = math.random(1,100)