Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorNicholas Leonard <nick@nikopia.org>2014-04-10 00:49:43 +0400
committerNicholas Leonard <nick@nikopia.org>2014-04-10 00:49:43 +0400
commit6783768eb6e06ebbf0c1bd2d6b5f4bc9f709b32f (patch)
treee78ffe0e0f7651ef6767a3de09443e9a19ffd474 /test
parent2781f4bee6c6f725800c863c3c2e96a551295328 (diff)
unit test for 1D LookupTable
Diffstat (limited to 'test')
-rw-r--r--test/test.lua54
1 files changed, 54 insertions, 0 deletions
diff --git a/test/test.lua b/test/test.lua
index b2b5f92..91b38e8 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -1646,6 +1646,60 @@ function nntest.PairwiseDistance()
end
end
+function nntest.LookupTable()
+ local totalIndex = math.random(10,100)
+ local nIndex = math.random(5,7)
+ local entry_size = math.random(5,7)
+ local input = torch.Tensor(nIndex):zero()
+ local module = nn.LookupTable(totalIndex, entry_size)
+ local minval = 1
+ local maxval = totalIndex
+
+ -- 1D
+ local err = jac.testJacobianParameters(module, input, module.weight, module.gradWeight, minval, maxval)
+ mytester:assertlt(err,precision, 'error on weight ')
+
+ local err = jac.testJacobianUpdateParameters(module, input, module.weight, minval, maxval)
+ mytester:assertlt(err,precision, 'error on weight [direct update] ')
+
+ module.gradWeight:zero()
+ for t,err in pairs(jac.testAllUpdate(module, input, 'weight', 'gradWeight')) do
+ mytester:assertlt(err, precision, string.format(
+ 'error on weight [%s]', t))
+ end
+
+ -- 2D
+ local nframe = math.random(50,70)
+ local input = torch.Tensor(nframe, nIndex):zero()
+
+ local err = jac.testJacobianParameters(module, input, module.weight, module.gradWeight)
+ mytester:assertlt(err,precision, 'error on weight ')
+
+ local err = jac.testJacobianParameters(module, input, module.bias, module.gradBias)
+ mytester:assertlt(err,precision, 'error on weight ')
+
+ local err = jac.testJacobianUpdateParameters(module, input, module.weight)
+ mytester:assertlt(err,precision, 'error on weight [direct update] ')
+
+ local err = jac.testJacobianUpdateParameters(module, input, module.bias)
+ mytester:assertlt(err,precision, 'error on bias [direct update] ')
+
+ for t,err in pairs(jac.testAllUpdate(module, input, 'weight', 'gradWeight')) do
+ mytester:assertlt(err, precision, string.format(
+ 'error on weight [%s]', t))
+ end
+
+ for t,err in pairs(jac.testAllUpdate(module, input, 'bias', 'gradBias')) do
+ mytester:assertlt(err, precision, string.format(
+ 'error on bias [%s]', t))
+ end
+
+ -- IO
+ local ferr,berr = jac.testIO(module,input)
+ mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ')
+ mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ')
+end
+
mytester:add(nntest)
if not nn then