diff options
author | Nicholas Leonard <nick@nikopia.org> | 2014-10-08 19:22:25 +0400 |
---|---|---|
committer | Nicholas Leonard <nick@nikopia.org> | 2014-10-08 19:22:25 +0400 |
commit | 8a4e5ec9f9b63d8fb76fce1c50f1c439a15ddbd6 (patch) | |
tree | c0641ac8197bdc441b9c390fffbea0d89c0322f9 /SpatialReSampling.lua | |
parent | 845196f88f4316ee391504b16693fa0eb27c2a36 (diff) |
SpatialReSampling rwidth/rheight batch fixes
Diffstat (limited to 'SpatialReSampling.lua')
-rw-r--r-- | SpatialReSampling.lua | 10 |
1 files changed, 7 insertions, 3 deletions
diff --git a/SpatialReSampling.lua b/SpatialReSampling.lua index b738eab..7324098 100644 --- a/SpatialReSampling.lua +++ b/SpatialReSampling.lua @@ -3,7 +3,7 @@ local SpatialReSampling, parent = torch.class('nn.SpatialReSampling', 'nn.Module local help_desc = [[Applies a 2D re-sampling over an input image composed of several input planes. The input tensor in forward(input) is -expected to be a 3D tensor (width x height x nInputPlane). +expected to be a 3D or 4D tensor ([batchSize x nInputPlane x width x height). The number of output planes will be the same as the nb of input planes. @@ -31,8 +31,12 @@ function SpatialReSampling:__init(...) end function SpatialReSampling:updateOutput(input) - self.oheight = self.oheight or self.rheight*input:size(2) - self.owidth = self.owidth or self.rwidth*input:size(3) + local hDim, wDim = 2, 3 + if input:dim() == 4 then + hDim, wDim = 3, 4 + end + self.oheight = self.oheight or self.rheight*input:size(hDim) + self.owidth = self.owidth or self.rwidth*input:size(wDim) input.nn.SpatialReSampling_updateOutput(self, input) return self.output end |