Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authornicholas-leonard <nick@nikopia.org>2014-08-02 06:31:36 +0400
committernicholas-leonard <nick@nikopia.org>2014-08-02 06:31:36 +0400
commit5cbd4e605299fac5f1c68848e2549c49da4e44e4 (patch)
tree28f2904132cf72292a2e4695a2076c131244a13f /test
parenteacdc4c7ce924efe6e4055143c38e6be57d932a0 (diff)
NarrowLookupTable ascDelta=true works (unit tested) + doc
Diffstat (limited to 'test')
-rw-r--r--test/test-all.lua27
1 files changed, 27 insertions, 0 deletions
diff --git a/test/test-all.lua b/test/test-all.lua
index 1d4f617..a511be3 100644
--- a/test/test-all.lua
+++ b/test/test-all.lua
@@ -517,6 +517,33 @@ function nnxtest.NarrowLookupTable()
local deltaSize = 4
local lr = 0.1
+ -- 1D input ascDelta = false
+ local input = torch.randperm(dictSize):narrow(1,1,nIndex)
+ local nlt = nn.NarrowLookupTable(deltaSize, dictSize, embedSize, false)
+ local output = nlt:forward(input)
+
+ local output2 = torch.Tensor(120):zero()
+ local narrowSize = embedSize
+ local idx = 121 - narrowSize
+ for i=nIndex,1,-1 do
+ output2:narrow(1, idx, narrowSize):copy(nlt.weight[input[i]]:narrow(1,1,narrowSize))
+ narrowSize = narrowSize - deltaSize
+ idx = idx - narrowSize
+ end
+ mytester:assertTensorEq(output, output2, 0.000001, "1D forward ascDelta = false error")
+
+ nlt:zeroGradParameters()
+ local gradWeight2 = nlt.gradWeight:clone()
+ nlt:backward(input, output)
+ local narrowSize = embedSize
+ local idx = 121 - narrowSize
+ for i=nIndex,1,-1 do
+ gradWeight2[input[i]]:narrow(1, 1, narrowSize):add(output:narrow(1,idx,narrowSize))
+ narrowSize = narrowSize - deltaSize
+ idx = idx - narrowSize
+ end
+ mytester:assertTensorEq(nlt.gradWeight, gradWeight2, 0.000001, "1D backward ascDelta = false error")
+
-- 1D input
local input = torch.randperm(dictSize):narrow(1,1,nIndex)
local nlt = nn.NarrowLookupTable(deltaSize, dictSize, embedSize)