Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMichael 'myrhev' Mathieu <michael.mathieu@ens.fr>2012-04-27 00:32:15 +0400
committerMichael 'myrhev' Mathieu <michael.mathieu@ens.fr>2012-04-27 00:50:18 +0400
commit300bd31b0ec5bc108ce78b4f3730f8c67fc5a8bc (patch)
tree1e740a1d333966c20315ced3ccf25bca1d1cbacb /SpatialUpSampling.lua
parentc551b07dabd2beebdfe7e534e2a8dd247077c61d (diff)
Add SpatialReSamplingEx (grouping all the Spatial*Sampling).
SpatialPadding can now pad on dimensions different than (2,3). Add Tic/Toc modules, to time a network
Diffstat (limited to 'SpatialUpSampling.lua')
-rw-r--r--SpatialUpSampling.lua37
1 files changed, 32 insertions, 5 deletions
diff --git a/SpatialUpSampling.lua b/SpatialUpSampling.lua
index 43b9de6..779e979 100644
--- a/SpatialUpSampling.lua
+++ b/SpatialUpSampling.lua
@@ -22,17 +22,44 @@ function SpatialUpSampling:__init(...)
-- get args
xlua.unpack_class(self, {...}, 'nn.SpatialUpSampling', help_desc,
{arg='dW', type='number', help='stride width', req=true},
- {arg='dH', type='number', help='stride height', req=true})
+ {arg='dH', type='number', help='stride height', req=true},
+ {arg='yDim', type='number', help='image y dimension', default=2},
+ {arg='xDim', type='number', help='image x dimension', default=3}
+ )
+ if self.yDim+1 ~= self.xDim then
+ error('nn.SpatialUpSampling: yDim must be equals to xDim-1')
+ end
+ self.outputSize = torch.LongStorage(4)
+ self.inputSize = torch.LongStorage(4)
end
function SpatialUpSampling:updateOutput(input)
- self.output:resize(input:size(1), input:size(2) * self.dH, input:size(3) * self.dW)
- input.nn.SpatialUpSampling_updateOutput(self, input)
+ self.inputSize:fill(1)
+ for i = 1,self.yDim-1 do
+ self.inputSize[1] = self.inputSize[1] * input:size(i)
+ end
+ self.inputSize[2] = input:size(self.yDim)
+ self.inputSize[3] = input:size(self.xDim)
+ for i = self.xDim+1,input:nDimension() do
+ self.inputSize[4] = self.inputSize[4] * input:size(i)
+ end
+ self.outputSize[1] = self.inputSize[1]
+ self.outputSize[2] = self.inputSize[2] * self.dH
+ self.outputSize[3] = self.inputSize[3] * self.dW
+ self.outputSize[4] = self.inputSize[4]
+ self.output:resize(self.outputSize)
+ input.nn.SpatialUpSampling_updateOutput(self, input:reshape(self.inputSize))
+ local outputSize2 = input:size()
+ outputSize2[self.yDim] = outputSize2[self.yDim] * self.dH
+ outputSize2[self.xDim] = outputSize2[self.xDim] * self.dW
+ self.output = self.output:reshape(outputSize2)
return self.output
end
function SpatialUpSampling:updateGradInput(input, gradOutput)
- self.gradInput:resizeAs(input)
- input.nn.SpatialUpSampling_updateGradInput(self, input, gradOutput)
+ self.gradInput:resize(self.inputSize)
+ input.nn.SpatialUpSampling_updateGradInput(self, input,
+ gradOutput:reshape(self.outputSize))
+ self.gradInput = self.gradInput:reshape(input:size())
return self.gradInput
end