diff options
author | nicholas-leonard <nick@nikopia.org> | 2014-08-02 03:26:43 +0400 |
---|---|---|
committer | nicholas-leonard <nick@nikopia.org> | 2014-08-02 03:26:43 +0400 |
commit | f97177231aeea2de9d76f65c52bf0b88fb875997 (patch) | |
tree | 51f18bb2f51485d13c623d604d7fbab98ca3d0ed /test | |
parent | 37d70a3a099aafc0a7e9773c65493ae204452a63 (diff) |
NarrowLookupTable:forward 1D works (unit tested)
Diffstat (limited to 'test')
-rw-r--r-- | test/test-all.lua | 14 |
1 files changed, 14 insertions, 0 deletions
diff --git a/test/test-all.lua b/test/test-all.lua index d9e6297..ffcf642 100644 --- a/test/test-all.lua +++ b/test/test-all.lua @@ -517,7 +517,21 @@ function nnxtest.NarrowLookupTable() local deltaSize = 4 local input = torch.randperm(dictSize):narrow(1,1,nIndex) + local nlt = nn.NarrowLookupTable(deltaSize, dictSize, embedSize) + local output = nlt:forward(input) + local narrowSize = embedSize + local output2 = torch.Tensor(120):zero() + local idx = 1 + for i=1,nIndex do + output2:narrow(1, idx, narrowSize):copy(nlt.weight[input[i]]:narrow(1,1,narrowSize)) + if i == nIndex then + break + end + idx = idx + narrowSize + narrowSize = narrowSize - deltaSize + end + mytester:assertTensorEq(output, output2, 0.000001, "1D forward error") end function nnx.test(tests) |