diff options
-rw-r--r-- | test/test_rnn.lua | 8 |
1 files changed, 8 insertions, 0 deletions
diff --git a/test/test_rnn.lua b/test/test_rnn.lua index b50eefc..d2e0518 100644 --- a/test/test_rnn.lua +++ b/test/test_rnn.lua @@ -398,15 +398,20 @@ function rnntest.testVariableLengthSequences() -- batched local packed = cudnn.RNN:packPaddedSequence(input, lengths) local packedOutput = lstm:updateOutput(packed) + local packedHiddenOutput = lstm.hiddenOutput:clone() local separate = {} + local hids = {} for i, length in ipairs(lengths) do local inp = indivInputs[i] local output = lstm2:updateOutput(inp):clone() table.insert(separate, output) + local hid = lstm2.hiddenOutput:clone() + table.insert(hids, hid) end separate = torch.cat(separate, 1):squeeze() + hids = torch.cat(hids, 1):squeeze() mytester:asserteq(packedOutput:size(1), separate:size(1)) mytester:asserteq(packedOutput:size(2), separate:size(2)) @@ -432,6 +437,9 @@ function rnntest.testVariableLengthSequences() local diff = torch.csub(separate[sep], packedOutput[batched]):abs():sum() mytester:assert(diff < 1e-7) end + + local hdiff = torch.csub(packedHiddenOutput, hids):abs():sum() + mytester:assert(hdiff < 1e7) end mytester = torch.Tester() |