From 9e3c725c036d26c3fad47a8bdc3367d1eb047f43 Mon Sep 17 00:00:00 2001 From: Trevor Killeen Date: Mon, 3 Apr 2017 12:57:00 -0700 Subject: outputs for forward pass working + test --- RNN.lua | 7 +++- test/test_rnn.lua | 122 +++++++++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 126 insertions(+), 3 deletions(-) diff --git a/RNN.lua b/RNN.lua index 45e137e..c61d84d 100644 --- a/RNN.lua +++ b/RNN.lua @@ -180,7 +180,7 @@ function RNN:resetOutputDescriptor(output, batchSizes) local dim = torch.IntTensor({batchSizes[i+1], self.hiddenSize * self.numDirections, 1}) local stride = torch.IntTensor({dim[3] * dim[2], dim[3],1}) errcheck('cudnnSetTensorNdDescriptor', - self.xDescs[i], + self.yDescs[i], self.datatype, 3, dim:data(), @@ -468,7 +468,9 @@ function RNN:updateOutput(input) -- Make sure input is contiguous local x = self:makeContiguous(self.inputPacked and input[1] or input) local oSize = self:deriveOutputSize(x) - local oStride = torch.LongStorage({oSize[2] * oSize[3], oSize[3], 1}) + local oStride = self.inputPacked and + torch.LongStorage({oSize[2], 1}) or + torch.LongStorage({oSize[2] * oSize[3], oSize[3], 1}) self.output:resize(oSize, oStride) local y = self.output local w = self.weight @@ -547,6 +549,7 @@ function RNN:updateOutput(input) local elemSize = self.reserve:elementSize() reserveSize = math.floor((reserveSize + elemSize - 1) / elemSize) self.reserve:resize(reserveSize) + errcheck('cudnnRNNForwardTraining', cudnn.getHandle(), self.rnnDesc[0], diff --git a/test/test_rnn.lua b/test/test_rnn.lua index b761507..b50eefc 100644 --- a/test/test_rnn.lua +++ b/test/test_rnn.lua @@ -314,6 +314,126 @@ function cudnntest.testPackPadSequences() mytester:assertTableEq(lengths, actualLengths) end +-- clone the parameters of src into dest, assumes both RNNs were created with +-- the same options (e.g. same input size, hidden size, layers, etc.) +local function deepcopyRNN(dest, src) + dest.weight = src.weight:clone() -- encompasses W_hh, W_xh etc. + dest.gradWeight = src.gradWeight:clone() +end + +function rnntest.testVariableLengthSequences() + local input = torch.CudaTensor({ + {{1, 2, 2, 1}, + {2, 1, 2, 2}, + {1, 1, 1, 2}, + {2, 2, 2, 1}}, + {{4, 1, 3, 1}, + {3, 1, 2, 1}, + {1, 1, 2, 1}, + {0, 0, 0, 0}}, + {{1, 1, 2, 1}, + {2, 1, 2, 2}, + {1, 2, 2, 1}, + {0, 0, 0, 0}}, + {{1, 2, 1, 1}, + {0, 0, 0, 0}, + {0, 0, 0, 0}, + {0, 0, 0, 0}} + }) + + -- same as above + local indivInputs = { + torch.CudaTensor({ + {{1, 2, 2, 1}}, + {{4, 1, 3, 1}}, + {{1, 1, 2, 1}}, + {{1, 2, 1, 1}}, + }), + torch.CudaTensor({ + {{2, 1, 2, 2}}, + {{3, 1, 2, 1}}, + {{2, 1, 2, 2}}, + }), + torch.CudaTensor({ + {{1, 1, 1, 2}}, + {{1, 1, 2, 1}}, + {{1, 2, 2, 1}}, + }), + torch.CudaTensor({ + {{2, 2, 2, 1}}, + }), + } + + local lengths = {4, 3, 3, 1} + local maxLength = 4 + + local inputSize = 4 + local hiddenSize = 10 + local numLayers = 1 + local batchFirst = false + local dropout = false + local rememberStates = false + + local lstm = cudnn.LSTM( + inputSize, + hiddenSize, + numLayers, + batchFirst, + dropout, + rememberStates) + + local lstm2 = cudnn.LSTM( + inputSize, + hiddenSize, + numLayers, + batchFirst, + dropout, + rememberStates) + + deepcopyRNN(lstm2, lstm) + + -- Step 1: Pass Sequences as batch and individually, verify weights, outputs + -- are the same in both instances + + -- batched + local packed = cudnn.RNN:packPaddedSequence(input, lengths) + local packedOutput = lstm:updateOutput(packed) + + local separate = {} + + for i, length in ipairs(lengths) do + local inp = indivInputs[i] + local output = lstm2:updateOutput(inp):clone() + table.insert(separate, output) + end + separate = torch.cat(separate, 1):squeeze() + + mytester:asserteq(packedOutput:size(1), separate:size(1)) + mytester:asserteq(packedOutput:size(2), separate:size(2)) + + -- packedOutput has format where all 4 from first batch, then all 3 from + -- second batch, etc. while separate has all 4 from first sequence, + -- all 3 from next sequence, etc. I manually map the matches here + local corresponding = { + {1, 1}, + {2, 5}, + {3, 8}, + {4, 11}, + {5, 2}, + {6, 6}, + {7, 9}, + {8, 3}, + {9, 7}, + {10, 10}, + {11, 4} + } + for _, pair in ipairs(corresponding) do + sep, batched = unpack(pair) + local diff = torch.csub(separate[sep], packedOutput[batched]):abs():sum() + mytester:assert(diff < 1e-7) + end +end + mytester = torch.Tester() -mytester:add(cudnntest) +mytester:add(rnntest) mytester:run() -- cgit v1.2.3