diff options
author | nicholas-leonard <nick@nikopia.org> | 2015-03-16 00:09:23 +0300 |
---|---|---|
committer | nicholas-leonard <nick@nikopia.org> | 2015-03-16 00:09:23 +0300 |
commit | a54ac28e4261ee62645e4c80ea4db21b02923b6f (patch) | |
tree | 1ff396c9e99c94f5926854eb6ce2a4a533cbf1cd /test | |
parent | bbf319f81ea9808b7ab051dce120383e8c56e7e0 (diff) |
Recurrent and LSTM fixes
Diffstat (limited to 'test')
-rw-r--r-- | test/test-all.lua | 70 |
1 files changed, 70 insertions, 0 deletions
diff --git a/test/test-all.lua b/test/test-all.lua index 5abd4e2..fd72b6e 100644 --- a/test/test-all.lua +++ b/test/test-all.lua @@ -564,6 +564,76 @@ function nnxtest.Recurrent_TestTable() mlp:backwardThroughTime(learningRate) end +function nnxtest.LSTM() + local batchSize = math.random(1,2) + local inputSize = math.random(3,4) + local outputSize = math.random(5,6) + local nStep = 3 + local input = {} + local gradOutput = {} + for step=1,nStep do + input[step] = torch.randn(batchSize, inputSize) + if step == nStep then + -- for the sake of keeping this unit test simple, + gradOutput[step] = torch.randn(batchSize, outputSize) + else + -- only the last step will get a gradient from the output + gradOutput[step] = torch.zeros(batchSize, outputSize) + end + end + local lstm = nn.LSTM(inputSize, outputSize) + + -- we will use this to build an LSTM step by step (with shared params) + local lstmStep = lstm.recurrentModule:clone() + + -- forward/backward through LSTM + local output = {} + lstm:zeroGradParameters() + for step=1,nStep do + output[step] = lstm:forward(input[step]) + assert(torch.isTensor(input[step])) + lstm:backward(input[step], gradOutput[step], 1) + end + local gradInput = lstm:backwardThroughTime() + + local mlp2 -- this one will simulate rho = nSteps + local inputs + for step=1,nStep do + -- iteratively build an LSTM out of non-recurrent components + local lstm = lstmStep:clone() + lstm:share(lstmStep) + lstm:share(lstmStep, 'weight', 'gradWeight', 'bias', 'gradBias') + if step == 1 then + mlp2 = lstm + else + local rnn = nn.Sequential() + local para = nn.ParallelTable() + para:add(nn.Identity()):add(mlp2) + rnn:add(para) + rnn:add(nn.FlattenTable()) + rnn:add(lstm) + mlp2 = rnn + end + + + -- prepare inputs for mlp2 + if inputs then + inputs = {input[step], inputs} + else + inputs = {input[step], torch.zeros(batchSize, outputSize), torch.zeros(batchSize, outputSize)} + end + end + mlp2:add(nn.SelectTable(1)) --just output the output (not cell) + + local output2 = mlp2:forward(inputs) + + + mlp2:zeroGradParameters() + local gradInput2 = mlp2:backward(inputs, gradOutput[nStep], 1/nStep) + mytester:assertTensorEq(gradInput2[2][2][1], gradInput, 0.00001, "LSTM gradInput error") + mytester:assertTensorEq(output[nStep], output2, 0.00001, "LSTM output error") +end + function nnxtest.SpatialNormalization_Gaussian2D() local inputSize = math.random(11,20) local kersize = 9 |