Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authornicholas-leonard <nick@nikopia.org>2015-03-17 07:39:36 +0300
committernicholas-leonard <nick@nikopia.org>2015-03-17 07:39:36 +0300
commit4a3f749d7bd19db2f313879c3b7013968eb76a33 (patch)
tree97da3d165cad4273b240883f4fe15785f8853146 /test
parent1c09f06acc222f50dedb24c0a8c953333b22207a (diff)
LSTM backwardThroughTime unit tested
Diffstat (limited to 'test')
-rw-r--r--test/test-all.lua10
1 files changed, 6 insertions, 4 deletions
diff --git a/test/test-all.lua b/test/test-all.lua
index df51302..a824737 100644
--- a/test/test-all.lua
+++ b/test/test-all.lua
@@ -347,6 +347,10 @@ function nnxtest.Recurrent()
mytester:assert(isRecursable(mlp.initialModule, torch.randn(inputSize)), "Recurrent isRecursable() initial error")
mytester:assert(isRecursable(mlp.recurrentModule, {torch.randn(inputSize), torch.randn(outputSize)}), "Recurrent isRecursable() recurrent error")
+ -- test that the above test actually works
+ local euclidean = nn.Euclidean(inputSize, outputSize)
+ mytester:assert(not isRecursable(euclidean, torch.randn(batchSize, inputSize)), "AbstractRecurrent.isRecursable error")
+
local gradOutputs, outputs = {}, {}
-- inputs = {inputN, {inputN-1, {inputN-2, ...}}}}}
local inputs
@@ -610,7 +614,6 @@ function nnxtest.LSTM()
for step=1,nStep do
-- iteratively build an LSTM out of non-recurrent components
local lstm = lstmStep:clone()
- lstm:share(lstmStep)
lstm:share(lstmStep, 'weight', 'gradWeight', 'bias', 'gradBias')
if step == 1 then
mlp2 = lstm
@@ -624,7 +627,6 @@ function nnxtest.LSTM()
mlp2 = rnn
end
-
-- prepare inputs for mlp2
if inputs then
inputs = {input[step], inputs}
@@ -633,7 +635,6 @@ function nnxtest.LSTM()
end
end
mlp2:add(nn.SelectTable(1)) --just output the output (not cell)
-
local output2 = mlp2:forward(inputs)
mlp2:zeroGradParameters()
@@ -646,7 +647,8 @@ function nnxtest.LSTM()
mytester:assert(#params == #params2, "LSTM parameters error "..#params.." ~= "..#params2)
for i, gradParam in ipairs(gradParams) do
local gradParam2 = gradParams2[i]
- mytester:assertTensorEq(gradParam, gradParam2, 0.000001, "LSTM gradParam "..i.." error "..tostring(gradParam).." "..tostring(gradParam2))
+ mytester:assertTensorEq(gradParam, gradParam2, 0.000001,
+ "LSTM gradParam "..i.." error "..tostring(gradParam).." "..tostring(gradParam2))
end
end