Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorsoumith <soumith@fb.com>2015-10-24 02:32:49 +0300
committersoumith <soumith@gmail.com>2015-10-24 02:49:39 +0300
commit12438b9a51fd9a2a217384d45fff1eb55b12f725 (patch)
treedb1468f18a001f2261d0dde74350c766fc3d7270 /SpatialFractionalMaxPooling.lua
parent96e3c3070446d77c413e58f0613f71f599b6594e (diff)
Adding Fractional Max Pooling
Diffstat (limited to 'SpatialFractionalMaxPooling.lua')
-rw-r--r--SpatialFractionalMaxPooling.lua156
1 files changed, 156 insertions, 0 deletions
diff --git a/SpatialFractionalMaxPooling.lua b/SpatialFractionalMaxPooling.lua
new file mode 100644
index 0000000..f0bfc73
--- /dev/null
+++ b/SpatialFractionalMaxPooling.lua
@@ -0,0 +1,156 @@
+local SpatialFractionalMaxPooling, parent =
+ torch.class('nn.SpatialFractionalMaxPooling', 'nn.Module')
+
+-- Usage:
+-- nn.SpatialFractionalMaxPooling(poolSizeW, poolSizeH, outW, outH)
+-- the output should be the exact size (outH x outW)
+-- nn.SpatialFractionalMaxPooling(poolSizeW, poolSizeH, ratioW, ratioH)
+-- the output should be the size (floor(inH x ratioH) x floor(inW x ratioW))
+-- ratios are numbers between (0, 1) exclusive
+function SpatialFractionalMaxPooling:__init(poolSizeW, poolSizeH, arg1, arg2)
+ parent.__init(self)
+ assert(poolSizeW >= 2)
+ assert(poolSizeH >= 2)
+
+ -- Pool size (how wide the pooling for each output unit is)
+ self.poolSizeW = poolSizeW
+ self.poolSizeH = poolSizeH
+ self.indices = torch.Tensor()
+
+ -- Random samples are drawn for all
+ -- batch * plane * (height, width; i.e., 2) points. This determines
+ -- the 2d "pseudorandom" overlapping pooling regions for each
+ -- (batch element x input plane). A new set of random samples is
+ -- drawn every updateOutput call, unless we disable it via
+ -- :fixPoolingRegions().
+ self.randomSamples = nil
+
+ -- Flag to disable re-generation of random samples for producing
+ -- a new pooling. For testing purposes
+ self.newRandomPool = false
+
+ if arg1 >= 1 and arg2 >= 1 then
+ -- Desired output size: the input tensor will determine the reduction
+ -- ratio
+ self.outW = arg1
+ self.outH = arg2
+ else
+ -- Reduction ratio specified per each input
+ -- This is the reduction ratio that we use
+ self.ratioW = arg1
+ self.ratioH = arg2
+
+ -- The reduction ratio must be between 0 and 1
+ assert(self.ratioW > 0 and self.ratioW < 1)
+ assert(self.ratioH > 0 and self.ratioH < 1)
+ end
+end
+
+function SpatialFractionalMaxPooling:getBufferSize_(input)
+ local batchSize = 0
+ local planeSize = 0
+
+ if input:nDimension() == 3 then
+ batchSize = 1
+ planeSize = input:size(1)
+ elseif input:nDimension() == 4 then
+ batchSize = input:size(1)
+ planeSize = input:size(2)
+ else
+ error('input must be dim 3 or 4')
+ end
+
+ return torch.LongStorage({batchSize, planeSize, 2})
+end
+
+function SpatialFractionalMaxPooling:initSampleBuffer_(input)
+ local sampleBufferSize = self:getBufferSize_(input)
+
+ if self.randomSamples == nil then
+ self.randomSamples = input.new():resize(sampleBufferSize):uniform()
+ elseif (self.randomSamples:size(1) ~= sampleBufferSize[1] or
+ self.randomSamples:size(2) ~= sampleBufferSize[2]) then
+ self.randomSamples:resize(sampleBufferSize):uniform()
+ else
+ if not self.newRandomPool then
+ -- Create new pooling windows, since this is a subsequent call
+ self.randomSamples:uniform()
+ end
+ end
+end
+
+function SpatialFractionalMaxPooling:getOutputSizes_(input)
+ local outW = self.outW
+ local outH = self.outH
+ if self.ratioW ~= nil and self.ratioH ~= nil then
+ if input:nDimension() == 4 then
+ outW = math.floor(input:size(4) * self.ratioW)
+ outH = math.floor(input:size(3) * self.ratioH)
+ elseif input:nDimension() == 3 then
+ outW = math.floor(input:size(3) * self.ratioW)
+ outH = math.floor(input:size(2) * self.ratioH)
+ else
+ error('input must be dim 3 or 4')
+ end
+
+ -- Neither can be smaller than 1
+ assert(outW > 0, 'reduction ratio or input width too small')
+ assert(outH > 0, 'reduction ratio or input height too small')
+ else
+ assert(outW ~= nil and outH ~= nil)
+ end
+
+ return outW, outH
+end
+
+-- Call this to turn off regeneration of random pooling regions each
+-- updateOutput call.
+function SpatialFractionalMaxPooling:fixPoolingRegions(val)
+ if val == nil then
+ val = true
+ end
+
+ self.newRandomPool = val
+ return self
+end
+
+function SpatialFractionalMaxPooling:updateOutput(input)
+ self:initSampleBuffer_(input)
+ local outW, outH = self:getOutputSizes_(input)
+
+ input.nn.SpatialFractionalMaxPooling_updateOutput(
+ self.output, input,
+ outW, outH, self.poolSizeW, self.poolSizeH,
+ self.indices, self.randomSamples)
+ return self.output
+end
+
+function SpatialFractionalMaxPooling:updateGradInput(input, gradOutput)
+ assert(self.randomSamples ~= nil,
+ 'must call updateOutput/forward first')
+
+ local outW, outH = self:getOutputSizes_(input)
+
+ input.nn.SpatialFractionalMaxPooling_updateGradInput(
+ self.gradInput, input, gradOutput,
+ outW, outH, self.poolSizeW, self.poolSizeH,
+ self.indices)
+ return self.gradInput
+end
+
+function SpatialFractionalMaxPooling:empty()
+ self.gradInput:resize()
+ self.gradInput:storage():resize(0)
+ self.output:resize()
+ self.output:storage():resize(0)
+ self.indices:resize()
+ self.indices:storage():resize(0)
+ self.randomSamples = nil
+end
+
+function SpatialFractionalMaxPooling:__tostring__()
+ return string.format('%s(%d,%d,%d,%d)', torch.type(self),
+ self.outW and self.outW or self.ratioW,
+ self.outH and self.outH or self.ratioH,
+ self.poolSizeW, self.poolSizeH)
+end