diff options
author | Simone Cirillo <simone.cirillo@horus.technology> | 2016-03-22 14:37:36 +0300 |
---|---|---|
committer | Simone Cirillo <simone.cirillo@horus.technology> | 2016-03-29 13:29:24 +0300 |
commit | 6ee33420a98b12fcd65b047bf56a3a98408bdd70 (patch) | |
tree | 5229dbef8ae0e3d19b48bde481f0eccdc8621f33 /convert.lua | |
parent | 41bec1db610316f0e25a229602c933e4901da0c1 (diff) |
Fixed cudnn -> nn avg-pooling conversion
cudnn.convert now properly initializes nn.SpatialAveragePooling.count_include_pad
Diffstat (limited to 'convert.lua')
-rw-r--r-- | convert.lua | 5 |
1 files changed, 4 insertions, 1 deletions
diff --git a/convert.lua b/convert.lua index a938ddc..a64dfe2 100644 --- a/convert.lua +++ b/convert.lua @@ -43,7 +43,10 @@ function cudnn.convert(net, dst) if v == 'ReLU' then y = dst.ReLU() end -- because parameters for k,u in pairs(x) do y[k] = u end if src == cudnn and x.clearDesc then x.clearDesc(y) end - if src == cudnn and v == 'SpatialAveragePooling' then y.divide = true end + if src == cudnn and v == 'SpatialAveragePooling' then + y.divide = true + y.count_include_pad = v.mode == 'CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING' + end return y end local t = torch.typename(x) |