Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2017-07-26 00:01:50 +0300
committerGitHub <noreply@github.com>2017-07-26 00:01:50 +0300
commitf6134120040624e7b387b81488d0208f881a1904 (patch)
treec8e2cb6081dc9a0f2151f2b1a1e02627333dac25
parent14cedef2d03dcbdd95e49be935c7368ed3d475c5 (diff)
parent9c5ddccde8fbffc8b181653169418dca561e05e1 (diff)
Merge pull request #1259 from wickedfoo/feature_lp_pooling
CPU implementation of L_p feature pooling
-rw-r--r--FeatureLPPooling.lua74
-rwxr-xr-xinit.lua2
-rw-r--r--lib/THNN/generic/FeatureLPPooling.c348
-rw-r--r--lib/THNN/generic/THNN.h20
-rw-r--r--lib/THNN/init.c4
-rwxr-xr-xtest.lua312
6 files changed, 759 insertions, 1 deletions
diff --git a/FeatureLPPooling.lua b/FeatureLPPooling.lua
new file mode 100644
index 0000000..5de4656
--- /dev/null
+++ b/FeatureLPPooling.lua
@@ -0,0 +1,74 @@
+
+local FeatureLPPooling, parent =
+ torch.class('nn.FeatureLPPooling', 'nn.Module')
+
+--[[
+ Possible inputs that we handle:
+
+ #### `batch_mode = false`
+ The dimensionality of the input chooses between the following modes:
+
+ ```
+ [feature dim]
+ [feature dim][opt dim 1]
+ [feature dim][opt dim 1][opt dim 2]
+ ```
+
+ #### `batch_mode = true`
+ The dimensionality of the input chooses between the following modes:
+ ```
+ [batch dim][feature dim]
+ [batch dim][feature dim][opt dim 1]
+ [batch dim][feature dim][opt dim 1][opt dim 2]
+ ```
+
+ The output has the same number of dimensions as the input, except the feature
+ dimension size is reduced to ((`input` - `width`) / `stride`) + 1
+]]
+function FeatureLPPooling:__init(width, stride, power, batch_mode)
+ parent.__init(self)
+
+ if (width < 2 or width > 16) then
+ error('width must be within 2 to 16')
+ end
+
+ if (stride < 1 or stride > 4) then
+ error('stride must be within 1 to 4')
+ end
+
+ self.width = width
+ self.stride = stride
+ self.power = power
+ self.batch_mode = batch_mode
+
+ self.output = torch.Tensor()
+ self.gradInput = torch.Tensor()
+end
+
+function FeatureLPPooling:updateOutput(input)
+ input.THNN.FeatureLPPooling_updateOutput(input:cdata(),
+ self.output:cdata(),
+ self.power,
+ self.width,
+ self.stride,
+ self.batch_mode)
+ return self.output
+end
+
+function FeatureLPPooling:updateGradInput(input, gradOutput)
+ input.THNN.FeatureLPPooling_updateGradInput(gradOutput:cdata(),
+ input:cdata(),
+ self.output:cdata(),
+ self.gradInput:cdata(),
+ self.power,
+ self.width,
+ self.stride,
+ self.batch_mode)
+ return self.gradInput
+end
+
+function FeatureLPPooling:__tostring__()
+ return string.format('%s(w%d s%d power %d batch %d',
+ torch.type(self),
+ self.width, self.stride, self.power, self.batch_mode)
+end
diff --git a/init.lua b/init.lua
index 009504c..21ac789 100755
--- a/init.lua
+++ b/init.lua
@@ -160,6 +160,8 @@ require('nn.VolumetricAveragePooling')
require('nn.VolumetricBatchNormalization')
require('nn.VolumetricReplicationPadding')
+require('nn.FeatureLPPooling')
+
require('nn.GPU')
require('nn.ParallelTable')
diff --git a/lib/THNN/generic/FeatureLPPooling.c b/lib/THNN/generic/FeatureLPPooling.c
new file mode 100644
index 0000000..25a58db
--- /dev/null
+++ b/lib/THNN/generic/FeatureLPPooling.c
@@ -0,0 +1,348 @@
+#ifndef TH_GENERIC_FILE
+#define TH_GENERIC_FILE "generic/FeatureLPPooling.c"
+#else
+
+#ifndef FEATURE_LP_DEFS
+#define FEATURE_LP_DEFS
+
+typedef struct {
+ size_t size[4];
+ size_t stride[4];
+} FeatureLPPoolingSizes;
+
+inline size_t flpGetOffset(FeatureLPPoolingSizes* s,
+ size_t batch,
+ size_t feature,
+ size_t opt1,
+ size_t opt2) {
+ return s->stride[0] * batch +
+ s->stride[1] * feature +
+ s->stride[2] * opt1 +
+ s->stride[3] * opt2;
+}
+
+inline size_t flpOutputSize(size_t inputSize,
+ size_t width,
+ size_t stride) {
+ return ((inputSize - width) / stride) + 1;
+}
+
+#endif // FEATURE_LP_DEFS
+
+FeatureLPPoolingSizes
+THNN_(FeatureLPPooling_upcastCPU)(THTensor* t, bool batchMode) {
+ int dim = THTensor_(nDimension)(t);
+
+ // Upcast to [batch dim][feature dim][opt dim 1][opt dim 2]
+ FeatureLPPoolingSizes s;
+ for (int i = 0; i < 4; ++i) {
+ s.size[i] = 1;
+ s.stride[i] = 1;
+ }
+
+ if (dim == 1) {
+ THAssert(!batchMode);
+ // [feature dim]
+ s.size[1] = THTensor_(size)(t, 0);
+ s.stride[1] = THTensor_(stride)(t, 0);
+ } else if (dim == 2) {
+ if (batchMode) {
+ // [batch dim][feature dim]
+ for (int i = 0; i < 2; ++i) {
+ s.size[i] = THTensor_(size)(t, i);
+ s.stride[i] = THTensor_(stride)(t, i);
+ }
+ } else {
+ // [feature dim][opt dim 1]
+ s.size[1] = THTensor_(size)(t, 0);
+ s.stride[1] = THTensor_(stride)(t, 0);
+ s.size[2] = THTensor_(size)(t, 1);
+ s.stride[2] = THTensor_(stride)(t, 1);
+ }
+ } else if (dim == 3) {
+ if (batchMode) {
+ // [batch dim][feature dim][opt dim 1]
+ for (int i = 0; i < 3; ++i) {
+ s.size[i] = THTensor_(size)(t, i);
+ s.stride[i] = THTensor_(stride)(t, i);
+ }
+ } else {
+ // [feature dim][opt dim 1][opt dim 2]
+ for (int i = 1; i < 4; ++i) {
+ s.size[i] = THTensor_(size)(t, i - 1);
+ s.stride[i] = THTensor_(stride)(t, i - 1);
+ }
+ }
+ } else if (dim == 4) {
+ // [batch dim][feature dim][opt dim 1][opt dim 2]
+ THAssert(batchMode);
+ for (int i = 0; i < 4; ++i) {
+ s.size[i] = THTensor_(size)(t, i);
+ s.stride[i] = THTensor_(stride)(t, i);
+ }
+ }
+
+ return s;
+}
+
+void
+THNN_(FeatureLPPooling_resizeForOutputCPU)(THTensor* toResize,
+ THTensor* input,
+ bool batchMode,
+ int width,
+ int stride) {
+ int inputDim = THTensor_(nDimension)(input);
+ THAssert(inputDim >= 1 && inputDim <= 4);
+
+ long outSize =
+ flpOutputSize(THTensor_(size)(input, 0), width, stride);
+ if (batchMode) {
+ THAssert(inputDim > 1);
+ outSize =
+ flpOutputSize(THTensor_(size)(input, 1), width, stride);
+ } else {
+ THAssert(inputDim < 4);
+ }
+
+ if (inputDim == 1) {
+ THTensor_(resize1d)(toResize, outSize);
+ } else if (inputDim == 2) {
+ if (batchMode) {
+ THTensor_(resize2d)(toResize,
+ THTensor_(size)(input, 0),
+ outSize);
+ } else {
+ THTensor_(resize2d)(toResize,
+ outSize,
+ THTensor_(size)(input, 1));
+ }
+ } else if (inputDim == 3) {
+ if (batchMode) {
+ THTensor_(resize3d)(toResize,
+ THTensor_(size)(input, 0), outSize,
+ THTensor_(size)(input, 2));
+ } else {
+ THTensor_(resize3d)(toResize,
+ outSize, THTensor_(size)(input, 1),
+ THTensor_(size)(input, 2));
+ }
+ } else if (inputDim == 4) {
+ THTensor_(resize4d)(toResize,
+ THTensor_(size)(input, 0),
+ outSize,
+ THTensor_(size)(input, 2),
+ THTensor_(size)(input, 3));
+ }
+}
+
+// Makes `toResize` the same size/dimensionality as `src`
+void
+THNN_(FeatureLPPooling_resizeCPU)(THTensor* toResize,
+ THTensor* src) {
+ int inputDim = THTensor_(nDimension)(src);
+ THAssert(inputDim >= 1 && inputDim <= 4);
+
+ if (inputDim == 1) {
+ THTensor_(resize1d)(toResize,
+ THTensor_(size)(src, 0));
+ } else if (inputDim == 2) {
+ THTensor_(resize2d)(
+ toResize,
+ THTensor_(size)(src, 0),
+ THTensor_(size)(src, 1));
+ } else if (inputDim == 3) {
+ THTensor_(resize3d)(
+ toResize,
+ THTensor_(size)(src, 0),
+ THTensor_(size)(src, 1),
+ THTensor_(size)(src, 2));
+ } else if (inputDim == 4) {
+ THTensor_(resize4d)(
+ toResize,
+ THTensor_(size)(src, 0),
+ THTensor_(size)(src, 1),
+ THTensor_(size)(src, 2),
+ THTensor_(size)(src, 3));
+ }
+}
+
+void
+THNN_(FeatureLPPooling_updateOutput)(
+ THNNState *state,
+ THTensor *input,
+ THTensor *output,
+ accreal power,
+ int width,
+ int stride,
+ bool batchMode) {
+ int inputDim = THTensor_(nDimension)(input);
+
+ if (batchMode) {
+ THArgCheck(inputDim >= 2 && inputDim <= 4, 2,
+ "input must be 2-4 dimensions for batch mode");
+ } else {
+ THArgCheck(inputDim >= 1 && inputDim <= 3, 2,
+ "input must be 1-3 dimensions for non-batch mode");
+ }
+
+ FeatureLPPoolingSizes inputDesc =
+ THNN_(FeatureLPPooling_upcastCPU)(input, batchMode);
+
+ // Make sure the feature dimension is properly sized
+ THArgCheck(inputDesc.size[1] >= width, 3,
+ "input: feature dimension must be >= width");
+
+ // Make sure that width and stride are within range
+ THArgCheck(width >= 2 && width <= 16, 5,
+ "width must be between 2 - 16");
+
+ THArgCheck(stride >= 1 && stride <= 4, 6,
+ "stride must be between 1 - 4");
+
+ // Resize output
+
+ THNN_(FeatureLPPooling_resizeForOutputCPU)(
+ output, input, batchMode, width, stride);
+
+ FeatureLPPoolingSizes outputDesc =
+ THNN_(FeatureLPPooling_upcastCPU)(output, batchMode);
+
+ real* inputP = THTensor_(data)(input);
+ real* outputP = THTensor_(data)(output);
+
+#pragma omp parallel for
+ for (size_t batch = 0; batch < inputDesc.size[0]; ++batch) {
+ for (size_t opt1 = 0; opt1 < inputDesc.size[2]; ++opt1) {
+ for (size_t opt2 = 0; opt2 < inputDesc.size[3]; ++opt2) {
+ for (size_t outputFeature = 0;
+ outputFeature < outputDesc.size[1]; ++outputFeature) {
+
+ accreal v = (accreal) 0;
+ for (size_t i = 0; i < width; ++i) {
+ size_t inputFeature = outputFeature * stride + i;
+ if (inputFeature >= inputDesc.size[1]) {
+ break;
+ }
+
+ v +=
+ pow(inputP[flpGetOffset(&inputDesc,
+ batch,
+ inputFeature,
+ opt1,
+ opt2)], power);
+ }
+
+ outputP[flpGetOffset(&outputDesc, batch, outputFeature, opt1, opt2)] =
+ pow(v, (accreal) 1 / power);
+ }
+ }
+ }
+ }
+}
+
+void
+THNN_(FeatureLPPooling_updateGradInput)(
+ THNNState *state,
+ THTensor* gradOutput,
+ THTensor* input,
+ THTensor* output,
+ THTensor* gradInput,
+ accreal power,
+ int width,
+ int stride,
+ bool batchMode) {
+ int inputDim = THTensor_(nDimension)(input);
+
+ if (batchMode) {
+ THArgCheck(inputDim >= 2 && inputDim <= 4, 3,
+ "input must be 2-4 dimensions for batch mode");
+ } else {
+ THArgCheck(inputDim >= 1 && inputDim <= 3, 3,
+ "input must be 1-3 dimensions for non-batch mode");
+ }
+
+ FeatureLPPoolingSizes inputDesc =
+ THNN_(FeatureLPPooling_upcastCPU)(input, batchMode);
+ FeatureLPPoolingSizes gradOutputDesc =
+ THNN_(FeatureLPPooling_upcastCPU)(gradOutput, batchMode);
+ FeatureLPPoolingSizes outputDesc =
+ THNN_(FeatureLPPooling_upcastCPU)(output, batchMode);
+
+ // Make sure the feature dimension is properly sized
+ THArgCheck(inputDesc.size[1] >= width, 3,
+ "input: feature dimension must be >= width");
+
+ // Make sure that width and stride are within range
+ THArgCheck(width >= 2 && width <= 16, 7,
+ "width must be between 2 - 16");
+
+ THArgCheck(stride >= 1 && stride <= 4, 8,
+ "stride must be between 1 - 4");
+
+ for (int i = 0; i < 4; ++i) {
+ THAssertMsg(outputDesc.size[i] == gradOutputDesc.size[i],
+ "output and gradOutput sizes do not match");
+ }
+
+ // Make sure that the input sizes produce the output sizes
+ THArgCheck(flpOutputSize(inputDesc.size[1], width, stride) ==
+ outputDesc.size[1], 3,
+ "input and output sizes do not match with respect to "
+ "width and stride");
+
+ // Resize `gradInput` based on `input`
+ THNN_(FeatureLPPooling_resizeCPU)(gradInput, input);
+
+ // Zero gradInput for accumulation
+ THTensor_(zero)(gradInput);
+
+ FeatureLPPoolingSizes gradInputDesc =
+ THNN_(FeatureLPPooling_upcastCPU)(gradInput, batchMode);
+
+ real* gradOutputP = THTensor_(data)(gradOutput);
+ real* gradInputP = THTensor_(data)(gradInput);
+ real* outputP = THTensor_(data)(output);
+ real* inputP = THTensor_(data)(input);
+
+#pragma omp parallel for
+ for (size_t batch = 0; batch < inputDesc.size[0]; ++batch) {
+ for (size_t opt1 = 0; opt1 < inputDesc.size[2]; ++opt1) {
+ for (size_t opt2 = 0; opt2 < inputDesc.size[3]; ++opt2) {
+ for (size_t outputFeature = 0;
+ outputFeature < outputDesc.size[1]; ++outputFeature) {
+
+ // Load output (f(x_is)). It is possible that this is zero, in
+ // which case we'll ignore this point.
+ real outputV =
+ outputP[
+ flpGetOffset(&outputDesc, batch, outputFeature, opt1, opt2)];
+
+ if (outputV == (real) 0) {
+ continue;
+ }
+
+ for (size_t i = 0; i < width; ++i) {
+ size_t inputFeature = outputFeature * stride + i;
+ THAssert(inputFeature < inputDesc.size[1]);
+
+ real gradOutputV =
+ gradOutputP[
+ flpGetOffset(&gradOutputDesc, batch, outputFeature, opt1, opt2)];
+ real inputV =
+ inputP[
+ flpGetOffset(&inputDesc, batch, inputFeature, opt1, opt2)];
+
+ // Calculate grad * (x_i / f(x_is))^(p - 1)
+ real v = gradOutputV * pow(inputV / outputV, power - (accreal) 1);
+
+ gradInputP[
+ flpGetOffset(&gradInputDesc, batch, inputFeature, opt1, opt2)]
+ += v;
+ }
+ }
+ }
+ }
+ }
+}
+
+#endif
diff --git a/lib/THNN/generic/THNN.h b/lib/THNN/generic/THNN.h
index 31b0795..ad4ea51 100644
--- a/lib/THNN/generic/THNN.h
+++ b/lib/THNN/generic/THNN.h
@@ -1461,6 +1461,26 @@ TH_API void THNN_(SpatialReplicationPadding_updateGradInput)(
int pad_l, int pad_r,
int pad_t, int pad_b);
+TH_API void THNN_(FeatureLPPooling_updateOutput)(
+ THNNState *state,
+ THTensor *input,
+ THTensor *output,
+ accreal power,
+ int width,
+ int stride,
+ bool batchMode);
+
+TH_API void THNN_(FeatureLPPooling_updateGradInput)(
+ THNNState *state,
+ THTensor* gradOutput,
+ THTensor* input,
+ THTensor* output,
+ THTensor* gradInput,
+ accreal power,
+ int width,
+ int stride,
+ bool batchMode);
+
TH_API void THNN_(VolumetricReplicationPadding_updateOutput)(
THNNState *state,
THTensor *input,
diff --git a/lib/THNN/init.c b/lib/THNN/init.c
index 12f3222..acb88c0 100644
--- a/lib/THNN/init.c
+++ b/lib/THNN/init.c
@@ -179,6 +179,9 @@
#include "generic/TemporalRowConvolution.c"
#include "THGenerateFloatTypes.h"
+#include "generic/FeatureLPPooling.c"
+#include "THGenerateFloatTypes.h"
+
#include "generic/BatchNormalization.c"
#include "THGenerateFloatTypes.h"
@@ -280,4 +283,3 @@
#include "generic/VolumetricUpSamplingTrilinear.c"
#include "THGenerateFloatTypes.h"
-
diff --git a/test.lua b/test.lua
index 77e8cb8..35852fa 100755
--- a/test.lua
+++ b/test.lua
@@ -8788,6 +8788,318 @@ function nntest.Kmeans()
end
end
+function nntest.FeatureLPPooling()
+ local verbose = false
+
+ local num_tries = 2
+ local jacobian = nn.Jacobian
+ local precision = 4e-3
+
+ local batch_max = 3
+ local feature_max = 100
+ local dim1_max = 3
+ local dim2_max = 3
+
+ local function pickPow()
+ local num = torch.random(4)
+ if num == 1 then
+ return 1
+ else
+ return (num - 1) * 2.0
+ end
+ end
+
+ local function runFPropTest(dims, width, stride, pow, batch_mode)
+ local pool = nn.FeatureLPPooling(width, stride, pow, batch_mode):float()
+
+ local num_batch = torch.random(batch_max)
+ local num_features = (torch.random(feature_max) - 1) * stride + width
+ local num_dim1 = torch.random(dim1_max)
+ local num_dim2 = torch.random(dim2_max)
+
+ if verbose then
+ print('test on dim ' .. dims ..
+ ' features ' .. num_features ..
+ ' width ' .. width .. ' stride ' .. stride ..
+ ' p ' .. pow .. ' bm ' .. (batch_mode and 1 or 0))
+ end
+
+ local input = nil
+ if dims == 1 then
+ if batch_mode then
+ input = torch.FloatTensor(num_batch, num_features)
+
+ for i = 1, num_batch do
+ for f = 1, num_features do
+ input[i][f] = f - 1
+ end
+ end
+
+ else
+ input = torch.FloatTensor(num_features)
+
+ for f = 1, num_features do
+ input[f] = f - 1
+ end
+
+ end
+ elseif dims == 2 then
+ if batch_mode then
+ input = torch.FloatTensor(num_batch, num_features, num_dim1)
+
+ for i = 1, num_batch do
+ for f = 1, num_features do
+ for j = 1, num_dim1 do
+ input[i][f][j] = f - 1
+ end
+ end
+ end
+
+ else
+ input = torch.FloatTensor(num_features, num_dim1)
+
+ for f = 1, num_features do
+ for j = 1, num_dim1 do
+ input[f][j] = f - 1
+ end
+ end
+
+ end
+ elseif dims == 3 then
+ if batch_mode then
+ input = torch.FloatTensor(num_batch, num_features, num_dim1, num_dim2)
+
+ for i = 1, num_batch do
+ for f = 1, num_features do
+ for j = 1, num_dim1 do
+ for k = 1, num_dim2 do
+ input[i][f][j][k] = f - 1
+ end
+ end
+ end
+ end
+
+ else
+ input = torch.FloatTensor(num_features, num_dim1, num_dim2)
+
+ for f = 1, num_features do
+ for j = 1, num_dim1 do
+ for k = 1, num_dim2 do
+ input[f][j][k] = f - 1
+ end
+ end
+ end
+
+ end
+ end
+
+ local output = pool:forward(input)
+
+ -- Each output feature o(k) (k zero based) for L1 is:
+ -- sum(i((k - 1) * s), i((k - 1) * s + 1), ..., i((k - 1) * s + w - 1))
+ -- if i(x) = x, then: o(k) = w * (k - 1) * s + w * (w - 1) / 2
+ -- For Lp (p != 1), just evaluate ourselves and compare
+
+ local function verifyFeature(val, k, width, stride, pow)
+ local sum_input = 0
+ if pow == 1 then
+ sum_input = width * (k - 1) * stride + width * (width - 1) / 2
+ else
+ for w = 0, width - 1 do
+ sum_input = sum_input + math.pow((k - 1) * stride + w, pow)
+ end
+ sum_input = math.pow(sum_input, 1 / pow)
+ end
+
+ local diff = math.abs(val - sum_input)
+ if (diff >= 1e-3) then
+ if verbose then
+ print('failed on ' .. val .. ' ' .. sum_input)
+ end
+ mytester:assertlt(math.abs(val - sum_input), 1e-3)
+ end
+ end
+
+ if dims == 1 then
+ if batch_mode then
+ for i = 1, output:size(1) do
+ for f = 1, output:size(2) do
+ verifyFeature(output[i][f], f, width, stride, pow)
+ end
+ end
+
+ else
+ for f = 1, output:size(1) do
+ verifyFeature(output[f], f, width, stride, pow)
+ end
+
+ end
+ elseif dims == 2 then
+ if batch_mode then
+ for i = 1, output:size(1) do
+ for f = 1, output:size(2) do
+ for j = 1, output:size(3) do
+ verifyFeature(output[i][f][j], f, width, stride, pow)
+ end
+ end
+ end
+
+ else
+ for f = 1, output:size(1) do
+ for j = 1, output:size(2) do
+ verifyFeature(output[f][j], f, width, stride, pow)
+ end
+ end
+
+ end
+ elseif dims == 3 then
+ if batch_mode then
+ for i = 1, output:size(1) do
+ for f = 1, output:size(2) do
+ for j = 1, output:size(3) do
+ for k = 1, output:size(4) do
+ verifyFeature(output[i][f][j][k], f, width, stride, pow)
+ end
+ end
+ end
+ end
+
+ else
+ for f = 1, output:size(1) do
+ for j = 1, output:size(2) do
+ for k = 1, output:size(3) do
+ verifyFeature(output[f][j][k], f, width, stride, pow)
+ end
+ end
+ end
+
+ end
+ end
+ end
+
+ local function runBPropTest(dims, width, stride, pow, batch_mode)
+ local pool = nn.FeatureLPPooling(width, stride, pow, batch_mode):float()
+
+ local num_batch = torch.random(batch_max)
+ local num_features = (torch.random(feature_max) - 1) * stride + width
+ local num_dim1 = torch.random(dim1_max)
+ local num_dim2 = torch.random(dim2_max)
+
+ local input = nil
+ if dims == 1 then
+ if batch_mode then
+ input = torch.FloatTensor(num_batch, num_features)
+ else
+ input = torch.FloatTensor(num_features)
+ end
+ elseif dims == 2 then
+ if batch_mode then
+ input = torch.FloatTensor(num_batch, num_features, num_dim1)
+ else
+ input = torch.FloatTensor(num_features, num_dim1)
+ end
+ elseif dims == 3 then
+ if batch_mode then
+ input = torch.FloatTensor(num_batch, num_features, num_dim1, num_dim2)
+ else
+ input = torch.FloatTensor(num_features, num_dim1, num_dim2)
+ end
+ end
+
+ local err = jacobian.testJacobian(pool, input, -2, -2, 5e-4)
+ if verbose then
+ print('test on dim ' .. dims ..
+ ' features ' .. num_features ..
+ ' width ' .. width .. ' stride ' .. stride ..
+ ' p ' .. pow .. ' err ' .. err)
+ end
+ mytester:assertlt(err, precision)
+ end
+
+ function testForwardLp()
+ for i = 1, num_tries do
+ for stride = 1, 4 do
+ for idx, batch_mode in ipairs({true, false}) do
+ for dims = 1, 3 do
+ runFPropTest(dims, 1 + torch.random(15),
+ stride, pickPow(), batch_mode)
+ end
+ end
+ end
+ end
+ end
+
+ function testZeroBProp()
+ local pool = nn.FeatureLPPooling(3, 1, 2.0, false):float()
+
+ local input = torch.FloatTensor(100):zero()
+ pool:forward(input)
+
+ local gradOutput = torch.FloatTensor(98):zero()
+ local gradInput = pool:backward(input, gradOutput, 1.0)
+
+ for i = 1, gradInput:size(1) do
+ mytester:asserteq(gradInput[i], 0)
+ end
+ end
+
+ function testJacobian1dNoBatch()
+ for i = 1, num_tries do
+ for stride = 1, 4 do
+ runBPropTest(1, 1 + torch.random(15), stride, pickPow(), false)
+ end
+ end
+ end
+
+ function testJacobian1dBatch()
+ for i = 1, num_tries do
+ for stride = 1, 4 do
+ runBPropTest(1, 1 + torch.random(15), stride, pickPow(), true)
+ end
+ end
+ end
+
+ function testJacobian2dNoBatch()
+ for i = 1, num_tries do
+ for stride = 1, 4 do
+ runBPropTest(2, 1 + torch.random(15), stride, pickPow(), false)
+ end
+ end
+ end
+
+ function testJacobian2dBatch()
+ for i = 1, num_tries do
+ for stride = 1, 4 do
+ runBPropTest(2, 1 + torch.random(15), stride, pickPow(), true)
+ end
+ end
+ end
+
+ function testJacobian3dNoBatch()
+ for i = 1, num_tries do
+ for stride = 1, 4 do
+ runBPropTest(3, 1 + torch.random(15), stride, pickPow(), false)
+ end
+ end
+ end
+
+ function testJacobian3dBatch()
+ for i = 1, num_tries do
+ for stride = 1, 4 do
+ runBPropTest(3, 1 + torch.random(15), stride, pickPow(), true)
+ end
+ end
+ end
+
+ testForwardLp()
+ testZeroBProp()
+ testJacobian1dBatch()
+ testJacobian2dNoBatch()
+ testJacobian2dBatch()
+ testJacobian3dNoBatch()
+ testJacobian3dBatch()
+end
+
mytester:add(nntest)
jac = nn.Jacobian