Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorNicholas Leonard <nick@nikopia.org>2014-11-29 04:19:11 +0300
committerNicholas Leonard <nick@nikopia.org>2014-11-29 04:19:11 +0300
commit15a4018e4277bee3ff1ba019d83c0214e0407c40 (patch)
tree2865638a0ce9204347de19b6525486fadcec8114 /test
parent7d87aa6b5045dc135b9d2ee411253fa9cb97793e (diff)
Recurrent constructor takes rho argument
Diffstat (limited to 'test')
-rw-r--r--test/test-all.lua20
1 files changed, 13 insertions, 7 deletions
diff --git a/test/test-all.lua b/test/test-all.lua
index a3999b8..ae2d17a 100644
--- a/test/test-all.lua
+++ b/test/test-all.lua
@@ -331,7 +331,7 @@ function nnxtest.Recurrent()
local inputSize = 10
local hiddenSize = 12
local outputSize = 7
- local nSteps = 5
+ local nSteps = 5
local inputModule = nn.Linear(inputSize, outputSize)
local transferModule = nn.Sigmoid()
-- test MLP feedback Module (because of Module:representations())
@@ -339,7 +339,8 @@ function nnxtest.Recurrent()
feedbackModule:add(nn.Linear(outputSize, hiddenSize))
feedbackModule:add(nn.Sigmoid())
feedbackModule:add(nn.Linear(hiddenSize, outputSize))
- local mlp = nn.Recurrent(outputSize, inputModule, feedbackModule, transferModule:clone())
+ -- rho = nSteps
+ local mlp = nn.Recurrent(outputSize, inputModule, feedbackModule, transferModule:clone(), nSteps)
local gradOutputs, outputs = {}, {}
-- inputs = {inputN, {inputN-1, {inputN-2, ...}}}}}
@@ -352,6 +353,8 @@ function nnxtest.Recurrent()
mlp6:evaluate()
mlp:zeroGradParameters()
+ local mlp7 = mlp:clone()
+ mlp7.rho = nSteps - 1
for step=1,nSteps do
local input = torch.randn(batchSize, inputSize)
local gradOutput
@@ -368,6 +371,10 @@ function nnxtest.Recurrent()
local output6 = mlp6:forward(input)
mytester:assertTensorEq(output, output6, 0.000001, "evaluation error "..step)
+
+ local output7 = mlp7:forward(input)
+ mlp7:backward(input, gradOutput)
+ mytester:assertTensorEq(output, output7, 0.000001, "rho = nSteps-1 forward error "..step)
table.insert(gradOutputs, gradOutput)
table.insert(outputs, output:clone())
@@ -380,15 +387,14 @@ function nnxtest.Recurrent()
end
local mlp4 = mlp:clone()
local mlp5 = mlp:clone()
- local mlp7 = mlp:clone() -- rho = nSteps - 1
-- backward propagate through time (BPTT)
- local gradInput = mlp:backwardThroughTime(nSteps) -- rho = nSteps
+ local gradInput = mlp:backwardThroughTime()
mlp4.fastBackward = false
- local gradInput4 = mlp4:backwardThroughTime(nSteps)
+ local gradInput4 = mlp4:backwardThroughTime()
mytester:assertTensorEq(gradInput, gradInput4, 0.000001, 'error slow vs fast backwardThroughTime')
- -- rho = nSteps - 1
- mlp7:backwardThroughTime(nSteps-1) -- rho shouldn't update startModule
+ -- rho = nSteps - 1 : shouldn't update startModule
+ mlp7:backwardThroughTime()
local mlp2 -- this one will simulate rho = nSteps
local outputModules = {}