diff options
author | Soumith Chintala <soumith@gmail.com> | 2014-12-01 01:37:40 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2014-12-01 01:40:01 +0300 |
commit | f58f3453c968d842efc1c4cadb019b4d8fe3e655 (patch) | |
tree | 31bb2c2b09589269726a767b135cea5f0fa1041d /Pooling.lua | |
parent | 875067f4eb6f5c8eab77ee1acd030fd5e2225fc5 (diff) |
ceil mode for Pooling (for compatibility with ccn2 and caffe)
Diffstat (limited to 'Pooling.lua')
-rw-r--r-- | Pooling.lua | 21 |
1 files changed, 19 insertions, 2 deletions
diff --git a/Pooling.lua b/Pooling.lua index 97eb617..0ab7cb4 100644 --- a/Pooling.lua +++ b/Pooling.lua @@ -9,6 +9,17 @@ function Pooling:__init(kW, kH, dW, dH) self.dW = dW or kW self.dH = dH or kW self.iSize = torch.LongStorage(4):fill(0) + self.ceil_mode = false +end + +function Pooling:ceil() + self.ceil_mode = true + return self +end + +function Pooling:floor() + self.ceil_mode = false + return self end function Pooling:resetPoolDescriptors() @@ -38,8 +49,14 @@ function Pooling:createIODescriptors(input) -- resize gradInput self.gradInput:resizeAs(input) -- resize output - local oW = math.floor((input:size(4) - self.kW)/self.dW + 1) - local oH = math.floor((input:size(3) - self.kH)/self.dH + 1) + local oW, oH + if self.ceil_mode then + oW = math.ceil((input:size(4) - self.kW)/self.dW + 1) + oH = math.ceil((input:size(3) - self.kH)/self.dH + 1) + else + oW = math.floor((input:size(4) - self.kW)/self.dW + 1) + oH = math.floor((input:size(3) - self.kH)/self.dH + 1) + end self.output:resize(input:size(1), input:size(2), oH, oW) -- create input/output descriptor |