diff options
author | Sergey Zagoruyko <zagoruyko2@gmail.com> | 2015-03-04 17:39:05 +0300 |
---|---|---|
committer | Sergey Zagoruyko <zagoruyko2@gmail.com> | 2015-03-04 17:39:05 +0300 |
commit | cb303329bd82f1bb80d3c7eebc14d41195acec83 (patch) | |
tree | cf3fc5fcafe6392e5c11e657c3809e9e070de855 /SpatialAveragePooling.lua | |
parent | 7470cc81d458066d751156df52d05d3482411e6c (diff) |
SpatialAveragePooling divides by kW*kH
Diffstat (limited to 'SpatialAveragePooling.lua')
-rw-r--r-- | SpatialAveragePooling.lua | 16 |
1 files changed, 14 insertions, 2 deletions
diff --git a/SpatialAveragePooling.lua b/SpatialAveragePooling.lua index 13b6b45..90b79aa 100644 --- a/SpatialAveragePooling.lua +++ b/SpatialAveragePooling.lua @@ -7,14 +7,26 @@ function SpatialAveragePooling:__init(kW, kH, dW, dH) self.kH = kH self.dW = dW or 1 self.dH = dH or 1 + self.divide = true end function SpatialAveragePooling:updateOutput(input) - return input.nn.SpatialAveragePooling_updateOutput(self, input) + input.nn.SpatialAveragePooling_updateOutput(self, input) + -- for backward compatibility with saved models + -- which are not supposed to have "divide" field + if not self.divide then + self.output:mul(self.kW*self.kH) + end + return self.output end function SpatialAveragePooling:updateGradInput(input, gradOutput) if self.gradInput then - return input.nn.SpatialAveragePooling_updateGradInput(self, input, gradOutput) + input.nn.SpatialAveragePooling_updateGradInput(self, input, gradOutput) + -- for backward compatibility + if not self.divide then + self.gradInput:mul(self.kW*self.kH) + end + return self.gradInput end end |