Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMichael 'myrhev' Mathieu <michael.mathieu@ens.fr>2012-04-27 00:32:15 +0400
committerMichael 'myrhev' Mathieu <michael.mathieu@ens.fr>2012-04-27 00:50:18 +0400
commit300bd31b0ec5bc108ce78b4f3730f8c67fc5a8bc (patch)
tree1e740a1d333966c20315ced3ccf25bca1d1cbacb /SpatialReSamplingEx.lua
parentc551b07dabd2beebdfe7e534e2a8dd247077c61d (diff)
Add SpatialReSamplingEx (grouping all the Spatial*Sampling).
SpatialPadding can now pad on dimensions different than (2,3). Add Tic/Toc modules, to time a network
Diffstat (limited to 'SpatialReSamplingEx.lua')
-rw-r--r--SpatialReSamplingEx.lua82
1 files changed, 82 insertions, 0 deletions
diff --git a/SpatialReSamplingEx.lua b/SpatialReSamplingEx.lua
new file mode 100644
index 0000000..6b9f562
--- /dev/null
+++ b/SpatialReSamplingEx.lua
@@ -0,0 +1,82 @@
+local SpatialReSamplingEx, parent = torch.class('nn.SpatialReSamplingEx', 'nn.Module')
+
+local help_desc = [[
+ Extended spatial resampling.
+]]
+function SpatialReSamplingEx:__init(...)
+ parent.__init(self)
+
+ -- get args
+ xlua.unpack_class(
+ self, {...}, 'nn.SpatialReSampling', help_desc,
+ {arg='rwidth', type='number', help='ratio: owidth/iwidth'},
+ {arg='rheight', type='number', help='ratio: oheight/iheight'},
+ {arg='owidth', type='number', help='output width'},
+ {arg='oheight', type='number', help='output height'},
+ {arg='mode', type='string', help='Mode : simple | average (only for downsampling) | bilinear', default = 'simple'},
+ {arg='yDim', type='number', help='image y dimension', default=2},
+ {arg='xDim', type='number', help='image x dimension', default=3}
+ )
+ if self.yDim+1 ~= self.xDim then
+ error('nn.SpatialReSamplingEx: yDim must be equals to xDim-1')
+ end
+ self.outputSize = torch.LongStorage(4)
+ self.inputSize = torch.LongStorage(4)
+ if self.mode == 'simple' then self.mode_c = 0 end
+ if self.mode == 'average' then self.mode_c = 1 end
+ if self.mode == 'bilinear' then self.mode_c = 2 end
+ if not self.mode_c then
+ error('SpatialReSampling: mode must be simple | average | bilinear')
+ end
+end
+
+local function round(a)
+ return math.floor(a+0.5)
+end
+
+function SpatialReSamplingEx:updateOutput(input)
+ -- compute iheight, iwidth, oheight and owidth
+ self.iheight = input:size(self.yDim)
+ self.iwidth = input:size(self.xDim)
+ self.oheight = self.oheight or round(self.rheight*self.iheight)
+ self.owidth = self.owidth or round(self.rwidth*self.iwidth)
+ if not ((self.oheight>=self.iheight) == (self.owidth>=self.iwidth)) then
+ error('SpatialReSamplingEx: Cannot upsample one dimension while downsampling the other')
+ end
+
+ -- resize input into K1 x iheight x iwidth x K2 tensor
+ self.inputSize:fill(1)
+ for i = 1,self.yDim-1 do
+ self.inputSize[1] = self.inputSize[1] * input:size(i)
+ end
+ self.inputSize[2] = self.iheight
+ self.inputSize[3] = self.iwidth
+ for i = self.xDim+1,input:nDimension() do
+ self.inputSize[4] = self.inputSize[4] * input:size(i)
+ end
+ local reshapedInput = input:reshape(self.inputSize)
+
+ -- prepare output of size K1 x oheight x owidth x K2
+ self.outputSize[1] = self.inputSize[1]
+ self.outputSize[2] = self.oheight
+ self.outputSize[3] = self.owidth
+ self.outputSize[4] = self.inputSize[4]
+ self.output:resize(self.outputSize)
+
+ -- resample over dims 2 and 3
+ input.nn.SpatialReSamplingEx_updateOutput(self, input:reshape(self.inputSize))
+
+ --resize output into the same shape as input
+ local outputSize2 = input:size()
+ outputSize2[self.yDim] = self.oheight
+ outputSize2[self.xDim] = self.owidth
+ self.output = self.output:reshape(outputSize2)
+ return self.output
+end
+
+function SpatialReSamplingEx:updateGradInput(input, gradOutput)
+ self.gradInput:resize(self.inputSize)
+ input.nn.SpatialReSamplingEx_updateGradInput(self, gradOutput:reshape(self.outputSize))
+ self.gradInput = self.gradInput:reshape(input:size())
+ return self.gradInput
+end