diff options
author | nicholas-leonard <nick@nikopia.org> | 2014-08-02 04:13:52 +0400 |
---|---|---|
committer | nicholas-leonard <nick@nikopia.org> | 2014-08-02 04:13:52 +0400 |
commit | f2a9b3e5f896bd60b2504523b9de94f77e0ee5a1 (patch) | |
tree | 6c0b09d446ee346b1fcd9672a9fb6c46b0c2f95e /test | |
parent | 602ea3b6fabdc86005eb7dd249b50a37785e946b (diff) |
NarrowLookupTable:accGradParameters works (unit tested)
Diffstat (limited to 'test')
-rw-r--r-- | test/test-all.lua | 42 |
1 files changed, 33 insertions, 9 deletions
diff --git a/test/test-all.lua b/test/test-all.lua index a2832d4..28254a7 100644 --- a/test/test-all.lua +++ b/test/test-all.lua @@ -516,23 +516,34 @@ function nnxtest.NarrowLookupTable() local embedSize = 32 local deltaSize = 4 + -- 1D input local input = torch.randperm(dictSize):narrow(1,1,nIndex) local nlt = nn.NarrowLookupTable(deltaSize, dictSize, embedSize) local output = nlt:forward(input) - local narrowSize = embedSize local output2 = torch.Tensor(120):zero() + local narrowSize = embedSize local idx = 1 for i=1,nIndex do output2:narrow(1, idx, narrowSize):copy(nlt.weight[input[i]]:narrow(1,1,narrowSize)) - if i == nIndex then - break - end idx = idx + narrowSize narrowSize = narrowSize - deltaSize end mytester:assertTensorEq(output, output2, 0.000001, "1D forward error") + nlt:zeroGradParameters() + local gradWeight2 = nlt.gradWeight:clone() + nlt:backward(input, output) + local idx = 1 + local narrowSize = embedSize + for i=1,nIndex do + gradWeight2[input[i]]:narrow(1, 1, narrowSize):copy(output:narrow(1,idx,narrowSize)) + idx = idx + narrowSize + narrowSize = narrowSize - deltaSize + end + mytester:assertTensorEq(nlt.gradWeight, gradWeight2, 0.000001, "1D backward error") + + -- 2D input local input = torch.randperm(dictSize):narrow(1,1,nIndex*batchSize):view(8,-1) local output = nlt:forward(input) local output2 = torch.Tensor(batchSize, 120):zero() @@ -542,15 +553,28 @@ function nnxtest.NarrowLookupTable() local narrowSize = embedSize local idx = 1 for i=1,nIndex do - output2:narrow(1, idx, narrowSize):copy(nlt.weight[input[i]]:narrow(1,1,narrowSize)) - if i == nIndex then - break - end + output2:narrow(1, idx, narrowSize):add(nlt.weight[input[i]]:narrow(1,1,narrowSize)) idx = idx + narrowSize narrowSize = narrowSize - deltaSize end end - mytester:assertTensorEq(output, output2, 0.000001, "1D forward error") + mytester:assertTensorEq(output, output2, 0.000001, "2D forward error") + + nlt:zeroGradParameters() + local gradWeight2 = nlt.gradWeight:clone() + nlt:backward(input, output) + for k=1,batchSize do + local input = input[k] + local output = output[k] + local idx = 1 + local narrowSize = embedSize + for i=1,nIndex do + gradWeight2[input[i]]:narrow(1,1,narrowSize):add(output:narrow(1,idx,narrowSize)) + idx = idx + narrowSize + narrowSize = narrowSize - deltaSize + end + end + mytester:assertTensorEq(nlt.gradWeight, gradWeight2, 0.000001, "2D backward error") end function nnx.test(tests) |