diff options
author | nicholas-leonard <nick@nikopia.org> | 2015-03-17 07:39:36 +0300 |
---|---|---|
committer | nicholas-leonard <nick@nikopia.org> | 2015-03-17 07:39:36 +0300 |
commit | 4a3f749d7bd19db2f313879c3b7013968eb76a33 (patch) | |
tree | 97da3d165cad4273b240883f4fe15785f8853146 /test | |
parent | 1c09f06acc222f50dedb24c0a8c953333b22207a (diff) |
LSTM backwardThroughTime unit tested
Diffstat (limited to 'test')
-rw-r--r-- | test/test-all.lua | 10 |
1 files changed, 6 insertions, 4 deletions
diff --git a/test/test-all.lua b/test/test-all.lua index df51302..a824737 100644 --- a/test/test-all.lua +++ b/test/test-all.lua @@ -347,6 +347,10 @@ function nnxtest.Recurrent() mytester:assert(isRecursable(mlp.initialModule, torch.randn(inputSize)), "Recurrent isRecursable() initial error") mytester:assert(isRecursable(mlp.recurrentModule, {torch.randn(inputSize), torch.randn(outputSize)}), "Recurrent isRecursable() recurrent error") + -- test that the above test actually works + local euclidean = nn.Euclidean(inputSize, outputSize) + mytester:assert(not isRecursable(euclidean, torch.randn(batchSize, inputSize)), "AbstractRecurrent.isRecursable error") + local gradOutputs, outputs = {}, {} -- inputs = {inputN, {inputN-1, {inputN-2, ...}}}}} local inputs @@ -610,7 +614,6 @@ function nnxtest.LSTM() for step=1,nStep do -- iteratively build an LSTM out of non-recurrent components local lstm = lstmStep:clone() - lstm:share(lstmStep) lstm:share(lstmStep, 'weight', 'gradWeight', 'bias', 'gradBias') if step == 1 then mlp2 = lstm @@ -624,7 +627,6 @@ function nnxtest.LSTM() mlp2 = rnn end - -- prepare inputs for mlp2 if inputs then inputs = {input[step], inputs} @@ -633,7 +635,6 @@ function nnxtest.LSTM() end end mlp2:add(nn.SelectTable(1)) --just output the output (not cell) - local output2 = mlp2:forward(inputs) mlp2:zeroGradParameters() @@ -646,7 +647,8 @@ function nnxtest.LSTM() mytester:assert(#params == #params2, "LSTM parameters error "..#params.." ~= "..#params2) for i, gradParam in ipairs(gradParams) do local gradParam2 = gradParams2[i] - mytester:assertTensorEq(gradParam, gradParam2, 0.000001, "LSTM gradParam "..i.." error "..tostring(gradParam).." "..tostring(gradParam2)) + mytester:assertTensorEq(gradParam, gradParam2, 0.000001, + "LSTM gradParam "..i.." error "..tostring(gradParam).." "..tostring(gradParam2)) end end |