Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorkoray kavukcuoglu <koray@kavukcuoglu.org>2013-10-15 04:21:26 +0400
committerkoray kavukcuoglu <koray@kavukcuoglu.org>2013-10-15 04:21:26 +0400
commit2a7a45dd813c94fb2831f3197c5f0a388035ab4a (patch)
tree6a16847a3bc60ee92c0826f1a520a1c68a41c9ab /test
parentbea3665f9dbcd2a5248078d2c23903fcc973c6be (diff)
add 3D max pooling
Diffstat (limited to 'test')
-rw-r--r--test/test.lua34
1 files changed, 30 insertions, 4 deletions
diff --git a/test/test.lua b/test/test.lua
index 1eb92ad..dd6be22 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -1300,10 +1300,10 @@ function nntest.TemporalSubSampling()
end
function nntest.TemporalMaxPooling()
- local from = math.random(1,10)
- local ki = math.random(1,10)
- local si = math.random(1,4)
- local outi = math.random(10,20)
+ local from = math.random(10,10)
+ local ki = math.random(5,10)
+ local si = math.random(1,2)
+ local outi = math.random(50,90)
local ini = (outi-1)*si+ki
local module = nn.TemporalMaxPooling(ki, si)
local input = torch.Tensor(ini, from):zero()
@@ -1364,6 +1364,32 @@ function nntest.VolumetricConvolution()
mytester:asserteq(0, berr, torch.typename(module) .. ' - i/o backward err ')
end
+function nntest.VolumetricMaxPooling()
+ local from = math.random(2,5)
+ local to = from
+ local kt = math.random(3,7)
+ local ki = math.random(3,7)
+ local kj = math.random(3,7)
+ local st = math.random(2,4)
+ local si = math.random(2,4)
+ local sj = math.random(2,4)
+ local outt = math.random(3,7)
+ local outi = math.random(3,7)
+ local outj = math.random(3,7)
+ local int = (outt-1)*st+kt
+ local ini = (outi-1)*si+ki
+ local inj = (outj-1)*sj+kj
+ local module = nn.VolumetricMaxPooling(kt, ki, kj, st, si, sj)
+ local input = torch.Tensor(from, int, inj, ini):zero()
+
+ local err = jac.testJacobian(module, input)
+ mytester:assertlt(err, precision, 'error on state ')
+
+ local ferr, berr = jac.testIO(module, input)
+ mytester:asserteq(0, ferr, torch.typename(module) .. ' - i/o forward err ')
+ mytester:asserteq(0, berr, torch.typename(module) .. ' - i/o backward err ')
+end
+
function nntest.Module_getParameters_1()
local n = nn.Sequential()
n:add( nn.Linear(10,10) )