diff options
author | Michael 'myrhev' Mathieu <michael.mathieu@ens.fr> | 2012-04-27 00:32:15 +0400 |
---|---|---|
committer | Michael 'myrhev' Mathieu <michael.mathieu@ens.fr> | 2012-04-27 00:50:18 +0400 |
commit | 300bd31b0ec5bc108ce78b4f3730f8c67fc5a8bc (patch) | |
tree | 1e740a1d333966c20315ced3ccf25bca1d1cbacb /SpatialReSamplingEx.lua | |
parent | c551b07dabd2beebdfe7e534e2a8dd247077c61d (diff) |
Add SpatialReSamplingEx (grouping all the Spatial*Sampling).
SpatialPadding can now pad on dimensions different than (2,3).
Add Tic/Toc modules, to time a network
Diffstat (limited to 'SpatialReSamplingEx.lua')
-rw-r--r-- | SpatialReSamplingEx.lua | 82 |
1 files changed, 82 insertions, 0 deletions
diff --git a/SpatialReSamplingEx.lua b/SpatialReSamplingEx.lua new file mode 100644 index 0000000..6b9f562 --- /dev/null +++ b/SpatialReSamplingEx.lua @@ -0,0 +1,82 @@ +local SpatialReSamplingEx, parent = torch.class('nn.SpatialReSamplingEx', 'nn.Module') + +local help_desc = [[ + Extended spatial resampling. +]] +function SpatialReSamplingEx:__init(...) + parent.__init(self) + + -- get args + xlua.unpack_class( + self, {...}, 'nn.SpatialReSampling', help_desc, + {arg='rwidth', type='number', help='ratio: owidth/iwidth'}, + {arg='rheight', type='number', help='ratio: oheight/iheight'}, + {arg='owidth', type='number', help='output width'}, + {arg='oheight', type='number', help='output height'}, + {arg='mode', type='string', help='Mode : simple | average (only for downsampling) | bilinear', default = 'simple'}, + {arg='yDim', type='number', help='image y dimension', default=2}, + {arg='xDim', type='number', help='image x dimension', default=3} + ) + if self.yDim+1 ~= self.xDim then + error('nn.SpatialReSamplingEx: yDim must be equals to xDim-1') + end + self.outputSize = torch.LongStorage(4) + self.inputSize = torch.LongStorage(4) + if self.mode == 'simple' then self.mode_c = 0 end + if self.mode == 'average' then self.mode_c = 1 end + if self.mode == 'bilinear' then self.mode_c = 2 end + if not self.mode_c then + error('SpatialReSampling: mode must be simple | average | bilinear') + end +end + +local function round(a) + return math.floor(a+0.5) +end + +function SpatialReSamplingEx:updateOutput(input) + -- compute iheight, iwidth, oheight and owidth + self.iheight = input:size(self.yDim) + self.iwidth = input:size(self.xDim) + self.oheight = self.oheight or round(self.rheight*self.iheight) + self.owidth = self.owidth or round(self.rwidth*self.iwidth) + if not ((self.oheight>=self.iheight) == (self.owidth>=self.iwidth)) then + error('SpatialReSamplingEx: Cannot upsample one dimension while downsampling the other') + end + + -- resize input into K1 x iheight x iwidth x K2 tensor + self.inputSize:fill(1) + for i = 1,self.yDim-1 do + self.inputSize[1] = self.inputSize[1] * input:size(i) + end + self.inputSize[2] = self.iheight + self.inputSize[3] = self.iwidth + for i = self.xDim+1,input:nDimension() do + self.inputSize[4] = self.inputSize[4] * input:size(i) + end + local reshapedInput = input:reshape(self.inputSize) + + -- prepare output of size K1 x oheight x owidth x K2 + self.outputSize[1] = self.inputSize[1] + self.outputSize[2] = self.oheight + self.outputSize[3] = self.owidth + self.outputSize[4] = self.inputSize[4] + self.output:resize(self.outputSize) + + -- resample over dims 2 and 3 + input.nn.SpatialReSamplingEx_updateOutput(self, input:reshape(self.inputSize)) + + --resize output into the same shape as input + local outputSize2 = input:size() + outputSize2[self.yDim] = self.oheight + outputSize2[self.xDim] = self.owidth + self.output = self.output:reshape(outputSize2) + return self.output +end + +function SpatialReSamplingEx:updateGradInput(input, gradOutput) + self.gradInput:resize(self.inputSize) + input.nn.SpatialReSamplingEx_updateGradInput(self, gradOutput:reshape(self.outputSize)) + self.gradInput = self.gradInput:reshape(input:size()) + return self.gradInput +end |