Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorFrancisco Massa <fvsmassa@gmail.com>2015-06-18 02:17:11 +0300
committerFrancisco Massa <fvsmassa@gmail.com>2015-06-18 02:26:41 +0300
commit793b6bfef8f4fa51e7ddee9cfe34d0732de067a2 (patch)
tree1efc990a7f41e99b20a5363943928546083d82e4 /SpatialConvolution.lua
parent25d46e69823298c9584858f5a7facf5fc0b17e90 (diff)
SpatialConvolutionMM supports padW/padH != 1
Fix indentation
Diffstat (limited to 'SpatialConvolution.lua')
-rw-r--r--SpatialConvolution.lua18
1 files changed, 13 insertions, 5 deletions
diff --git a/SpatialConvolution.lua b/SpatialConvolution.lua
index 3507030..84412b0 100644
--- a/SpatialConvolution.lua
+++ b/SpatialConvolution.lua
@@ -1,6 +1,6 @@
local SpatialConvolution, parent = torch.class('nn.SpatialConvolution', 'nn.Module')
-function SpatialConvolution:__init(nInputPlane, nOutputPlane, kW, kH, dW, dH, padding)
+function SpatialConvolution:__init(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
parent.__init(self)
dW = dW or 1
@@ -13,7 +13,8 @@ function SpatialConvolution:__init(nInputPlane, nOutputPlane, kW, kH, dW, dH, pa
self.dW = dW
self.dH = dH
- self.padding = padding or 0
+ self.padW = padW or 0
+ self.padH = padH or self.padW
self.weight = torch.Tensor(nOutputPlane, nInputPlane, kH, kW)
self.bias = torch.Tensor(nOutputPlane)
@@ -45,7 +46,14 @@ end
local function backCompatibility(self)
self.finput = self.finput or self.weight.new()
self.fgradInput = self.fgradInput or self.weight.new()
- self.padding = self.padding or 0
+ if self.padding then
+ self.padW = self.padding
+ self.padH = self.padding
+ self.padding = nil
+ else
+ self.padW = self.padW or 0
+ self.padH = self.padH or 0
+ end
if self.weight:dim() == 2 then
self.weight = self.weight:view(self.nOutputPlane, self.nInputPlane, self.kH, self.kW)
end
@@ -128,8 +136,8 @@ function SpatialConvolution:__tostring__()
end
if self.padding and self.padding ~= 0 then
s = s .. ', ' .. self.padding .. ',' .. self.padding
- elseif self.pad_w or self.pad_h then
- s = s .. ', ' .. self.pad_w .. ',' .. self.pad_h
+ elseif (self.padW or self.padH) and (self.padW ~= 0 or self.padH ~= 0) then
+ s = s .. ', ' .. self.padW .. ',' .. self.padW
end
return s .. ')'
end