Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSergey Zagoruyko <zagoruyko2@gmail.com>2016-05-22 23:21:23 +0300
committerSergey Zagoruyko <zagoruyko2@gmail.com>2016-05-22 23:21:23 +0300
commitaa54b0dadaec1a478c197da1b2b82b019224d510 (patch)
tree92e3f8ec06feae29651eec892ac46ed6782feffc
parent108113e95499bbdf991a6f9f95e4b03422fd0ee2 (diff)
parentf7ff756060cb53941c1df54fbe1a537a615a98df (diff)
Merge pull request #188 from szagoruyko/R5
Fix avg pooling back-compatibility
-rw-r--r--ffi.lua4
-rw-r--r--test/test.lua2
2 files changed, 4 insertions, 2 deletions
diff --git a/ffi.lua b/ffi.lua
index 7b07124..9e2cd5e 100644
--- a/ffi.lua
+++ b/ffi.lua
@@ -724,7 +724,8 @@ typedef enum
{
CUDNN_POOLING_MAX = 0,
CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING = 1, /* count for average includes padded values*/
- CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING = 2 /* count for average does not include padded values*/
+ CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING = 2, /* count for average does not include padded values*/
+ CUDNN_POOLING_AVERAGE = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING // for backward compatibility
} cudnnPoolingMode_t;
/* Create an instance of pooling descriptor */
@@ -1584,7 +1585,6 @@ cudnnStatus_t cudnnActivationBackward_v4(
const cudnnTensorDescriptor_t dxDesc,
void *dx );
-
]]
local libnames = {'libcudnn.so.5', 'libcudnn.5.dylib'}
diff --git a/test/test.lua b/test/test.lua
index ba1b5a1..3d2521e 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -820,6 +820,8 @@ function cudnntest.SpatialAveragePooling_single()
local sconv = nn.SpatialAveragePooling(ki,kj,si,sj):cuda()
local gconv = cudnn.SpatialAveragePooling(ki,kj,si,sj):cuda()
+ mytester:assert(cudnn.C.CUDNN_POOLING_AVERAGE ~= nil, 'back-compat broken')
+
local function test(sconv, gconv)
local groundtruth = sconv:forward(input):clone()
local groundgrad = sconv:backward(input, gradOutput)