Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRuotian Luo <rluo@ttic.edu>2017-01-15 00:27:52 +0300
committerRuotian Luo <rluo@ttic.edu>2017-01-15 00:27:52 +0300
commit2374627df385b68fd919a573e09517e26fa3c254 (patch)
tree4b870c74beb5c82b655f2c4b48b58e8c450828ab /test.lua
parentc489620118f335d83086ebcfcac4532f4ed760e2 (diff)
Add SpatialAdaptiveAveragePooling.
Diffstat (limited to 'test.lua')
-rw-r--r--test.lua56
1 files changed, 56 insertions, 0 deletions
diff --git a/test.lua b/test.lua
index b3e1d16..8425aef 100644
--- a/test.lua
+++ b/test.lua
@@ -4009,6 +4009,62 @@ function nntest.SpatialAdaptiveMaxPooling()
end
+function nntest.SpatialAdaptiveAveragePooling()
+ local from = math.random(1,5)
+ local ki = math.random(1,5)
+ local kj = math.random(1,5)
+ local ini = math.random(1,16)
+ local inj = math.random(1,16)
+
+ local module = nn.SpatialAdaptiveAveragePooling(ki,kj)
+ local input = torch.rand(from,ini,inj)
+
+ local err = jac.testJacobian(module, input)
+ mytester:assertlt(err, precision, 'error on state ')
+
+ local ferr, berr = jac.testIO(module, input)
+ mytester:eq(ferr, 0, torch.typename(module) .. ' - i/o forward err ', precision)
+ mytester:eq(berr, 0, torch.typename(module) .. ' - i/o backward err ', precision)
+
+ -- batch
+ local nbatch = math.random(1,3)
+ input = torch.rand(nbatch,from,ini,inj)
+ module = nn.SpatialAdaptiveAveragePooling(ki,kj)
+
+ local err = jac.testJacobian(module, input)
+ mytester:assertlt(err, precision, 'error on state (Batch) ')
+
+ local ferr, berr = jac.testIO(module, input)
+ mytester:eq(ferr, 0, torch.typename(module) .. ' - i/o forward err (Batch) ', precision)
+ mytester:eq(berr, 0, torch.typename(module) .. ' - i/o backward err (Batch) ', precision)
+
+ -- non-contiguous
+
+ input = torch.rand(from,ini,inj):transpose(2,3)
+ module = nn.SpatialAdaptiveAveragePooling(ki,kj)
+ local inputc = input:contiguous() -- contiguous
+ local output = module:forward(input):clone()
+ local outputc = module:forward(inputc):clone()
+ mytester:asserteq(0, (output-outputc):abs():max(), torch.typename(module) .. ' - non-contiguous err ')
+ local gradInput = module:backward(input, output):clone()
+ local gradInputc = module:backward(inputc, outputc):clone()
+ mytester:asserteq(0, (gradInput-gradInputc):abs():max(), torch.typename(module) .. ' - non-contiguous err ')
+
+ -- non-contiguous batch
+ local nbatch = math.random(1,3)
+ input = torch.rand(nbatch,from,ini,inj):transpose(1,3):transpose(2,4)
+ local inputc = input:contiguous() -- contiguous
+ module = nn.SpatialAdaptiveAveragePooling(ki,kj)
+
+ local output = module:forward(input):clone()
+ local outputc = module:forward(inputc):clone()
+ mytester:asserteq(0, (output-outputc):abs():max(), torch.typename(module) .. ' - batch non-contiguous err ')
+ local gradInput = module:backward(input, output):clone()
+ local gradInputc = module:backward(inputc, outputc):clone()
+ mytester:asserteq(0, (gradInput-gradInputc):abs():max(), torch.typename(module) .. ' - batch non-contiguous err ')
+
+end
+
function nntest.SpatialLPPooling()
local fanin = math.random(1,4)
local osizex = math.random(1,4)