Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authornicholas-leonard <nick@nikopia.org>2014-08-02 03:26:43 +0400
committernicholas-leonard <nick@nikopia.org>2014-08-02 03:26:43 +0400
commitf97177231aeea2de9d76f65c52bf0b88fb875997 (patch)
tree51f18bb2f51485d13c623d604d7fbab98ca3d0ed /test
parent37d70a3a099aafc0a7e9773c65493ae204452a63 (diff)
NarrowLookupTable:forward 1D works (unit tested)
Diffstat (limited to 'test')
-rw-r--r--test/test-all.lua14
1 files changed, 14 insertions, 0 deletions
diff --git a/test/test-all.lua b/test/test-all.lua
index d9e6297..ffcf642 100644
--- a/test/test-all.lua
+++ b/test/test-all.lua
@@ -517,7 +517,21 @@ function nnxtest.NarrowLookupTable()
local deltaSize = 4
local input = torch.randperm(dictSize):narrow(1,1,nIndex)
+ local nlt = nn.NarrowLookupTable(deltaSize, dictSize, embedSize)
+ local output = nlt:forward(input)
+ local narrowSize = embedSize
+ local output2 = torch.Tensor(120):zero()
+ local idx = 1
+ for i=1,nIndex do
+ output2:narrow(1, idx, narrowSize):copy(nlt.weight[input[i]]:narrow(1,1,narrowSize))
+ if i == nIndex then
+ break
+ end
+ idx = idx + narrowSize
+ narrowSize = narrowSize - deltaSize
+ end
+ mytester:assertTensorEq(output, output2, 0.000001, "1D forward error")
end
function nnx.test(tests)