Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNicholas Leonard <nick@nikopia.org>2014-11-25 05:17:52 +0300
committernicholas-leonard <nick@nikopia.org>2014-11-25 07:08:09 +0300
commit70f542492cbde0cf7f16afc915a3b8b674b77bd0 (patch)
tree477c4c8e9e50c0ecae31ca4f1214abc52148cbe6 /test.lua
parentd232b6f3bedda4f5151248e81d2cc75123ab509b (diff)
SpatialAveragePooling
Diffstat (limited to 'test.lua')
-rw-r--r--test.lua69
1 files changed, 64 insertions, 5 deletions
diff --git a/test.lua b/test.lua
index ed7fd21..4615b23 100644
--- a/test.lua
+++ b/test.lua
@@ -1349,7 +1349,6 @@ function nntest.SpatialSubSampling()
'error on bias [%s]', t))
end
- --verbose = true
local batch = math.random(2,5)
outi = math.random(4,8)
outj = math.random(4,8)
@@ -1358,10 +1357,6 @@ function nntest.SpatialSubSampling()
module = nn.SpatialSubSampling(from, ki, kj, si, sj)
input = torch.Tensor(batch,from,inj,ini):zero()
--- print(from, to, ki, kj, si, sj, batch, ini, inj)
--- print(module.weight:size())
--- print(module.gradWeight:size())
-
local err = jac.testJacobian(module, input)
mytester:assertlt(err, precision, 'batch error on state ')
@@ -1427,6 +1422,70 @@ function nntest.SpatialMaxPooling()
end
+function nntest.SpatialAveragePooling()
+ local from = math.random(1,6)
+ local ki = math.random(1,5)
+ local kj = math.random(1,5)
+ local si = math.random(1,4)
+ local sj = math.random(1,4)
+ local outi = math.random(6,10)
+ local outj = math.random(6,10)
+ local ini = (outi-1)*si+ki
+ local inj = (outj-1)*sj+kj
+ local module = nn.SpatialAveragePooling(ki, kj, si, sj)
+ local input = torch.Tensor(from, inj, ini):zero()
+
+ local err = jac.testJacobian(module, input)
+ mytester:assertlt(err, precision, 'error on state ')
+
+ local ferr, berr = jac.testIO(module, input)
+ mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ')
+ mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ')
+
+ local sap = nn.SpatialSubSampling(from, ki, kj, si, sj)
+ sap.weight:fill(1.0)
+ sap.bias:fill(0.0)
+
+ local output = module:forward(input)
+ local gradInput = module:backward(input, output)
+ local output2 = sap:forward(input)
+ local gradInput2 = sap:updateGradInput(input, output)
+
+ mytester:assertTensorEq(output, output2, 0.000001, torch.typename(module) .. ' forward err ')
+ mytester:assertTensorEq(gradInput, gradInput2, 0.000001, torch.typename(module) .. ' backward err ')
+
+ local batch = math.random(2,5)
+ outi = math.random(4,8)
+ outj = math.random(4,8)
+ ini = (outi-1)*si+ki
+ inj = (outj-1)*sj+kj
+ module = nn.SpatialAveragePooling(ki, kj, si, sj)
+ input = torch.Tensor(batch,from,inj,ini):zero()
+
+ local err = jac.testJacobian(module, input)
+ mytester:assertlt(err, precision, 'batch error on state ')
+
+ local ferr, berr = jac.testIO(module, input)
+ mytester:asserteq(0, ferr, torch.typename(module) .. ' - i/o forward err ')
+ mytester:asserteq(0, berr, torch.typename(module) .. ' - i/o backward err ')
+
+ local ferr, berr = jac.testIO(module, input)
+ mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err (Batch) ')
+ mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err (Batch) ')
+
+ local sap = nn.SpatialSubSampling(from, ki, kj, si, sj)
+ sap.weight:fill(1.0)
+ sap.bias:fill(0.0)
+
+ local output = module:forward(input)
+ local gradInput = module:backward(input, output)
+ local output2 = sap:forward(input)
+ local gradInput2 = sap:updateGradInput(input, output)
+
+ mytester:assertTensorEq(output, output2, 0.000001, torch.typename(module) .. ' forward err (Batch) ')
+ mytester:assertTensorEq(gradInput, gradInput2, 0.000001, torch.typename(module) .. ' backward err (Batch) ')
+end
+
function nntest.SpatialLPPooling()
local fanin = math.random(1,4)
local osizex = math.random(1,4)