Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorPáidí Creed <paidi@swiftkey.net>2014-02-01 20:53:29 +0400
committerPáidí Creed <paidi@swiftkey.net>2014-02-01 21:07:13 +0400
commit31bf7f120ab5ea43e769bb33248a69b12bdd3a25 (patch)
tree0832a3e7af3e0aea275dce751828501be32bdeab /test
parent2152758d904b4cab0ace02817203a65d92acbb10 (diff)
Add extra tests for SparseLinear and fix bug where scale was not being multiplied into bias updates
Diffstat (limited to 'test')
-rw-r--r--test/test.lua12
1 files changed, 9 insertions, 3 deletions
diff --git a/test/test.lua b/test/test.lua
index 27bb114..eb6dede 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -238,8 +238,8 @@ function nntest.Sqrt()
end
function nntest.Linear()
- local ini = math.random(50,70)
- local inj = math.random(50,70)
+ local ini = math.random(5,7)
+ local inj = math.random(5,7)
local input = torch.Tensor(ini):zero()
local module = nn.Linear(ini,inj)
@@ -335,7 +335,7 @@ function nntest.SparseLinear()
local err = (expected - actual):abs():max()
mytester:assertle(err, precision, 'error on result')
- -- Jacobian
+ -- Jacobian 1D
local err = sjac.testJacobian(module,input)
mytester:assertlt(err,precision, 'error on state ')
@@ -345,6 +345,12 @@ function nntest.SparseLinear()
local err = sjac.testJacobianParameters(module, input, module.bias, module.gradBias)
mytester:assertlt(err,precision, 'error on bias ')
+ local err = sjac.testJacobianUpdateParameters(module, input, module.weight)
+ mytester:assertlt(err,precision, 'error on weight [direct update] ')
+
+ local err = sjac.testJacobianUpdateParameters(module, input, module.bias)
+ mytester:assertlt(err,precision, 'error on bias [direct update] ')
+
local ferr, berr = sjac.testIO(module, input)
mytester:asserteq(0, ferr, torch.typename(module) .. ' - i/o forward err ')
mytester:asserteq(0, berr, torch.typename(module) .. ' - i/o backward err ')