Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorRonan Collobert <ronan@collobert.com>2012-07-25 15:50:37 +0400
committerRonan Collobert <ronan@collobert.com>2012-07-25 15:50:37 +0400
commit0d9a4a50c9253b032a41ad114c41aa82d370d102 (patch)
tree220fffbdeac73edbf7bb85bf39666a4d33cd2ee5 /test
parent37b738909113cd43b2d9a15ff279d737134a9462 (diff)
added temporal max pooling
Diffstat (limited to 'test')
-rw-r--r--test/test.lua17
1 files changed, 17 insertions, 0 deletions
diff --git a/test/test.lua b/test/test.lua
index b8536ca..4d4383d 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -1013,6 +1013,23 @@ function nntest.TemporalSubSampling()
mytester:asserteq(0, berr, torch.typename(module) .. ' - i/o backward err ')
end
+function nntestx.TemporalMaxPooling()
+ local from = math.random(1,10)
+ local ki = math.random(1,10)
+ local si = math.random(1,4)
+ local outi = math.random(10,20)
+ local ini = (outi-1)*si+ki
+ local module = nn.TemporalMaxPooling(ki, si)
+ local input = torch.Tensor(ini, from):zero()
+
+ local err = jac.testJacobian(module, input)
+ mytester:assertlt(err, precision, 'error on state ')
+
+ local ferr, berr = jac.testIO(module, input)
+ mytester:asserteq(0, ferr, torch.typename(module) .. ' - i/o forward err ')
+ mytester:asserteq(0, berr, torch.typename(module) .. ' - i/o backward err ')
+end
+
function nntest.VolumetricConvolution()
local from = math.random(2,5)
local to = math.random(2,5)