diff options
author | Nicholas Leonard <nick@nikopia.org> | 2014-11-29 04:19:11 +0300 |
---|---|---|
committer | Nicholas Leonard <nick@nikopia.org> | 2014-11-29 04:19:11 +0300 |
commit | 15a4018e4277bee3ff1ba019d83c0214e0407c40 (patch) | |
tree | 2865638a0ce9204347de19b6525486fadcec8114 /test | |
parent | 7d87aa6b5045dc135b9d2ee411253fa9cb97793e (diff) |
Recurrent constructor takes rho argument
Diffstat (limited to 'test')
-rw-r--r-- | test/test-all.lua | 20 |
1 files changed, 13 insertions, 7 deletions
diff --git a/test/test-all.lua b/test/test-all.lua index a3999b8..ae2d17a 100644 --- a/test/test-all.lua +++ b/test/test-all.lua @@ -331,7 +331,7 @@ function nnxtest.Recurrent() local inputSize = 10 local hiddenSize = 12 local outputSize = 7 - local nSteps = 5 + local nSteps = 5 local inputModule = nn.Linear(inputSize, outputSize) local transferModule = nn.Sigmoid() -- test MLP feedback Module (because of Module:representations()) @@ -339,7 +339,8 @@ function nnxtest.Recurrent() feedbackModule:add(nn.Linear(outputSize, hiddenSize)) feedbackModule:add(nn.Sigmoid()) feedbackModule:add(nn.Linear(hiddenSize, outputSize)) - local mlp = nn.Recurrent(outputSize, inputModule, feedbackModule, transferModule:clone()) + -- rho = nSteps + local mlp = nn.Recurrent(outputSize, inputModule, feedbackModule, transferModule:clone(), nSteps) local gradOutputs, outputs = {}, {} -- inputs = {inputN, {inputN-1, {inputN-2, ...}}}}} @@ -352,6 +353,8 @@ function nnxtest.Recurrent() mlp6:evaluate() mlp:zeroGradParameters() + local mlp7 = mlp:clone() + mlp7.rho = nSteps - 1 for step=1,nSteps do local input = torch.randn(batchSize, inputSize) local gradOutput @@ -368,6 +371,10 @@ function nnxtest.Recurrent() local output6 = mlp6:forward(input) mytester:assertTensorEq(output, output6, 0.000001, "evaluation error "..step) + + local output7 = mlp7:forward(input) + mlp7:backward(input, gradOutput) + mytester:assertTensorEq(output, output7, 0.000001, "rho = nSteps-1 forward error "..step) table.insert(gradOutputs, gradOutput) table.insert(outputs, output:clone()) @@ -380,15 +387,14 @@ function nnxtest.Recurrent() end local mlp4 = mlp:clone() local mlp5 = mlp:clone() - local mlp7 = mlp:clone() -- rho = nSteps - 1 -- backward propagate through time (BPTT) - local gradInput = mlp:backwardThroughTime(nSteps) -- rho = nSteps + local gradInput = mlp:backwardThroughTime() mlp4.fastBackward = false - local gradInput4 = mlp4:backwardThroughTime(nSteps) + local gradInput4 = mlp4:backwardThroughTime() mytester:assertTensorEq(gradInput, gradInput4, 0.000001, 'error slow vs fast backwardThroughTime') - -- rho = nSteps - 1 - mlp7:backwardThroughTime(nSteps-1) -- rho shouldn't update startModule + -- rho = nSteps - 1 : shouldn't update startModule + mlp7:backwardThroughTime() local mlp2 -- this one will simulate rho = nSteps local outputModules = {} |