diff options
-rw-r--r-- | Pooling.lua | 21 |
1 files changed, 19 insertions, 2 deletions
diff --git a/Pooling.lua b/Pooling.lua index 97eb617..0ab7cb4 100644 --- a/Pooling.lua +++ b/Pooling.lua @@ -9,6 +9,17 @@ function Pooling:__init(kW, kH, dW, dH) self.dW = dW or kW self.dH = dH or kW self.iSize = torch.LongStorage(4):fill(0) + self.ceil_mode = false +end + +function Pooling:ceil() + self.ceil_mode = true + return self +end + +function Pooling:floor() + self.ceil_mode = false + return self end function Pooling:resetPoolDescriptors() @@ -38,8 +49,14 @@ function Pooling:createIODescriptors(input) -- resize gradInput self.gradInput:resizeAs(input) -- resize output - local oW = math.floor((input:size(4) - self.kW)/self.dW + 1) - local oH = math.floor((input:size(3) - self.kH)/self.dH + 1) + local oW, oH + if self.ceil_mode then + oW = math.ceil((input:size(4) - self.kW)/self.dW + 1) + oH = math.ceil((input:size(3) - self.kH)/self.dH + 1) + else + oW = math.floor((input:size(4) - self.kW)/self.dW + 1) + oH = math.floor((input:size(3) - self.kH)/self.dH + 1) + end self.output:resize(input:size(1), input:size(2), oH, oW) -- create input/output descriptor |