Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorFrancisco Massa <fvsmassa@gmail.com>2016-01-17 14:57:55 +0300
committerfsuzanomassa <fvsmassa@gmail.com>2016-01-28 01:07:34 +0300
commite2ad1f33fe59c6d31a2b3491efb7506b82bdf5b2 (patch)
tree9738d207edb2babfc33a0bc60a7b4589c794649c /SpatialAveragePooling.lua
parentaf15ab04503791f21e593f97ea88cc97d6f9140a (diff)
Add THNN conversion of {Spatial(AdaptiveMax,Average,Max)Pooling}
Diffstat (limited to 'SpatialAveragePooling.lua')
-rw-r--r--SpatialAveragePooling.lua21
1 files changed, 19 insertions, 2 deletions
diff --git a/SpatialAveragePooling.lua b/SpatialAveragePooling.lua
index d8ef41f..231177d 100644
--- a/SpatialAveragePooling.lua
+++ b/SpatialAveragePooling.lua
@@ -45,7 +45,15 @@ end
function SpatialAveragePooling:updateOutput(input)
backwardCompatible(self)
- input.nn.SpatialAveragePooling_updateOutput(self, input)
+ input.THNN.SpatialAveragePooling_updateOutput(
+ input:cdata(),
+ self.output:cdata(),
+ self.kW, self.kH,
+ self.dW, self.dH,
+ self.padW, self.padH,
+ self.ceil_mode,
+ self.count_include_pad
+ )
-- for backward compatibility with saved models
-- which are not supposed to have "divide" field
if not self.divide then
@@ -56,7 +64,16 @@ end
function SpatialAveragePooling:updateGradInput(input, gradOutput)
if self.gradInput then
- input.nn.SpatialAveragePooling_updateGradInput(self, input, gradOutput)
+ input.THNN.SpatialAveragePooling_updateGradInput(
+ input:cdata(),
+ gradOutput:cdata(),
+ self.gradInput:cdata(),
+ self.kW, self.kH,
+ self.dW, self.dH,
+ self.padW, self.padH,
+ self.ceil_mode,
+ self.count_include_pad
+ )
-- for backward compatibility
if not self.divide then
self.gradInput:mul(self.kW*self.kH)