diff options
author | Trevor Killeen <killeentm@gmail.com> | 2017-04-03 23:06:09 +0300 |
---|---|---|
committer | Trevor Killeen <killeentm@gmail.com> | 2017-04-04 17:49:25 +0300 |
commit | 16c2267bf6c0e5032eddb6e2a194200725337137 (patch) | |
tree | d7c3bdf801ea7d1ca55a6991dec0a6ec0bfe86fd | |
parent | 9e3c725c036d26c3fad47a8bdc3367d1eb047f43 (diff) |
test for hidden output as well
-rw-r--r-- | test/test_rnn.lua | 8 |
1 files changed, 8 insertions, 0 deletions
diff --git a/test/test_rnn.lua b/test/test_rnn.lua index b50eefc..d2e0518 100644 --- a/test/test_rnn.lua +++ b/test/test_rnn.lua @@ -398,15 +398,20 @@ function rnntest.testVariableLengthSequences() -- batched local packed = cudnn.RNN:packPaddedSequence(input, lengths) local packedOutput = lstm:updateOutput(packed) + local packedHiddenOutput = lstm.hiddenOutput:clone() local separate = {} + local hids = {} for i, length in ipairs(lengths) do local inp = indivInputs[i] local output = lstm2:updateOutput(inp):clone() table.insert(separate, output) + local hid = lstm2.hiddenOutput:clone() + table.insert(hids, hid) end separate = torch.cat(separate, 1):squeeze() + hids = torch.cat(hids, 1):squeeze() mytester:asserteq(packedOutput:size(1), separate:size(1)) mytester:asserteq(packedOutput:size(2), separate:size(2)) @@ -432,6 +437,9 @@ function rnntest.testVariableLengthSequences() local diff = torch.csub(separate[sep], packedOutput[batched]):abs():sum() mytester:assert(diff < 1e-7) end + + local hdiff = torch.csub(packedHiddenOutput, hids):abs():sum() + mytester:assert(hdiff < 1e7) end mytester = torch.Tester() |