Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNicholas Leonard <nick@nikopia.org>2014-10-08 19:22:25 +0400
committerNicholas Leonard <nick@nikopia.org>2014-10-08 19:22:25 +0400
commit8a4e5ec9f9b63d8fb76fce1c50f1c439a15ddbd6 (patch)
treec0641ac8197bdc441b9c390fffbea0d89c0322f9 /SpatialReSampling.lua
parent845196f88f4316ee391504b16693fa0eb27c2a36 (diff)
SpatialReSampling rwidth/rheight batch fixes
Diffstat (limited to 'SpatialReSampling.lua')
-rw-r--r--SpatialReSampling.lua10
1 files changed, 7 insertions, 3 deletions
diff --git a/SpatialReSampling.lua b/SpatialReSampling.lua
index b738eab..7324098 100644
--- a/SpatialReSampling.lua
+++ b/SpatialReSampling.lua
@@ -3,7 +3,7 @@ local SpatialReSampling, parent = torch.class('nn.SpatialReSampling', 'nn.Module
local help_desc =
[[Applies a 2D re-sampling over an input image composed of
several input planes. The input tensor in forward(input) is
-expected to be a 3D tensor (width x height x nInputPlane).
+expected to be a 3D or 4D tensor ([batchSize x nInputPlane x width x height).
The number of output planes will be the same as the nb of input
planes.
@@ -31,8 +31,12 @@ function SpatialReSampling:__init(...)
end
function SpatialReSampling:updateOutput(input)
- self.oheight = self.oheight or self.rheight*input:size(2)
- self.owidth = self.owidth or self.rwidth*input:size(3)
+ local hDim, wDim = 2, 3
+ if input:dim() == 4 then
+ hDim, wDim = 3, 4
+ end
+ self.oheight = self.oheight or self.rheight*input:size(hDim)
+ self.owidth = self.owidth or self.rwidth*input:size(wDim)
input.nn.SpatialReSampling_updateOutput(self, input)
return self.output
end