diff options
author | nicholas-leonard <nick@nikopia.org> | 2014-05-10 23:31:48 +0400 |
---|---|---|
committer | nicholas-leonard <nick@nikopia.org> | 2014-05-10 23:31:48 +0400 |
commit | 5703059443c6f5a5bfdf5c6ab035d2a377a821e5 (patch) | |
tree | dfb23afd24297736e6d0acd0adb5664325ce8b13 /test/test.lua | |
parent | d1bab2f635accc00f3c871c16674853d848b44fc (diff) |
TemporalConvolution unit test
Diffstat (limited to 'test/test.lua')
-rw-r--r-- | test/test.lua | 30 |
1 files changed, 30 insertions, 0 deletions
diff --git a/test/test.lua b/test/test.lua index b342c36..f565bd3 100644 --- a/test/test.lua +++ b/test/test.lua @@ -1283,6 +1283,7 @@ function nntest.Tanh() end function nntest.TemporalConvolution() + -- 1D local from = math.random(1,10) local to = math.random(1,10) local ki = math.random(1,10) @@ -1317,6 +1318,35 @@ function nntest.TemporalConvolution() 'error on bias [%s]', t)) end + -- 2D + local nBatchFrame = 8 + local input = torch.Tensor(nBatchFrame, ini, from):zero() + + local err = jac.testJacobian(module, input) + mytester:assertlt(err, precision, 'error on state ') + + local err = jac.testJacobianParameters(module, input, module.weight, module.gradWeight) + mytester:assertlt(err , precision, 'error on weight ') + + local err = jac.testJacobianParameters(module, input, module.bias, module.gradBias) + mytester:assertlt(err , precision, 'error on bias ') + + local err = jac.testJacobianUpdateParameters(module, input, module.weight) + mytester:assertlt(err , precision, 'error on weight [direct update]') + + local err = jac.testJacobianUpdateParameters(module, input, module.bias) + mytester:assertlt(err , precision, 'error on bias [direct update]') + + for t,err in pairs(jac.testAllUpdate(module, input, 'weight', 'gradWeight')) do + mytester:assertlt(err, precision, string.format( + 'error on weight [%s]', t)) + end + + for t,err in pairs(jac.testAllUpdate(module, input, 'bias', 'gradBias')) do + mytester:assertlt(err, precision, string.format( + 'error on bias [%s]', t)) + end + local ferr, berr = jac.testIO(module, input) mytester:asserteq(0, ferr, torch.typename(module) .. ' - i/o forward err ') mytester:asserteq(0, berr, torch.typename(module) .. ' - i/o backward err ') |