diff options
author | nicholas-leonard <nick@nikopia.org> | 2014-07-10 21:42:01 +0400 |
---|---|---|
committer | nicholas-leonard <nick@nikopia.org> | 2014-07-10 21:42:01 +0400 |
commit | 5e5d7f244a00ab12a5e8d5a0171c6f3bc3c4e9cc (patch) | |
tree | beddba7d1a20b179b6e5b356e40e2c9a49a66d55 /test | |
parent | 7d4971d39e2cf8e7f7069260ad57da298e008a2a (diff) |
added accUpdate to nn.LookupTable
Diffstat (limited to 'test')
-rw-r--r-- | test/test.lua | 11 |
1 files changed, 11 insertions, 0 deletions
diff --git a/test/test.lua b/test/test.lua index 135624d..1eb571e 100644 --- a/test/test.lua +++ b/test/test.lua @@ -1,5 +1,9 @@ require 'torch' +-- you can easily test specific units like this: +-- luajit -lnn -e "nn.test{'LookupTable'}" +-- luajit -lnn -e "nn.test{'LookupTable', 'Add'}" + local mytester = torch.Tester() local jac local sjac @@ -1821,6 +1825,13 @@ function nntest.LookupTable() local ferr,berr = jac.testIO(module,input,minval,maxval) mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ') mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ') + + -- accUpdate + module:accUpdateOnly() + mytester:assert(not module.gradWeight, 'gradWeight is nil') + module:float() + local output = module:forward(input) + module:backwardUpdate(input, output, 0.1) end function nntest.AddConstant() |