Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorNicholas Leonard <nick@nikopia.org>2014-11-29 02:01:13 +0300
committerNicholas Leonard <nick@nikopia.org>2014-11-29 02:01:13 +0300
commit7d87aa6b5045dc135b9d2ee411253fa9cb97793e (patch)
tree37a23c41bfa05f25aca6d65bcf119fb168e3566a /test
parent3158ff849cdd835b00da46905b087930bf89b9b3 (diff)
added arg rho to BPTT
Diffstat (limited to 'test')
-rw-r--r--test/test-all.lua38
1 files changed, 35 insertions, 3 deletions
diff --git a/test/test-all.lua b/test/test-all.lua
index b0a6585..a3999b8 100644
--- a/test/test-all.lua
+++ b/test/test-all.lua
@@ -380,14 +380,17 @@ function nnxtest.Recurrent()
end
local mlp4 = mlp:clone()
local mlp5 = mlp:clone()
+ local mlp7 = mlp:clone() -- rho = nSteps - 1
-- backward propagate through time (BPTT)
- local gradInput = mlp:backwardThroughTime()
+ local gradInput = mlp:backwardThroughTime(nSteps) -- rho = nSteps
mlp4.fastBackward = false
- local gradInput4 = mlp4:backwardThroughTime()
+ local gradInput4 = mlp4:backwardThroughTime(nSteps)
mytester:assertTensorEq(gradInput, gradInput4, 0.000001, 'error slow vs fast backwardThroughTime')
+ -- rho = nSteps - 1
+ mlp7:backwardThroughTime(nSteps-1) -- rho shouldn't update startModule
- local mlp2
+ local mlp2 -- this one will simulate rho = nSteps
local outputModules = {}
for step=1,nSteps do
local inputModule_ = inputModule:clone()
@@ -417,8 +420,22 @@ function nnxtest.Recurrent()
end
end
+
local output2 = mlp2:forward(inputs)
mlp2:zeroGradParameters()
+
+ -- unlike mlp2, mlp8 will simulate rho = nSteps -1
+ local mlp8 = mlp2:clone()
+ local inputModule8 = mlp8.modules[1].modules[1]
+ local m = mlp8.modules[1].modules[2].modules[1].modules[1].modules[2]
+ m = m.modules[1].modules[1].modules[2].modules[1].modules[1].modules[2]
+ local feedbackModule8 = m.modules[2]
+ local startModule8 = m.modules[1].modules[2] -- before clone
+ -- unshare the intialModule:
+ m.modules[1] = m.modules[1]:clone()
+ m.modules[2] = m.modules[2]:clone()
+ mlp8:backward(inputs, gradOutputs[#gradOutputs])
+
local gradInput2 = mlp2:backward(inputs, gradOutputs[#gradOutputs])
for step=1,nSteps-1 do
gradInput2 = gradInput2[2]
@@ -432,6 +449,7 @@ function nnxtest.Recurrent()
end
local mlp3 = nn.Sequential()
+ -- contains params and grads of mlp2 (the MLP version of the Recurrent)
mlp3:add(startModule):add(inputModule):add(feedbackModule)
local params2, gradParams2 = mlp3:parameters()
local params, gradParams = mlp:parameters()
@@ -444,6 +462,20 @@ function nnxtest.Recurrent()
mytester:assertTensorEq(gradParams[i], gradParams2[i], 0.000001, 'gradParameter error ' .. i)
end
+ local mlp9 = nn.Sequential()
+ -- contains params and grads of mlp8
+ mlp9:add(startModule8):add(inputModule8):add(feedbackModule8)
+ local params9, gradParams9 = mlp9:parameters()
+ local params7, gradParams7 = mlp7:parameters()
+ mytester:assert(#params9 == #params7, 'missing parameters')
+ mytester:assert(#gradParams7 == #params7, 'missing gradParameters')
+ for i=1,#params do
+ if i > 1 then
+ gradParams9[i]:div(nSteps-1)
+ end
+ mytester:assertTensorEq(gradParams7[i], gradParams9[i], 0.00001, 'gradParameter error ' .. i)
+ end
+
-- already called backwardThroughTime()
mlp:updateParameters(0.1)
mlp4:updateParameters(0.1)