diff options
author | Nicholas Leonard <nick@nikopia.org> | 2014-11-29 02:01:13 +0300 |
---|---|---|
committer | Nicholas Leonard <nick@nikopia.org> | 2014-11-29 02:01:13 +0300 |
commit | 7d87aa6b5045dc135b9d2ee411253fa9cb97793e (patch) | |
tree | 37a23c41bfa05f25aca6d65bcf119fb168e3566a /test | |
parent | 3158ff849cdd835b00da46905b087930bf89b9b3 (diff) |
added arg rho to BPTT
Diffstat (limited to 'test')
-rw-r--r-- | test/test-all.lua | 38 |
1 files changed, 35 insertions, 3 deletions
diff --git a/test/test-all.lua b/test/test-all.lua index b0a6585..a3999b8 100644 --- a/test/test-all.lua +++ b/test/test-all.lua @@ -380,14 +380,17 @@ function nnxtest.Recurrent() end local mlp4 = mlp:clone() local mlp5 = mlp:clone() + local mlp7 = mlp:clone() -- rho = nSteps - 1 -- backward propagate through time (BPTT) - local gradInput = mlp:backwardThroughTime() + local gradInput = mlp:backwardThroughTime(nSteps) -- rho = nSteps mlp4.fastBackward = false - local gradInput4 = mlp4:backwardThroughTime() + local gradInput4 = mlp4:backwardThroughTime(nSteps) mytester:assertTensorEq(gradInput, gradInput4, 0.000001, 'error slow vs fast backwardThroughTime') + -- rho = nSteps - 1 + mlp7:backwardThroughTime(nSteps-1) -- rho shouldn't update startModule - local mlp2 + local mlp2 -- this one will simulate rho = nSteps local outputModules = {} for step=1,nSteps do local inputModule_ = inputModule:clone() @@ -417,8 +420,22 @@ function nnxtest.Recurrent() end end + local output2 = mlp2:forward(inputs) mlp2:zeroGradParameters() + + -- unlike mlp2, mlp8 will simulate rho = nSteps -1 + local mlp8 = mlp2:clone() + local inputModule8 = mlp8.modules[1].modules[1] + local m = mlp8.modules[1].modules[2].modules[1].modules[1].modules[2] + m = m.modules[1].modules[1].modules[2].modules[1].modules[1].modules[2] + local feedbackModule8 = m.modules[2] + local startModule8 = m.modules[1].modules[2] -- before clone + -- unshare the intialModule: + m.modules[1] = m.modules[1]:clone() + m.modules[2] = m.modules[2]:clone() + mlp8:backward(inputs, gradOutputs[#gradOutputs]) + local gradInput2 = mlp2:backward(inputs, gradOutputs[#gradOutputs]) for step=1,nSteps-1 do gradInput2 = gradInput2[2] @@ -432,6 +449,7 @@ function nnxtest.Recurrent() end local mlp3 = nn.Sequential() + -- contains params and grads of mlp2 (the MLP version of the Recurrent) mlp3:add(startModule):add(inputModule):add(feedbackModule) local params2, gradParams2 = mlp3:parameters() local params, gradParams = mlp:parameters() @@ -444,6 +462,20 @@ function nnxtest.Recurrent() mytester:assertTensorEq(gradParams[i], gradParams2[i], 0.000001, 'gradParameter error ' .. i) end + local mlp9 = nn.Sequential() + -- contains params and grads of mlp8 + mlp9:add(startModule8):add(inputModule8):add(feedbackModule8) + local params9, gradParams9 = mlp9:parameters() + local params7, gradParams7 = mlp7:parameters() + mytester:assert(#params9 == #params7, 'missing parameters') + mytester:assert(#gradParams7 == #params7, 'missing gradParameters') + for i=1,#params do + if i > 1 then + gradParams9[i]:div(nSteps-1) + end + mytester:assertTensorEq(gradParams7[i], gradParams9[i], 0.00001, 'gradParameter error ' .. i) + end + -- already called backwardThroughTime() mlp:updateParameters(0.1) mlp4:updateParameters(0.1) |