diff options
author | nicholas-leonard <nick@nikopia.org> | 2014-08-02 06:31:36 +0400 |
---|---|---|
committer | nicholas-leonard <nick@nikopia.org> | 2014-08-02 06:31:36 +0400 |
commit | 5cbd4e605299fac5f1c68848e2549c49da4e44e4 (patch) | |
tree | 28f2904132cf72292a2e4695a2076c131244a13f /test | |
parent | eacdc4c7ce924efe6e4055143c38e6be57d932a0 (diff) |
NarrowLookupTable ascDelta=true works (unit tested) + doc
Diffstat (limited to 'test')
-rw-r--r-- | test/test-all.lua | 27 |
1 files changed, 27 insertions, 0 deletions
diff --git a/test/test-all.lua b/test/test-all.lua index 1d4f617..a511be3 100644 --- a/test/test-all.lua +++ b/test/test-all.lua @@ -517,6 +517,33 @@ function nnxtest.NarrowLookupTable() local deltaSize = 4 local lr = 0.1 + -- 1D input ascDelta = false + local input = torch.randperm(dictSize):narrow(1,1,nIndex) + local nlt = nn.NarrowLookupTable(deltaSize, dictSize, embedSize, false) + local output = nlt:forward(input) + + local output2 = torch.Tensor(120):zero() + local narrowSize = embedSize + local idx = 121 - narrowSize + for i=nIndex,1,-1 do + output2:narrow(1, idx, narrowSize):copy(nlt.weight[input[i]]:narrow(1,1,narrowSize)) + narrowSize = narrowSize - deltaSize + idx = idx - narrowSize + end + mytester:assertTensorEq(output, output2, 0.000001, "1D forward ascDelta = false error") + + nlt:zeroGradParameters() + local gradWeight2 = nlt.gradWeight:clone() + nlt:backward(input, output) + local narrowSize = embedSize + local idx = 121 - narrowSize + for i=nIndex,1,-1 do + gradWeight2[input[i]]:narrow(1, 1, narrowSize):add(output:narrow(1,idx,narrowSize)) + narrowSize = narrowSize - deltaSize + idx = idx - narrowSize + end + mytester:assertTensorEq(nlt.gradWeight, gradWeight2, 0.000001, "1D backward ascDelta = false error") + -- 1D input local input = torch.randperm(dictSize):narrow(1,1,nIndex) local nlt = nn.NarrowLookupTable(deltaSize, dictSize, embedSize) |