diff options
author | Soumith Chintala <soumith@gmail.com> | 2017-07-26 00:01:50 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-07-26 00:01:50 +0300 |
commit | f6134120040624e7b387b81488d0208f881a1904 (patch) | |
tree | c8e2cb6081dc9a0f2151f2b1a1e02627333dac25 | |
parent | 14cedef2d03dcbdd95e49be935c7368ed3d475c5 (diff) | |
parent | 9c5ddccde8fbffc8b181653169418dca561e05e1 (diff) |
Merge pull request #1259 from wickedfoo/feature_lp_pooling
CPU implementation of L_p feature pooling
-rw-r--r-- | FeatureLPPooling.lua | 74 | ||||
-rwxr-xr-x | init.lua | 2 | ||||
-rw-r--r-- | lib/THNN/generic/FeatureLPPooling.c | 348 | ||||
-rw-r--r-- | lib/THNN/generic/THNN.h | 20 | ||||
-rw-r--r-- | lib/THNN/init.c | 4 | ||||
-rwxr-xr-x | test.lua | 312 |
6 files changed, 759 insertions, 1 deletions
diff --git a/FeatureLPPooling.lua b/FeatureLPPooling.lua new file mode 100644 index 0000000..5de4656 --- /dev/null +++ b/FeatureLPPooling.lua @@ -0,0 +1,74 @@ + +local FeatureLPPooling, parent = + torch.class('nn.FeatureLPPooling', 'nn.Module') + +--[[ + Possible inputs that we handle: + + #### `batch_mode = false` + The dimensionality of the input chooses between the following modes: + + ``` + [feature dim] + [feature dim][opt dim 1] + [feature dim][opt dim 1][opt dim 2] + ``` + + #### `batch_mode = true` + The dimensionality of the input chooses between the following modes: + ``` + [batch dim][feature dim] + [batch dim][feature dim][opt dim 1] + [batch dim][feature dim][opt dim 1][opt dim 2] + ``` + + The output has the same number of dimensions as the input, except the feature + dimension size is reduced to ((`input` - `width`) / `stride`) + 1 +]] +function FeatureLPPooling:__init(width, stride, power, batch_mode) + parent.__init(self) + + if (width < 2 or width > 16) then + error('width must be within 2 to 16') + end + + if (stride < 1 or stride > 4) then + error('stride must be within 1 to 4') + end + + self.width = width + self.stride = stride + self.power = power + self.batch_mode = batch_mode + + self.output = torch.Tensor() + self.gradInput = torch.Tensor() +end + +function FeatureLPPooling:updateOutput(input) + input.THNN.FeatureLPPooling_updateOutput(input:cdata(), + self.output:cdata(), + self.power, + self.width, + self.stride, + self.batch_mode) + return self.output +end + +function FeatureLPPooling:updateGradInput(input, gradOutput) + input.THNN.FeatureLPPooling_updateGradInput(gradOutput:cdata(), + input:cdata(), + self.output:cdata(), + self.gradInput:cdata(), + self.power, + self.width, + self.stride, + self.batch_mode) + return self.gradInput +end + +function FeatureLPPooling:__tostring__() + return string.format('%s(w%d s%d power %d batch %d', + torch.type(self), + self.width, self.stride, self.power, self.batch_mode) +end @@ -160,6 +160,8 @@ require('nn.VolumetricAveragePooling') require('nn.VolumetricBatchNormalization') require('nn.VolumetricReplicationPadding') +require('nn.FeatureLPPooling') + require('nn.GPU') require('nn.ParallelTable') diff --git a/lib/THNN/generic/FeatureLPPooling.c b/lib/THNN/generic/FeatureLPPooling.c new file mode 100644 index 0000000..25a58db --- /dev/null +++ b/lib/THNN/generic/FeatureLPPooling.c @@ -0,0 +1,348 @@ +#ifndef TH_GENERIC_FILE +#define TH_GENERIC_FILE "generic/FeatureLPPooling.c" +#else + +#ifndef FEATURE_LP_DEFS +#define FEATURE_LP_DEFS + +typedef struct { + size_t size[4]; + size_t stride[4]; +} FeatureLPPoolingSizes; + +inline size_t flpGetOffset(FeatureLPPoolingSizes* s, + size_t batch, + size_t feature, + size_t opt1, + size_t opt2) { + return s->stride[0] * batch + + s->stride[1] * feature + + s->stride[2] * opt1 + + s->stride[3] * opt2; +} + +inline size_t flpOutputSize(size_t inputSize, + size_t width, + size_t stride) { + return ((inputSize - width) / stride) + 1; +} + +#endif // FEATURE_LP_DEFS + +FeatureLPPoolingSizes +THNN_(FeatureLPPooling_upcastCPU)(THTensor* t, bool batchMode) { + int dim = THTensor_(nDimension)(t); + + // Upcast to [batch dim][feature dim][opt dim 1][opt dim 2] + FeatureLPPoolingSizes s; + for (int i = 0; i < 4; ++i) { + s.size[i] = 1; + s.stride[i] = 1; + } + + if (dim == 1) { + THAssert(!batchMode); + // [feature dim] + s.size[1] = THTensor_(size)(t, 0); + s.stride[1] = THTensor_(stride)(t, 0); + } else if (dim == 2) { + if (batchMode) { + // [batch dim][feature dim] + for (int i = 0; i < 2; ++i) { + s.size[i] = THTensor_(size)(t, i); + s.stride[i] = THTensor_(stride)(t, i); + } + } else { + // [feature dim][opt dim 1] + s.size[1] = THTensor_(size)(t, 0); + s.stride[1] = THTensor_(stride)(t, 0); + s.size[2] = THTensor_(size)(t, 1); + s.stride[2] = THTensor_(stride)(t, 1); + } + } else if (dim == 3) { + if (batchMode) { + // [batch dim][feature dim][opt dim 1] + for (int i = 0; i < 3; ++i) { + s.size[i] = THTensor_(size)(t, i); + s.stride[i] = THTensor_(stride)(t, i); + } + } else { + // [feature dim][opt dim 1][opt dim 2] + for (int i = 1; i < 4; ++i) { + s.size[i] = THTensor_(size)(t, i - 1); + s.stride[i] = THTensor_(stride)(t, i - 1); + } + } + } else if (dim == 4) { + // [batch dim][feature dim][opt dim 1][opt dim 2] + THAssert(batchMode); + for (int i = 0; i < 4; ++i) { + s.size[i] = THTensor_(size)(t, i); + s.stride[i] = THTensor_(stride)(t, i); + } + } + + return s; +} + +void +THNN_(FeatureLPPooling_resizeForOutputCPU)(THTensor* toResize, + THTensor* input, + bool batchMode, + int width, + int stride) { + int inputDim = THTensor_(nDimension)(input); + THAssert(inputDim >= 1 && inputDim <= 4); + + long outSize = + flpOutputSize(THTensor_(size)(input, 0), width, stride); + if (batchMode) { + THAssert(inputDim > 1); + outSize = + flpOutputSize(THTensor_(size)(input, 1), width, stride); + } else { + THAssert(inputDim < 4); + } + + if (inputDim == 1) { + THTensor_(resize1d)(toResize, outSize); + } else if (inputDim == 2) { + if (batchMode) { + THTensor_(resize2d)(toResize, + THTensor_(size)(input, 0), + outSize); + } else { + THTensor_(resize2d)(toResize, + outSize, + THTensor_(size)(input, 1)); + } + } else if (inputDim == 3) { + if (batchMode) { + THTensor_(resize3d)(toResize, + THTensor_(size)(input, 0), outSize, + THTensor_(size)(input, 2)); + } else { + THTensor_(resize3d)(toResize, + outSize, THTensor_(size)(input, 1), + THTensor_(size)(input, 2)); + } + } else if (inputDim == 4) { + THTensor_(resize4d)(toResize, + THTensor_(size)(input, 0), + outSize, + THTensor_(size)(input, 2), + THTensor_(size)(input, 3)); + } +} + +// Makes `toResize` the same size/dimensionality as `src` +void +THNN_(FeatureLPPooling_resizeCPU)(THTensor* toResize, + THTensor* src) { + int inputDim = THTensor_(nDimension)(src); + THAssert(inputDim >= 1 && inputDim <= 4); + + if (inputDim == 1) { + THTensor_(resize1d)(toResize, + THTensor_(size)(src, 0)); + } else if (inputDim == 2) { + THTensor_(resize2d)( + toResize, + THTensor_(size)(src, 0), + THTensor_(size)(src, 1)); + } else if (inputDim == 3) { + THTensor_(resize3d)( + toResize, + THTensor_(size)(src, 0), + THTensor_(size)(src, 1), + THTensor_(size)(src, 2)); + } else if (inputDim == 4) { + THTensor_(resize4d)( + toResize, + THTensor_(size)(src, 0), + THTensor_(size)(src, 1), + THTensor_(size)(src, 2), + THTensor_(size)(src, 3)); + } +} + +void +THNN_(FeatureLPPooling_updateOutput)( + THNNState *state, + THTensor *input, + THTensor *output, + accreal power, + int width, + int stride, + bool batchMode) { + int inputDim = THTensor_(nDimension)(input); + + if (batchMode) { + THArgCheck(inputDim >= 2 && inputDim <= 4, 2, + "input must be 2-4 dimensions for batch mode"); + } else { + THArgCheck(inputDim >= 1 && inputDim <= 3, 2, + "input must be 1-3 dimensions for non-batch mode"); + } + + FeatureLPPoolingSizes inputDesc = + THNN_(FeatureLPPooling_upcastCPU)(input, batchMode); + + // Make sure the feature dimension is properly sized + THArgCheck(inputDesc.size[1] >= width, 3, + "input: feature dimension must be >= width"); + + // Make sure that width and stride are within range + THArgCheck(width >= 2 && width <= 16, 5, + "width must be between 2 - 16"); + + THArgCheck(stride >= 1 && stride <= 4, 6, + "stride must be between 1 - 4"); + + // Resize output + + THNN_(FeatureLPPooling_resizeForOutputCPU)( + output, input, batchMode, width, stride); + + FeatureLPPoolingSizes outputDesc = + THNN_(FeatureLPPooling_upcastCPU)(output, batchMode); + + real* inputP = THTensor_(data)(input); + real* outputP = THTensor_(data)(output); + +#pragma omp parallel for + for (size_t batch = 0; batch < inputDesc.size[0]; ++batch) { + for (size_t opt1 = 0; opt1 < inputDesc.size[2]; ++opt1) { + for (size_t opt2 = 0; opt2 < inputDesc.size[3]; ++opt2) { + for (size_t outputFeature = 0; + outputFeature < outputDesc.size[1]; ++outputFeature) { + + accreal v = (accreal) 0; + for (size_t i = 0; i < width; ++i) { + size_t inputFeature = outputFeature * stride + i; + if (inputFeature >= inputDesc.size[1]) { + break; + } + + v += + pow(inputP[flpGetOffset(&inputDesc, + batch, + inputFeature, + opt1, + opt2)], power); + } + + outputP[flpGetOffset(&outputDesc, batch, outputFeature, opt1, opt2)] = + pow(v, (accreal) 1 / power); + } + } + } + } +} + +void +THNN_(FeatureLPPooling_updateGradInput)( + THNNState *state, + THTensor* gradOutput, + THTensor* input, + THTensor* output, + THTensor* gradInput, + accreal power, + int width, + int stride, + bool batchMode) { + int inputDim = THTensor_(nDimension)(input); + + if (batchMode) { + THArgCheck(inputDim >= 2 && inputDim <= 4, 3, + "input must be 2-4 dimensions for batch mode"); + } else { + THArgCheck(inputDim >= 1 && inputDim <= 3, 3, + "input must be 1-3 dimensions for non-batch mode"); + } + + FeatureLPPoolingSizes inputDesc = + THNN_(FeatureLPPooling_upcastCPU)(input, batchMode); + FeatureLPPoolingSizes gradOutputDesc = + THNN_(FeatureLPPooling_upcastCPU)(gradOutput, batchMode); + FeatureLPPoolingSizes outputDesc = + THNN_(FeatureLPPooling_upcastCPU)(output, batchMode); + + // Make sure the feature dimension is properly sized + THArgCheck(inputDesc.size[1] >= width, 3, + "input: feature dimension must be >= width"); + + // Make sure that width and stride are within range + THArgCheck(width >= 2 && width <= 16, 7, + "width must be between 2 - 16"); + + THArgCheck(stride >= 1 && stride <= 4, 8, + "stride must be between 1 - 4"); + + for (int i = 0; i < 4; ++i) { + THAssertMsg(outputDesc.size[i] == gradOutputDesc.size[i], + "output and gradOutput sizes do not match"); + } + + // Make sure that the input sizes produce the output sizes + THArgCheck(flpOutputSize(inputDesc.size[1], width, stride) == + outputDesc.size[1], 3, + "input and output sizes do not match with respect to " + "width and stride"); + + // Resize `gradInput` based on `input` + THNN_(FeatureLPPooling_resizeCPU)(gradInput, input); + + // Zero gradInput for accumulation + THTensor_(zero)(gradInput); + + FeatureLPPoolingSizes gradInputDesc = + THNN_(FeatureLPPooling_upcastCPU)(gradInput, batchMode); + + real* gradOutputP = THTensor_(data)(gradOutput); + real* gradInputP = THTensor_(data)(gradInput); + real* outputP = THTensor_(data)(output); + real* inputP = THTensor_(data)(input); + +#pragma omp parallel for + for (size_t batch = 0; batch < inputDesc.size[0]; ++batch) { + for (size_t opt1 = 0; opt1 < inputDesc.size[2]; ++opt1) { + for (size_t opt2 = 0; opt2 < inputDesc.size[3]; ++opt2) { + for (size_t outputFeature = 0; + outputFeature < outputDesc.size[1]; ++outputFeature) { + + // Load output (f(x_is)). It is possible that this is zero, in + // which case we'll ignore this point. + real outputV = + outputP[ + flpGetOffset(&outputDesc, batch, outputFeature, opt1, opt2)]; + + if (outputV == (real) 0) { + continue; + } + + for (size_t i = 0; i < width; ++i) { + size_t inputFeature = outputFeature * stride + i; + THAssert(inputFeature < inputDesc.size[1]); + + real gradOutputV = + gradOutputP[ + flpGetOffset(&gradOutputDesc, batch, outputFeature, opt1, opt2)]; + real inputV = + inputP[ + flpGetOffset(&inputDesc, batch, inputFeature, opt1, opt2)]; + + // Calculate grad * (x_i / f(x_is))^(p - 1) + real v = gradOutputV * pow(inputV / outputV, power - (accreal) 1); + + gradInputP[ + flpGetOffset(&gradInputDesc, batch, inputFeature, opt1, opt2)] + += v; + } + } + } + } + } +} + +#endif diff --git a/lib/THNN/generic/THNN.h b/lib/THNN/generic/THNN.h index 31b0795..ad4ea51 100644 --- a/lib/THNN/generic/THNN.h +++ b/lib/THNN/generic/THNN.h @@ -1461,6 +1461,26 @@ TH_API void THNN_(SpatialReplicationPadding_updateGradInput)( int pad_l, int pad_r, int pad_t, int pad_b); +TH_API void THNN_(FeatureLPPooling_updateOutput)( + THNNState *state, + THTensor *input, + THTensor *output, + accreal power, + int width, + int stride, + bool batchMode); + +TH_API void THNN_(FeatureLPPooling_updateGradInput)( + THNNState *state, + THTensor* gradOutput, + THTensor* input, + THTensor* output, + THTensor* gradInput, + accreal power, + int width, + int stride, + bool batchMode); + TH_API void THNN_(VolumetricReplicationPadding_updateOutput)( THNNState *state, THTensor *input, diff --git a/lib/THNN/init.c b/lib/THNN/init.c index 12f3222..acb88c0 100644 --- a/lib/THNN/init.c +++ b/lib/THNN/init.c @@ -179,6 +179,9 @@ #include "generic/TemporalRowConvolution.c" #include "THGenerateFloatTypes.h" +#include "generic/FeatureLPPooling.c" +#include "THGenerateFloatTypes.h" + #include "generic/BatchNormalization.c" #include "THGenerateFloatTypes.h" @@ -280,4 +283,3 @@ #include "generic/VolumetricUpSamplingTrilinear.c" #include "THGenerateFloatTypes.h" - @@ -8788,6 +8788,318 @@ function nntest.Kmeans() end end +function nntest.FeatureLPPooling() + local verbose = false + + local num_tries = 2 + local jacobian = nn.Jacobian + local precision = 4e-3 + + local batch_max = 3 + local feature_max = 100 + local dim1_max = 3 + local dim2_max = 3 + + local function pickPow() + local num = torch.random(4) + if num == 1 then + return 1 + else + return (num - 1) * 2.0 + end + end + + local function runFPropTest(dims, width, stride, pow, batch_mode) + local pool = nn.FeatureLPPooling(width, stride, pow, batch_mode):float() + + local num_batch = torch.random(batch_max) + local num_features = (torch.random(feature_max) - 1) * stride + width + local num_dim1 = torch.random(dim1_max) + local num_dim2 = torch.random(dim2_max) + + if verbose then + print('test on dim ' .. dims .. + ' features ' .. num_features .. + ' width ' .. width .. ' stride ' .. stride .. + ' p ' .. pow .. ' bm ' .. (batch_mode and 1 or 0)) + end + + local input = nil + if dims == 1 then + if batch_mode then + input = torch.FloatTensor(num_batch, num_features) + + for i = 1, num_batch do + for f = 1, num_features do + input[i][f] = f - 1 + end + end + + else + input = torch.FloatTensor(num_features) + + for f = 1, num_features do + input[f] = f - 1 + end + + end + elseif dims == 2 then + if batch_mode then + input = torch.FloatTensor(num_batch, num_features, num_dim1) + + for i = 1, num_batch do + for f = 1, num_features do + for j = 1, num_dim1 do + input[i][f][j] = f - 1 + end + end + end + + else + input = torch.FloatTensor(num_features, num_dim1) + + for f = 1, num_features do + for j = 1, num_dim1 do + input[f][j] = f - 1 + end + end + + end + elseif dims == 3 then + if batch_mode then + input = torch.FloatTensor(num_batch, num_features, num_dim1, num_dim2) + + for i = 1, num_batch do + for f = 1, num_features do + for j = 1, num_dim1 do + for k = 1, num_dim2 do + input[i][f][j][k] = f - 1 + end + end + end + end + + else + input = torch.FloatTensor(num_features, num_dim1, num_dim2) + + for f = 1, num_features do + for j = 1, num_dim1 do + for k = 1, num_dim2 do + input[f][j][k] = f - 1 + end + end + end + + end + end + + local output = pool:forward(input) + + -- Each output feature o(k) (k zero based) for L1 is: + -- sum(i((k - 1) * s), i((k - 1) * s + 1), ..., i((k - 1) * s + w - 1)) + -- if i(x) = x, then: o(k) = w * (k - 1) * s + w * (w - 1) / 2 + -- For Lp (p != 1), just evaluate ourselves and compare + + local function verifyFeature(val, k, width, stride, pow) + local sum_input = 0 + if pow == 1 then + sum_input = width * (k - 1) * stride + width * (width - 1) / 2 + else + for w = 0, width - 1 do + sum_input = sum_input + math.pow((k - 1) * stride + w, pow) + end + sum_input = math.pow(sum_input, 1 / pow) + end + + local diff = math.abs(val - sum_input) + if (diff >= 1e-3) then + if verbose then + print('failed on ' .. val .. ' ' .. sum_input) + end + mytester:assertlt(math.abs(val - sum_input), 1e-3) + end + end + + if dims == 1 then + if batch_mode then + for i = 1, output:size(1) do + for f = 1, output:size(2) do + verifyFeature(output[i][f], f, width, stride, pow) + end + end + + else + for f = 1, output:size(1) do + verifyFeature(output[f], f, width, stride, pow) + end + + end + elseif dims == 2 then + if batch_mode then + for i = 1, output:size(1) do + for f = 1, output:size(2) do + for j = 1, output:size(3) do + verifyFeature(output[i][f][j], f, width, stride, pow) + end + end + end + + else + for f = 1, output:size(1) do + for j = 1, output:size(2) do + verifyFeature(output[f][j], f, width, stride, pow) + end + end + + end + elseif dims == 3 then + if batch_mode then + for i = 1, output:size(1) do + for f = 1, output:size(2) do + for j = 1, output:size(3) do + for k = 1, output:size(4) do + verifyFeature(output[i][f][j][k], f, width, stride, pow) + end + end + end + end + + else + for f = 1, output:size(1) do + for j = 1, output:size(2) do + for k = 1, output:size(3) do + verifyFeature(output[f][j][k], f, width, stride, pow) + end + end + end + + end + end + end + + local function runBPropTest(dims, width, stride, pow, batch_mode) + local pool = nn.FeatureLPPooling(width, stride, pow, batch_mode):float() + + local num_batch = torch.random(batch_max) + local num_features = (torch.random(feature_max) - 1) * stride + width + local num_dim1 = torch.random(dim1_max) + local num_dim2 = torch.random(dim2_max) + + local input = nil + if dims == 1 then + if batch_mode then + input = torch.FloatTensor(num_batch, num_features) + else + input = torch.FloatTensor(num_features) + end + elseif dims == 2 then + if batch_mode then + input = torch.FloatTensor(num_batch, num_features, num_dim1) + else + input = torch.FloatTensor(num_features, num_dim1) + end + elseif dims == 3 then + if batch_mode then + input = torch.FloatTensor(num_batch, num_features, num_dim1, num_dim2) + else + input = torch.FloatTensor(num_features, num_dim1, num_dim2) + end + end + + local err = jacobian.testJacobian(pool, input, -2, -2, 5e-4) + if verbose then + print('test on dim ' .. dims .. + ' features ' .. num_features .. + ' width ' .. width .. ' stride ' .. stride .. + ' p ' .. pow .. ' err ' .. err) + end + mytester:assertlt(err, precision) + end + + function testForwardLp() + for i = 1, num_tries do + for stride = 1, 4 do + for idx, batch_mode in ipairs({true, false}) do + for dims = 1, 3 do + runFPropTest(dims, 1 + torch.random(15), + stride, pickPow(), batch_mode) + end + end + end + end + end + + function testZeroBProp() + local pool = nn.FeatureLPPooling(3, 1, 2.0, false):float() + + local input = torch.FloatTensor(100):zero() + pool:forward(input) + + local gradOutput = torch.FloatTensor(98):zero() + local gradInput = pool:backward(input, gradOutput, 1.0) + + for i = 1, gradInput:size(1) do + mytester:asserteq(gradInput[i], 0) + end + end + + function testJacobian1dNoBatch() + for i = 1, num_tries do + for stride = 1, 4 do + runBPropTest(1, 1 + torch.random(15), stride, pickPow(), false) + end + end + end + + function testJacobian1dBatch() + for i = 1, num_tries do + for stride = 1, 4 do + runBPropTest(1, 1 + torch.random(15), stride, pickPow(), true) + end + end + end + + function testJacobian2dNoBatch() + for i = 1, num_tries do + for stride = 1, 4 do + runBPropTest(2, 1 + torch.random(15), stride, pickPow(), false) + end + end + end + + function testJacobian2dBatch() + for i = 1, num_tries do + for stride = 1, 4 do + runBPropTest(2, 1 + torch.random(15), stride, pickPow(), true) + end + end + end + + function testJacobian3dNoBatch() + for i = 1, num_tries do + for stride = 1, 4 do + runBPropTest(3, 1 + torch.random(15), stride, pickPow(), false) + end + end + end + + function testJacobian3dBatch() + for i = 1, num_tries do + for stride = 1, 4 do + runBPropTest(3, 1 + torch.random(15), stride, pickPow(), true) + end + end + end + + testForwardLp() + testZeroBProp() + testJacobian1dBatch() + testJacobian2dNoBatch() + testJacobian2dBatch() + testJacobian3dNoBatch() + testJacobian3dBatch() +end + mytester:add(nntest) jac = nn.Jacobian |