Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorsoumith <soumith@fb.com>2015-01-19 11:01:49 +0300
committersoumith <soumith@fb.com>2015-01-19 11:01:49 +0300
commitabee8f6de6294f2aafc3af128fe85009a1264076 (patch)
tree3d2bd1225225f9c42043ce0fc278102b4482f552 /test.lua
parenta9768c444f4a455fc8a7654981350c9ecc6e7aee (diff)
parent98cf6f146be4e032396daa0da12d5f181537d6a9 (diff)
Merge https://github.com/fmassa/nn
Diffstat (limited to 'test.lua')
-rw-r--r--test.lua31
1 files changed, 31 insertions, 0 deletions
diff --git a/test.lua b/test.lua
index 0085cc5..936fccf 100644
--- a/test.lua
+++ b/test.lua
@@ -1678,6 +1678,37 @@ function nntest.SpatialAveragePooling()
mytester:assertTensorEq(gradInput, gradInput2, 0.000001, torch.typename(module) .. ' backward err (Batch) ')
end
+function nntest.SpatialAdaptiveMaxPooling()
+ local from = math.random(1,5)
+ local ki = math.random(1,12)
+ local kj = math.random(1,12)
+ local ini = math.random(1,64)
+ local inj = math.random(1,64)
+
+ local module = nn.SpatialAdaptiveMaxPooling(ki,kj)
+ local input = torch.rand(from,ini,inj)
+
+ local err = jac.testJacobian(module, input)
+ mytester:assertlt(err, precision, 'error on state ')
+
+ local ferr, berr = jac.testIO(module, input)
+ mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ')
+ mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ')
+
+ -- batch
+ local nbatch = math.random(2,5)
+ input = torch.rand(nbatch,from,ini,inj)
+ module = nn.SpatialAdaptiveMaxPooling(ki,kj)
+
+ local err = jac.testJacobian(module, input)
+ mytester:assertlt(err, precision, 'error on state (Batch) ')
+
+ local ferr, berr = jac.testIO(module, input)
+ mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err (Batch) ')
+ mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err (Batch) ')
+
+end
+
function nntest.SpatialLPPooling()
local fanin = math.random(1,4)
local osizex = math.random(1,4)