Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authornicholas-leonard <nick@nikopia.org>2014-08-02 04:13:52 +0400
committernicholas-leonard <nick@nikopia.org>2014-08-02 04:13:52 +0400
commitf2a9b3e5f896bd60b2504523b9de94f77e0ee5a1 (patch)
tree6c0b09d446ee346b1fcd9672a9fb6c46b0c2f95e /test
parent602ea3b6fabdc86005eb7dd249b50a37785e946b (diff)
NarrowLookupTable:accGradParameters works (unit tested)
Diffstat (limited to 'test')
-rw-r--r--test/test-all.lua42
1 files changed, 33 insertions, 9 deletions
diff --git a/test/test-all.lua b/test/test-all.lua
index a2832d4..28254a7 100644
--- a/test/test-all.lua
+++ b/test/test-all.lua
@@ -516,23 +516,34 @@ function nnxtest.NarrowLookupTable()
local embedSize = 32
local deltaSize = 4
+ -- 1D input
local input = torch.randperm(dictSize):narrow(1,1,nIndex)
local nlt = nn.NarrowLookupTable(deltaSize, dictSize, embedSize)
local output = nlt:forward(input)
- local narrowSize = embedSize
local output2 = torch.Tensor(120):zero()
+ local narrowSize = embedSize
local idx = 1
for i=1,nIndex do
output2:narrow(1, idx, narrowSize):copy(nlt.weight[input[i]]:narrow(1,1,narrowSize))
- if i == nIndex then
- break
- end
idx = idx + narrowSize
narrowSize = narrowSize - deltaSize
end
mytester:assertTensorEq(output, output2, 0.000001, "1D forward error")
+ nlt:zeroGradParameters()
+ local gradWeight2 = nlt.gradWeight:clone()
+ nlt:backward(input, output)
+ local idx = 1
+ local narrowSize = embedSize
+ for i=1,nIndex do
+ gradWeight2[input[i]]:narrow(1, 1, narrowSize):copy(output:narrow(1,idx,narrowSize))
+ idx = idx + narrowSize
+ narrowSize = narrowSize - deltaSize
+ end
+ mytester:assertTensorEq(nlt.gradWeight, gradWeight2, 0.000001, "1D backward error")
+
+ -- 2D input
local input = torch.randperm(dictSize):narrow(1,1,nIndex*batchSize):view(8,-1)
local output = nlt:forward(input)
local output2 = torch.Tensor(batchSize, 120):zero()
@@ -542,15 +553,28 @@ function nnxtest.NarrowLookupTable()
local narrowSize = embedSize
local idx = 1
for i=1,nIndex do
- output2:narrow(1, idx, narrowSize):copy(nlt.weight[input[i]]:narrow(1,1,narrowSize))
- if i == nIndex then
- break
- end
+ output2:narrow(1, idx, narrowSize):add(nlt.weight[input[i]]:narrow(1,1,narrowSize))
idx = idx + narrowSize
narrowSize = narrowSize - deltaSize
end
end
- mytester:assertTensorEq(output, output2, 0.000001, "1D forward error")
+ mytester:assertTensorEq(output, output2, 0.000001, "2D forward error")
+
+ nlt:zeroGradParameters()
+ local gradWeight2 = nlt.gradWeight:clone()
+ nlt:backward(input, output)
+ for k=1,batchSize do
+ local input = input[k]
+ local output = output[k]
+ local idx = 1
+ local narrowSize = embedSize
+ for i=1,nIndex do
+ gradWeight2[input[i]]:narrow(1,1,narrowSize):add(output:narrow(1,idx,narrowSize))
+ idx = idx + narrowSize
+ narrowSize = narrowSize - deltaSize
+ end
+ end
+ mytester:assertTensorEq(nlt.gradWeight, gradWeight2, 0.000001, "2D backward error")
end
function nnx.test(tests)