Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authornicholas-leonard <nick@nikopia.org>2014-07-10 21:42:01 +0400
committernicholas-leonard <nick@nikopia.org>2014-07-10 21:42:01 +0400
commit5e5d7f244a00ab12a5e8d5a0171c6f3bc3c4e9cc (patch)
treebeddba7d1a20b179b6e5b356e40e2c9a49a66d55 /test
parent7d4971d39e2cf8e7f7069260ad57da298e008a2a (diff)
added accUpdate to nn.LookupTable
Diffstat (limited to 'test')
-rw-r--r--test/test.lua11
1 files changed, 11 insertions, 0 deletions
diff --git a/test/test.lua b/test/test.lua
index 135624d..1eb571e 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -1,5 +1,9 @@
require 'torch'
+-- you can easily test specific units like this:
+-- luajit -lnn -e "nn.test{'LookupTable'}"
+-- luajit -lnn -e "nn.test{'LookupTable', 'Add'}"
+
local mytester = torch.Tester()
local jac
local sjac
@@ -1821,6 +1825,13 @@ function nntest.LookupTable()
local ferr,berr = jac.testIO(module,input,minval,maxval)
mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ')
mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ')
+
+ -- accUpdate
+ module:accUpdateOnly()
+ mytester:assert(not module.gradWeight, 'gradWeight is nil')
+ module:float()
+ local output = module:forward(input)
+ module:backwardUpdate(input, output, 0.1)
end
function nntest.AddConstant()