diff options
author | nicholas-leonard <nick@nikopia.org> | 2014-08-02 04:43:57 +0400 |
---|---|---|
committer | nicholas-leonard <nick@nikopia.org> | 2014-08-02 04:43:57 +0400 |
commit | eacdc4c7ce924efe6e4055143c38e6be57d932a0 (patch) | |
tree | 2fe04591ed479feaec577b8f7c89536271ca4a95 /test | |
parent | f2a9b3e5f896bd60b2504523b9de94f77e0ee5a1 (diff) |
NarrowLookupTable:[accUpdateGradParameters,type] work (unit tested)
Diffstat (limited to 'test')
-rw-r--r-- | test/test-all.lua | 34 |
1 files changed, 32 insertions, 2 deletions
diff --git a/test/test-all.lua b/test/test-all.lua index 28254a7..1d4f617 100644 --- a/test/test-all.lua +++ b/test/test-all.lua @@ -515,6 +515,7 @@ function nnxtest.NarrowLookupTable() local batchSize = 8 local embedSize = 32 local deltaSize = 4 + local lr = 0.1 -- 1D input local input = torch.randperm(dictSize):narrow(1,1,nIndex) @@ -537,16 +538,29 @@ function nnxtest.NarrowLookupTable() local idx = 1 local narrowSize = embedSize for i=1,nIndex do - gradWeight2[input[i]]:narrow(1, 1, narrowSize):copy(output:narrow(1,idx,narrowSize)) + gradWeight2[input[i]]:narrow(1, 1, narrowSize):add(output:narrow(1,idx,narrowSize)) idx = idx + narrowSize narrowSize = narrowSize - deltaSize end mytester:assertTensorEq(nlt.gradWeight, gradWeight2, 0.000001, "1D backward error") + nlt:zeroGradParameters() + local weight2 = nlt.weight:clone() + nlt:backwardUpdate(input, output, lr) + local idx = 1 + local narrowSize = embedSize + for i=1,nIndex do + weight2[input[i]]:narrow(1, 1, narrowSize):add(-lr, output:narrow(1,idx,narrowSize)) + idx = idx + narrowSize + narrowSize = narrowSize - deltaSize + end + mytester:assertTensorEq(nlt.weight, weight2, 0.000001, "1D backwardUpdate error") + -- 2D input + nlt:float() local input = torch.randperm(dictSize):narrow(1,1,nIndex*batchSize):view(8,-1) local output = nlt:forward(input) - local output2 = torch.Tensor(batchSize, 120):zero() + local output2 = torch.FloatTensor(batchSize, 120):zero() for k=1,batchSize do local input = input[k] local output2 = output2[k] @@ -575,6 +589,22 @@ function nnxtest.NarrowLookupTable() end end mytester:assertTensorEq(nlt.gradWeight, gradWeight2, 0.000001, "2D backward error") + + nlt:zeroGradParameters() + local weight2 = nlt.weight:clone() + nlt:backwardUpdate(input, output, lr) + for k=1,batchSize do + local input = input[k] + local output = output[k] + local idx = 1 + local narrowSize = embedSize + for i=1,nIndex do + weight2[input[i]]:narrow(1,1,narrowSize):add(-lr, output:narrow(1,idx,narrowSize)) + idx = idx + narrowSize + narrowSize = narrowSize - deltaSize + end + end + mytester:assertTensorEq(nlt.weight, weight2, 0.000001, "2D backwardUpdate error") end function nnx.test(tests) |