Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/soumith/cudnn.torch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTrevor Killeen <killeentm@gmail.com>2017-04-03 23:06:09 +0300
committerTrevor Killeen <killeentm@gmail.com>2017-04-04 17:49:25 +0300
commit16c2267bf6c0e5032eddb6e2a194200725337137 (patch)
treed7c3bdf801ea7d1ca55a6991dec0a6ec0bfe86fd
parent9e3c725c036d26c3fad47a8bdc3367d1eb047f43 (diff)
test for hidden output as well
-rw-r--r--test/test_rnn.lua8
1 files changed, 8 insertions, 0 deletions
diff --git a/test/test_rnn.lua b/test/test_rnn.lua
index b50eefc..d2e0518 100644
--- a/test/test_rnn.lua
+++ b/test/test_rnn.lua
@@ -398,15 +398,20 @@ function rnntest.testVariableLengthSequences()
-- batched
local packed = cudnn.RNN:packPaddedSequence(input, lengths)
local packedOutput = lstm:updateOutput(packed)
+ local packedHiddenOutput = lstm.hiddenOutput:clone()
local separate = {}
+ local hids = {}
for i, length in ipairs(lengths) do
local inp = indivInputs[i]
local output = lstm2:updateOutput(inp):clone()
table.insert(separate, output)
+ local hid = lstm2.hiddenOutput:clone()
+ table.insert(hids, hid)
end
separate = torch.cat(separate, 1):squeeze()
+ hids = torch.cat(hids, 1):squeeze()
mytester:asserteq(packedOutput:size(1), separate:size(1))
mytester:asserteq(packedOutput:size(2), separate:size(2))
@@ -432,6 +437,9 @@ function rnntest.testVariableLengthSequences()
local diff = torch.csub(separate[sep], packedOutput[batched]):abs():sum()
mytester:assert(diff < 1e-7)
end
+
+ local hdiff = torch.csub(packedHiddenOutput, hids):abs():sum()
+ mytester:assert(hdiff < 1e7)
end
mytester = torch.Tester()