diff options
author | nicholas-leonard <nick@nikopia.org> | 2014-11-03 18:54:54 +0300 |
---|---|---|
committer | nicholas-leonard <nick@nikopia.org> | 2014-11-03 18:54:54 +0300 |
commit | 5ddd5c521bfbdd4d212c6b836bab74959214ed2b (patch) | |
tree | 5e23d10b0b60465d466984cb5f6ead3d1b6ab423 /test | |
parent | c0a5199d13436df8a15f9723c18557c52dcb86bf (diff) |
nn.Recurrent is unit tested
Diffstat (limited to 'test')
-rw-r--r-- | test/test-all.lua | 29 |
1 files changed, 18 insertions, 11 deletions
diff --git a/test/test-all.lua b/test/test-all.lua index 0d666e0..10f9166 100644 --- a/test/test-all.lua +++ b/test/test-all.lua @@ -307,7 +307,7 @@ function nnxtest.Recurrent() local inputSize = 10 local hiddenSize = 12 local outputSize = 7 - local nSteps = 5 + local nSteps = 5 local inputModule = nn.Linear(inputSize, outputSize) local transferModule = nn.Sigmoid() -- test MLP feedback Module (because of Module:representations()) @@ -316,12 +316,13 @@ function nnxtest.Recurrent() feedbackModule:add(nn.Sigmoid()) feedbackModule:add(nn.Linear(hiddenSize, outputSize)) local mlp = nn.Recurrent(outputSize, inputModule, feedbackModule, transferModule:clone()) - inputModule = mlp.inputModule:clone() - feedbackModule = mlp.feedbackModule:clone() local gradOutputs, outputs = {}, {} + -- inputs = {inputN, {inputN-1, {inputN-2, ...}}}}} local inputs local startModule = mlp.startModule:clone() + inputModule = mlp.inputModule:clone() + feedbackModule = mlp.feedbackModule:clone() mlp:zeroGradParameters() for step=1,nSteps do @@ -348,15 +349,14 @@ function nnxtest.Recurrent() end end -- backward propagate through time (BPTT) - mlp:backwardThroughTime() + local gradInput = mlp:backwardThroughTime() - -- input = {inputN, {inputN-1, {inputN-2, ...}}}}} local mlp2 - local outputs2 = {} + local outputModules = {} for step=1,nSteps do local inputModule_ = inputModule:clone() local outputModule = transferModule:clone() - table.insert(outputs2, outputModule.output) + table.insert(outputModules, outputModule) inputModule_:share(inputModule, 'weight', 'gradWeight', 'bias', 'gradBias') if step == 1 then local initialModule = nn.Sequential() @@ -384,21 +384,28 @@ function nnxtest.Recurrent() local output2 = mlp2:forward(inputs) mlp2:zeroGradParameters() local gradInput2 = mlp2:backward(inputs, gradOutputs[#gradOutputs]) + for step=1,nSteps-1 do + gradInput2 = gradInput2[2] + end + mytester:assertTensorEq(gradInput, gradInput2, 0.000001, "recurrent gradInput") mytester:assertTensorEq(outputs[#outputs], output2, 0.000001, "recurrent output") for step=1,nSteps do - local output, output2 = outputs[step], outputs2[step] - mytester:assertTensorEq(output, output2, 0.000001, "recurrent output step="..step) + local output, outputModule = outputs[step], outputModules[step] + mytester:assertTensorEq(output, outputModule.output, 0.000001, "recurrent output step="..step) end local mlp3 = nn.Sequential() - mlp3:add(startModule):add(inputModule):add(recurrentModule) + mlp3:add(startModule):add(inputModule):add(feedbackModule) local params2, gradParams2 = mlp3:parameters() local params, gradParams = mlp:parameters() mytester:assert(#params2 == #params, 'missing parameters') mytester:assert(#gradParams == #params, 'missing gradParameters') for i=1,#params do - mytester:assertTensorEq(gradParams[i], gradParams2[i], 0.000001, 'gradParameter error ' .. 1) + if i > 1 then + gradParams2[i]:div(nSteps) + end + mytester:assertTensorEq(gradParams[i], gradParams2[i], 0.000001, 'gradParameter error ' .. i) end end |