Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSimone Cirillo <simone.cirillo@horus.technology>2016-03-22 14:37:36 +0300
committerSimone Cirillo <simone.cirillo@horus.technology>2016-03-29 13:29:24 +0300
commit6ee33420a98b12fcd65b047bf56a3a98408bdd70 (patch)
tree5229dbef8ae0e3d19b48bde481f0eccdc8621f33 /convert.lua
parent41bec1db610316f0e25a229602c933e4901da0c1 (diff)
Fixed cudnn -> nn avg-pooling conversion
cudnn.convert now properly initializes nn.SpatialAveragePooling.count_include_pad
Diffstat (limited to 'convert.lua')
-rw-r--r--convert.lua5
1 files changed, 4 insertions, 1 deletions
diff --git a/convert.lua b/convert.lua
index a938ddc..a64dfe2 100644
--- a/convert.lua
+++ b/convert.lua
@@ -43,7 +43,10 @@ function cudnn.convert(net, dst)
if v == 'ReLU' then y = dst.ReLU() end -- because parameters
for k,u in pairs(x) do y[k] = u end
if src == cudnn and x.clearDesc then x.clearDesc(y) end
- if src == cudnn and v == 'SpatialAveragePooling' then y.divide = true end
+ if src == cudnn and v == 'SpatialAveragePooling' then
+ y.divide = true
+ y.count_include_pad = v.mode == 'CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING'
+ end
return y
end
local t = torch.typename(x)