diff options
author | Clement Farabet <clement.farabet@gmail.com> | 2011-07-08 08:43:27 +0400 |
---|---|---|
committer | Clement Farabet <clement.farabet@gmail.com> | 2011-07-08 08:43:27 +0400 |
commit | 3adbcafc2d64524ad57431102bcb1e082cc450be (patch) | |
tree | 3b3703ad2bb90c113899865839e2a849094fd0e7 /SpatialReSampling.lua | |
parent | 6ec50929b2964ae9423feb0f5434b0f9dfdce610 (diff) |
added two generic upsamplers/resamplers modules
Diffstat (limited to 'SpatialReSampling.lua')
-rw-r--r-- | SpatialReSampling.lua | 55 |
1 files changed, 55 insertions, 0 deletions
diff --git a/SpatialReSampling.lua b/SpatialReSampling.lua new file mode 100644 index 0000000..20c9985 --- /dev/null +++ b/SpatialReSampling.lua @@ -0,0 +1,55 @@ +local SpatialReSampling, parent = torch.class('nn.SpatialReSampling', 'nn.Module') + +local help_desc = +[[Applies a 2D re-sampling over an input image composed of +several input planes. The input tensor in forward(input) is +expected to be a 3D tensor (width x height x nInputPlane). +The number of output planes will be the same as the nb of input +planes. + +The re-sampling is done using bilinear interpolation. For a +simple nearest-neihbor upsampling, use nn.SpatialUpSampling(), +and for a simple average-based down-sampling, use +nn.SpatialDownSampling(). + +If the input image is a 3D tensor nInputPlane x height x width, +the output image size will be nInputPlane x oheight x owidth where +owidth and oheight are given to the constructor. + +Instead of owidth & oheight, one can provide rwidth & rheight, +such that owidth = iwidth*rwidth & oheight = iheight*rheight. ]] + +function SpatialReSampling:__init(...) + parent.__init(self) + xlua.unpack_class( + self, {...}, 'nn.SpatialReSampling', help_desc, + {arg='owidth', type='number', help='output width'}, + {arg='oheight', type='number', help='output height'}, + {arg='rwidth', type='number', help='ratio: owidth/iwidth'}, + {arg='rheight', type='number', help='ratio: oheight/iheight'} + ) +end + +function SpatialReSampling:forward(input) + self.oheight = self.oheight or self.rheight*input:size(2) + self.owidth = self.owidth or self.rwidth*input:size(3) + input.nn.SpatialReSampling_forward(self, input) + return self.output +end + +function SpatialReSampling:backward(input, gradOutput) + input.nn.SpatialReSampling_backward(self, input, gradOutput) + return self.gradInput +end + +function SpatialReSampling:write(file) + parent.write(self, file) + file:writeInt(self.owidth) + file:writeInt(self.oheight) +end + +function SpatialReSampling:read(file) + parent.read(self, file) + self.owidth = file:readInt() + self.oheight = file:readInt() +end |