Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2015-07-21 22:02:41 +0300
committerSoumith Chintala <soumith@gmail.com>2015-07-21 22:02:41 +0300
commit81348688b7089c733f88f7c43e875e012db6cdbe (patch)
treecf2fc235411a7a75df0773c28b2a586627ce748d /test.lua
parentb9764d4890a5e05c1ae9fe15a53f6123dc9f2202 (diff)
parent929cfc57c88952b597bec77046582b90d1122380 (diff)
Merge pull request #309 from fmassa/max_pool_pad
SpatialMaxPooling supports padding and ceil mode
Diffstat (limited to 'test.lua')
-rw-r--r--test.lua60
1 files changed, 33 insertions, 27 deletions
diff --git a/test.lua b/test.lua
index 92b686f..ed774bd 100644
--- a/test.lua
+++ b/test.lua
@@ -1926,38 +1926,44 @@ function nntest.SpatialSubSampling()
end
function nntest.SpatialMaxPooling()
- local from = math.random(1,5)
- local ki = math.random(1,4)
- local kj = math.random(1,4)
- local si = math.random(1,3)
- local sj = math.random(1,3)
- local outi = math.random(4,5)
- local outj = math.random(4,5)
- local ini = (outi-1)*si+ki
- local inj = (outj-1)*sj+kj
-
- local module = nn.SpatialMaxPooling(ki,kj,si,sj)
- local input = torch.rand(from,ini,inj)
+ for _,ceil_mode in pairs({true,false}) do
+ local from = math.random(1,5)
+ local ki = math.random(1,4)
+ local kj = math.random(1,4)
+ local si = math.random(1,3)
+ local sj = math.random(1,3)
+ local outi = math.random(4,5)
+ local outj = math.random(4,5)
+ local padW = math.min(math.random(0,1),math.floor(ki/2))
+ local padH = math.min(math.random(0,1),math.floor(kj/2))
+ local ini = (outi-1)*si+ki-2*padW
+ local inj = (outj-1)*sj+kj-2*padH
+
+ local ceil_string = ceil_mode and 'ceil' or 'floor'
+ local module = nn.SpatialMaxPooling(ki,kj,si,sj,padW,padH)
+ if ceil_mode then module:ceil() else module:floor() end
+ local input = torch.rand(from,inj,ini)
- local err = jac.testJacobian(module, input)
- mytester:assertlt(err, precision, 'error on state ')
-
- local ferr, berr = jac.testIO(module, input)
- mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ')
- mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ')
+ local err = jac.testJacobian(module, input)
+ mytester:assertlt(err, precision, 'error '..ceil_string..' mode on state ')
- -- batch
- local nbatch = math.random(2,5)
- input = torch.rand(nbatch,from,ini,inj)
- module = nn.SpatialMaxPooling(ki,kj,si,sj)
+ local ferr, berr = jac.testIO(module, input)
+ mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ')
+ mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ')
- local err = jac.testJacobian(module, input)
- mytester:assertlt(err, precision, 'error on state (Batch) ')
+ -- batch
+ local nbatch = math.random(2,5)
+ input = torch.rand(nbatch,from,inj,ini)
+ module = nn.SpatialMaxPooling(ki,kj,si,sj,padW,padH)
+ if ceil_mode then module:ceil() else module:floor() end
- local ferr, berr = jac.testIO(module, input)
- mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err (Batch) ')
- mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err (Batch) ')
+ local err = jac.testJacobian(module, input)
+ mytester:assertlt(err, precision, 'error '..ceil_string..' mode on state (Batch)')
+ local ferr, berr = jac.testIO(module, input)
+ mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err (Batch) ')
+ mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err (Batch) ')
+ end
end
function nntest.SpatialAveragePooling()