Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2014-12-01 01:37:40 +0300
committerSoumith Chintala <soumith@gmail.com>2014-12-01 01:40:01 +0300
commitf58f3453c968d842efc1c4cadb019b4d8fe3e655 (patch)
tree31bb2c2b09589269726a767b135cea5f0fa1041d /Pooling.lua
parent875067f4eb6f5c8eab77ee1acd030fd5e2225fc5 (diff)
ceil mode for Pooling (for compatibility with ccn2 and caffe)
Diffstat (limited to 'Pooling.lua')
-rw-r--r--Pooling.lua21
1 files changed, 19 insertions, 2 deletions
diff --git a/Pooling.lua b/Pooling.lua
index 97eb617..0ab7cb4 100644
--- a/Pooling.lua
+++ b/Pooling.lua
@@ -9,6 +9,17 @@ function Pooling:__init(kW, kH, dW, dH)
self.dW = dW or kW
self.dH = dH or kW
self.iSize = torch.LongStorage(4):fill(0)
+ self.ceil_mode = false
+end
+
+function Pooling:ceil()
+ self.ceil_mode = true
+ return self
+end
+
+function Pooling:floor()
+ self.ceil_mode = false
+ return self
end
function Pooling:resetPoolDescriptors()
@@ -38,8 +49,14 @@ function Pooling:createIODescriptors(input)
-- resize gradInput
self.gradInput:resizeAs(input)
-- resize output
- local oW = math.floor((input:size(4) - self.kW)/self.dW + 1)
- local oH = math.floor((input:size(3) - self.kH)/self.dH + 1)
+ local oW, oH
+ if self.ceil_mode then
+ oW = math.ceil((input:size(4) - self.kW)/self.dW + 1)
+ oH = math.ceil((input:size(3) - self.kH)/self.dH + 1)
+ else
+ oW = math.floor((input:size(4) - self.kW)/self.dW + 1)
+ oH = math.floor((input:size(3) - self.kH)/self.dH + 1)
+ end
self.output:resize(input:size(1), input:size(2), oH, oW)
-- create input/output descriptor