Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authornicholas-leonard <nick@nikopia.org>2014-08-02 04:43:57 +0400
committernicholas-leonard <nick@nikopia.org>2014-08-02 04:43:57 +0400
commiteacdc4c7ce924efe6e4055143c38e6be57d932a0 (patch)
tree2fe04591ed479feaec577b8f7c89536271ca4a95 /test
parentf2a9b3e5f896bd60b2504523b9de94f77e0ee5a1 (diff)
NarrowLookupTable:[accUpdateGradParameters,type] work (unit tested)
Diffstat (limited to 'test')
-rw-r--r--test/test-all.lua34
1 files changed, 32 insertions, 2 deletions
diff --git a/test/test-all.lua b/test/test-all.lua
index 28254a7..1d4f617 100644
--- a/test/test-all.lua
+++ b/test/test-all.lua
@@ -515,6 +515,7 @@ function nnxtest.NarrowLookupTable()
local batchSize = 8
local embedSize = 32
local deltaSize = 4
+ local lr = 0.1
-- 1D input
local input = torch.randperm(dictSize):narrow(1,1,nIndex)
@@ -537,16 +538,29 @@ function nnxtest.NarrowLookupTable()
local idx = 1
local narrowSize = embedSize
for i=1,nIndex do
- gradWeight2[input[i]]:narrow(1, 1, narrowSize):copy(output:narrow(1,idx,narrowSize))
+ gradWeight2[input[i]]:narrow(1, 1, narrowSize):add(output:narrow(1,idx,narrowSize))
idx = idx + narrowSize
narrowSize = narrowSize - deltaSize
end
mytester:assertTensorEq(nlt.gradWeight, gradWeight2, 0.000001, "1D backward error")
+ nlt:zeroGradParameters()
+ local weight2 = nlt.weight:clone()
+ nlt:backwardUpdate(input, output, lr)
+ local idx = 1
+ local narrowSize = embedSize
+ for i=1,nIndex do
+ weight2[input[i]]:narrow(1, 1, narrowSize):add(-lr, output:narrow(1,idx,narrowSize))
+ idx = idx + narrowSize
+ narrowSize = narrowSize - deltaSize
+ end
+ mytester:assertTensorEq(nlt.weight, weight2, 0.000001, "1D backwardUpdate error")
+
-- 2D input
+ nlt:float()
local input = torch.randperm(dictSize):narrow(1,1,nIndex*batchSize):view(8,-1)
local output = nlt:forward(input)
- local output2 = torch.Tensor(batchSize, 120):zero()
+ local output2 = torch.FloatTensor(batchSize, 120):zero()
for k=1,batchSize do
local input = input[k]
local output2 = output2[k]
@@ -575,6 +589,22 @@ function nnxtest.NarrowLookupTable()
end
end
mytester:assertTensorEq(nlt.gradWeight, gradWeight2, 0.000001, "2D backward error")
+
+ nlt:zeroGradParameters()
+ local weight2 = nlt.weight:clone()
+ nlt:backwardUpdate(input, output, lr)
+ for k=1,batchSize do
+ local input = input[k]
+ local output = output[k]
+ local idx = 1
+ local narrowSize = embedSize
+ for i=1,nIndex do
+ weight2[input[i]]:narrow(1,1,narrowSize):add(-lr, output:narrow(1,idx,narrowSize))
+ idx = idx + narrowSize
+ narrowSize = narrowSize - deltaSize
+ end
+ end
+ mytester:assertTensorEq(nlt.weight, weight2, 0.000001, "2D backwardUpdate error")
end
function nnx.test(tests)