Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authornicholas-leonard <nick@nikopia.org>2015-03-16 00:09:23 +0300
committernicholas-leonard <nick@nikopia.org>2015-03-16 00:09:23 +0300
commita54ac28e4261ee62645e4c80ea4db21b02923b6f (patch)
tree1ff396c9e99c94f5926854eb6ce2a4a533cbf1cd /test
parentbbf319f81ea9808b7ab051dce120383e8c56e7e0 (diff)
Recurrent and LSTM fixes
Diffstat (limited to 'test')
-rw-r--r--test/test-all.lua70
1 files changed, 70 insertions, 0 deletions
diff --git a/test/test-all.lua b/test/test-all.lua
index 5abd4e2..fd72b6e 100644
--- a/test/test-all.lua
+++ b/test/test-all.lua
@@ -564,6 +564,76 @@ function nnxtest.Recurrent_TestTable()
mlp:backwardThroughTime(learningRate)
end
+function nnxtest.LSTM()
+ local batchSize = math.random(1,2)
+ local inputSize = math.random(3,4)
+ local outputSize = math.random(5,6)
+ local nStep = 3
+ local input = {}
+ local gradOutput = {}
+ for step=1,nStep do
+ input[step] = torch.randn(batchSize, inputSize)
+ if step == nStep then
+ -- for the sake of keeping this unit test simple,
+ gradOutput[step] = torch.randn(batchSize, outputSize)
+ else
+ -- only the last step will get a gradient from the output
+ gradOutput[step] = torch.zeros(batchSize, outputSize)
+ end
+ end
+ local lstm = nn.LSTM(inputSize, outputSize)
+
+ -- we will use this to build an LSTM step by step (with shared params)
+ local lstmStep = lstm.recurrentModule:clone()
+
+ -- forward/backward through LSTM
+ local output = {}
+ lstm:zeroGradParameters()
+ for step=1,nStep do
+ output[step] = lstm:forward(input[step])
+ assert(torch.isTensor(input[step]))
+ lstm:backward(input[step], gradOutput[step], 1)
+ end
+ local gradInput = lstm:backwardThroughTime()
+
+ local mlp2 -- this one will simulate rho = nSteps
+ local inputs
+ for step=1,nStep do
+ -- iteratively build an LSTM out of non-recurrent components
+ local lstm = lstmStep:clone()
+ lstm:share(lstmStep)
+ lstm:share(lstmStep, 'weight', 'gradWeight', 'bias', 'gradBias')
+ if step == 1 then
+ mlp2 = lstm
+ else
+ local rnn = nn.Sequential()
+ local para = nn.ParallelTable()
+ para:add(nn.Identity()):add(mlp2)
+ rnn:add(para)
+ rnn:add(nn.FlattenTable())
+ rnn:add(lstm)
+ mlp2 = rnn
+ end
+
+
+ -- prepare inputs for mlp2
+ if inputs then
+ inputs = {input[step], inputs}
+ else
+ inputs = {input[step], torch.zeros(batchSize, outputSize), torch.zeros(batchSize, outputSize)}
+ end
+ end
+ mlp2:add(nn.SelectTable(1)) --just output the output (not cell)
+
+ local output2 = mlp2:forward(inputs)
+
+
+ mlp2:zeroGradParameters()
+ local gradInput2 = mlp2:backward(inputs, gradOutput[nStep], 1/nStep)
+ mytester:assertTensorEq(gradInput2[2][2][1], gradInput, 0.00001, "LSTM gradInput error")
+ mytester:assertTensorEq(output[nStep], output2, 0.00001, "LSTM output error")
+end
+
function nnxtest.SpatialNormalization_Gaussian2D()
local inputSize = math.random(11,20)
local kersize = 9