Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJonathan Tompson <tompson@cims.nyu.edu>2014-06-26 19:32:57 +0400
committerJonathan Tompson <tompson@cims.nyu.edu>2014-06-26 19:32:57 +0400
commit4725c6b639f8dfc5d0440557c65e5dbc6fec1873 (patch)
treecbce1545c19a1c674c95a09ee2d9201f1adfb178 /SpatialUpSamplingNearest.lua
parent1310a045ebc69a9f9e8c57d07af587a6535d5ae9 (diff)
Added SpatialUpSamplingNearest module.
Diffstat (limited to 'SpatialUpSamplingNearest.lua')
-rw-r--r--SpatialUpSamplingNearest.lua58
1 files changed, 58 insertions, 0 deletions
diff --git a/SpatialUpSamplingNearest.lua b/SpatialUpSamplingNearest.lua
new file mode 100644
index 0000000..8288250
--- /dev/null
+++ b/SpatialUpSamplingNearest.lua
@@ -0,0 +1,58 @@
+local SpatialUpSamplingNearest, parent = torch.class('nn.SpatialUpSamplingNearest', 'nn.Module')
+
+--[[
+Applies a 2D up-sampling over an input image composed of several input planes.
+
+The upsampling is done using the simple nearest neighbor technique.
+
+The Y and X dimensions are assumed to be the last 2 tensor dimensions. For
+instance, if the tensor is 4D, then dim 3 is the y dimension and dim 4 is the x.
+
+owidth = width*scale_factor
+oheight = height*scale_factor
+--]]
+
+function SpatialUpSamplingNearest:__init(scale)
+ parent.__init(self)
+
+ self.scale_factor = scale
+ if self.scale_factor < 1 then
+ error('scale_factor must be greater than 1')
+ end
+ if math.floor(self.scale_factor) ~= self.scale_factor then
+ error('scale_factor must be integer')
+ end
+ self.inputSize = torch.LongStorage(4)
+ self.outputSize = torch.LongStorage(4)
+ self.usage = nil
+end
+
+function SpatialUpSamplingNearest:updateOutput(input)
+ if input:dim() ~= 4 and input:dim() ~= 3 then
+ error('SpatialUpSamplingNearest only support 3D or 4D tensors')
+ end
+ -- Copy the input size
+ local xdim = input:dim()
+ local ydim = input:dim() - 1
+ for i = 1, input:dim() do
+ self.inputSize[i] = input:size(i)
+ self.outputSize[i] = input:size(i)
+ end
+ self.outputSize[ydim] = self.outputSize[ydim] * self.scale_factor
+ self.outputSize[xdim] = self.outputSize[xdim] * self.scale_factor
+ -- Resize the output if needed
+ if input:dim() == 3 then
+ self.output:resize(self.outputSize[1], self.outputSize[2],
+ self.outputSize[3])
+ else
+ self.output:resize(self.outputSize)
+ end
+ input.nn.SpatialUpSamplingNearest_updateOutput(self, input)
+ return self.output
+end
+
+function SpatialUpSamplingNearest:updateGradInput(input, gradOutput)
+ self.gradInput:resizeAs(input)
+ input.nn.SpatialUpSamplingNearest_updateGradInput(self, input, gradOutput)
+ return self.gradInput
+end