Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authornicholas-leonard <nick@nikopia.org>2015-03-18 06:20:07 +0300
committernicholas-leonard <nick@nikopia.org>2015-03-18 06:20:07 +0300
commit0cc77fa5e2e878d13f7103631032ca2fabf8929c (patch)
treecb2f08958fe47da9824629033c394ae6804db8af /test
parente974dd27fd1af8cbbaf0c5940e6c4b4dc7438f98 (diff)
Repeater unit tests
Diffstat (limited to 'test')
-rw-r--r--test/test-all.lua37
1 files changed, 35 insertions, 2 deletions
diff --git a/test/test-all.lua b/test/test-all.lua
index 265b14d..bf489a3 100644
--- a/test/test-all.lua
+++ b/test/test-all.lua
@@ -725,8 +725,41 @@ function nnxtest.Sequencer()
mytester:assertTensorEq(outputs3[step], output, 0.00001, "Sequencer output "..step)
mytester:assertTensorEq(gradInputs3[step], rnn.gradInputs[step], 0.00001, "Sequencer gradInputs "..step)
end
- print""
- print(rnn3)
+end
+
+function nnxtest.Repeater()
+ local batchSize = 4
+ local inputSize = 10
+ local outputSize = 7
+ local nSteps = 5
+ local inputModule = nn.Linear(inputSize, outputSize)
+ local transferModule = nn.Sigmoid()
+ -- test MLP feedback Module (because of Module:representations())
+ local feedbackModule = nn.Linear(outputSize, outputSize)
+ -- rho = nSteps
+ local rnn = nn.Recurrent(outputSize, inputModule, feedbackModule, transferModule, nSteps)
+ local rnn2 = rnn:clone()
+
+ local inputs, outputs, gradOutputs = {}, {}, {}
+ local input = torch.randn(batchSize, inputSize)
+ for step=1,nSteps do
+ outputs[step] = rnn:forward(input)
+ gradOutputs[step] = torch.randn(batchSize, outputSize)
+ rnn:backward(input, gradOutputs[step])
+ end
+ rnn:backwardThroughTime()
+
+ local rnn3 = nn.Repeater(rnn2, nSteps)
+ local outputs3 = rnn3:forward(input)
+ local gradInput3 = rnn3:backward(input, gradOutputs)
+ mytester:assert(#outputs3 == #outputs, "Repeater output size err")
+ mytester:assert(#outputs3 == #rnn.gradInputs, "Repeater gradInputs size err")
+ local gradInput = rnn.gradInputs[1]:clone():zero()
+ for step,output in ipairs(outputs) do
+ mytester:assertTensorEq(outputs3[step], output, 0.00001, "Sequencer output "..step)
+ gradInput:add(rnn.gradInputs[step])
+ end
+ mytester:assertTensorEq(gradInput3, gradInput, 0.00001, "Repeater gradInput err")
end
function nnxtest.SpatialNormalization_Gaussian2D()