diff options
author | Francisco Massa <fvsmassa@gmail.com> | 2015-06-18 02:17:11 +0300 |
---|---|---|
committer | Francisco Massa <fvsmassa@gmail.com> | 2015-06-18 02:26:41 +0300 |
commit | 793b6bfef8f4fa51e7ddee9cfe34d0732de067a2 (patch) | |
tree | 1efc990a7f41e99b20a5363943928546083d82e4 /SpatialConvolution.lua | |
parent | 25d46e69823298c9584858f5a7facf5fc0b17e90 (diff) |
SpatialConvolutionMM supports padW/padH != 1
Fix indentation
Diffstat (limited to 'SpatialConvolution.lua')
-rw-r--r-- | SpatialConvolution.lua | 18 |
1 files changed, 13 insertions, 5 deletions
diff --git a/SpatialConvolution.lua b/SpatialConvolution.lua index 3507030..84412b0 100644 --- a/SpatialConvolution.lua +++ b/SpatialConvolution.lua @@ -1,6 +1,6 @@ local SpatialConvolution, parent = torch.class('nn.SpatialConvolution', 'nn.Module') -function SpatialConvolution:__init(nInputPlane, nOutputPlane, kW, kH, dW, dH, padding) +function SpatialConvolution:__init(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH) parent.__init(self) dW = dW or 1 @@ -13,7 +13,8 @@ function SpatialConvolution:__init(nInputPlane, nOutputPlane, kW, kH, dW, dH, pa self.dW = dW self.dH = dH - self.padding = padding or 0 + self.padW = padW or 0 + self.padH = padH or self.padW self.weight = torch.Tensor(nOutputPlane, nInputPlane, kH, kW) self.bias = torch.Tensor(nOutputPlane) @@ -45,7 +46,14 @@ end local function backCompatibility(self) self.finput = self.finput or self.weight.new() self.fgradInput = self.fgradInput or self.weight.new() - self.padding = self.padding or 0 + if self.padding then + self.padW = self.padding + self.padH = self.padding + self.padding = nil + else + self.padW = self.padW or 0 + self.padH = self.padH or 0 + end if self.weight:dim() == 2 then self.weight = self.weight:view(self.nOutputPlane, self.nInputPlane, self.kH, self.kW) end @@ -128,8 +136,8 @@ function SpatialConvolution:__tostring__() end if self.padding and self.padding ~= 0 then s = s .. ', ' .. self.padding .. ',' .. self.padding - elseif self.pad_w or self.pad_h then - s = s .. ', ' .. self.pad_w .. ',' .. self.pad_h + elseif (self.padW or self.padH) and (self.padW ~= 0 or self.padH ~= 0) then + s = s .. ', ' .. self.padW .. ',' .. self.padW end return s .. ')' end |