Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authornicholas-leonard <nick@nikopia.org>2014-11-03 18:54:54 +0300
committernicholas-leonard <nick@nikopia.org>2014-11-03 18:54:54 +0300
commit5ddd5c521bfbdd4d212c6b836bab74959214ed2b (patch)
tree5e23d10b0b60465d466984cb5f6ead3d1b6ab423 /test
parentc0a5199d13436df8a15f9723c18557c52dcb86bf (diff)
nn.Recurrent is unit tested
Diffstat (limited to 'test')
-rw-r--r--test/test-all.lua29
1 files changed, 18 insertions, 11 deletions
diff --git a/test/test-all.lua b/test/test-all.lua
index 0d666e0..10f9166 100644
--- a/test/test-all.lua
+++ b/test/test-all.lua
@@ -307,7 +307,7 @@ function nnxtest.Recurrent()
local inputSize = 10
local hiddenSize = 12
local outputSize = 7
- local nSteps = 5
+ local nSteps = 5
local inputModule = nn.Linear(inputSize, outputSize)
local transferModule = nn.Sigmoid()
-- test MLP feedback Module (because of Module:representations())
@@ -316,12 +316,13 @@ function nnxtest.Recurrent()
feedbackModule:add(nn.Sigmoid())
feedbackModule:add(nn.Linear(hiddenSize, outputSize))
local mlp = nn.Recurrent(outputSize, inputModule, feedbackModule, transferModule:clone())
- inputModule = mlp.inputModule:clone()
- feedbackModule = mlp.feedbackModule:clone()
local gradOutputs, outputs = {}, {}
+ -- inputs = {inputN, {inputN-1, {inputN-2, ...}}}}}
local inputs
local startModule = mlp.startModule:clone()
+ inputModule = mlp.inputModule:clone()
+ feedbackModule = mlp.feedbackModule:clone()
mlp:zeroGradParameters()
for step=1,nSteps do
@@ -348,15 +349,14 @@ function nnxtest.Recurrent()
end
end
-- backward propagate through time (BPTT)
- mlp:backwardThroughTime()
+ local gradInput = mlp:backwardThroughTime()
- -- input = {inputN, {inputN-1, {inputN-2, ...}}}}}
local mlp2
- local outputs2 = {}
+ local outputModules = {}
for step=1,nSteps do
local inputModule_ = inputModule:clone()
local outputModule = transferModule:clone()
- table.insert(outputs2, outputModule.output)
+ table.insert(outputModules, outputModule)
inputModule_:share(inputModule, 'weight', 'gradWeight', 'bias', 'gradBias')
if step == 1 then
local initialModule = nn.Sequential()
@@ -384,21 +384,28 @@ function nnxtest.Recurrent()
local output2 = mlp2:forward(inputs)
mlp2:zeroGradParameters()
local gradInput2 = mlp2:backward(inputs, gradOutputs[#gradOutputs])
+ for step=1,nSteps-1 do
+ gradInput2 = gradInput2[2]
+ end
+ mytester:assertTensorEq(gradInput, gradInput2, 0.000001, "recurrent gradInput")
mytester:assertTensorEq(outputs[#outputs], output2, 0.000001, "recurrent output")
for step=1,nSteps do
- local output, output2 = outputs[step], outputs2[step]
- mytester:assertTensorEq(output, output2, 0.000001, "recurrent output step="..step)
+ local output, outputModule = outputs[step], outputModules[step]
+ mytester:assertTensorEq(output, outputModule.output, 0.000001, "recurrent output step="..step)
end
local mlp3 = nn.Sequential()
- mlp3:add(startModule):add(inputModule):add(recurrentModule)
+ mlp3:add(startModule):add(inputModule):add(feedbackModule)
local params2, gradParams2 = mlp3:parameters()
local params, gradParams = mlp:parameters()
mytester:assert(#params2 == #params, 'missing parameters')
mytester:assert(#gradParams == #params, 'missing gradParameters')
for i=1,#params do
- mytester:assertTensorEq(gradParams[i], gradParams2[i], 0.000001, 'gradParameter error ' .. 1)
+ if i > 1 then
+ gradParams2[i]:div(nSteps)
+ end
+ mytester:assertTensorEq(gradParams[i], gradParams2[i], 0.000001, 'gradParameter error ' .. i)
end
end