diff options
author | nicholas-leonard <nick@nikopia.org> | 2015-03-18 06:20:07 +0300 |
---|---|---|
committer | nicholas-leonard <nick@nikopia.org> | 2015-03-18 06:20:07 +0300 |
commit | 0cc77fa5e2e878d13f7103631032ca2fabf8929c (patch) | |
tree | cb2f08958fe47da9824629033c394ae6804db8af /test | |
parent | e974dd27fd1af8cbbaf0c5940e6c4b4dc7438f98 (diff) |
Repeater unit tests
Diffstat (limited to 'test')
-rw-r--r-- | test/test-all.lua | 37 |
1 files changed, 35 insertions, 2 deletions
diff --git a/test/test-all.lua b/test/test-all.lua index 265b14d..bf489a3 100644 --- a/test/test-all.lua +++ b/test/test-all.lua @@ -725,8 +725,41 @@ function nnxtest.Sequencer() mytester:assertTensorEq(outputs3[step], output, 0.00001, "Sequencer output "..step) mytester:assertTensorEq(gradInputs3[step], rnn.gradInputs[step], 0.00001, "Sequencer gradInputs "..step) end - print"" - print(rnn3) +end + +function nnxtest.Repeater() + local batchSize = 4 + local inputSize = 10 + local outputSize = 7 + local nSteps = 5 + local inputModule = nn.Linear(inputSize, outputSize) + local transferModule = nn.Sigmoid() + -- test MLP feedback Module (because of Module:representations()) + local feedbackModule = nn.Linear(outputSize, outputSize) + -- rho = nSteps + local rnn = nn.Recurrent(outputSize, inputModule, feedbackModule, transferModule, nSteps) + local rnn2 = rnn:clone() + + local inputs, outputs, gradOutputs = {}, {}, {} + local input = torch.randn(batchSize, inputSize) + for step=1,nSteps do + outputs[step] = rnn:forward(input) + gradOutputs[step] = torch.randn(batchSize, outputSize) + rnn:backward(input, gradOutputs[step]) + end + rnn:backwardThroughTime() + + local rnn3 = nn.Repeater(rnn2, nSteps) + local outputs3 = rnn3:forward(input) + local gradInput3 = rnn3:backward(input, gradOutputs) + mytester:assert(#outputs3 == #outputs, "Repeater output size err") + mytester:assert(#outputs3 == #rnn.gradInputs, "Repeater gradInputs size err") + local gradInput = rnn.gradInputs[1]:clone():zero() + for step,output in ipairs(outputs) do + mytester:assertTensorEq(outputs3[step], output, 0.00001, "Sequencer output "..step) + gradInput:add(rnn.gradInputs[step]) + end + mytester:assertTensorEq(gradInput3, gradInput, 0.00001, "Repeater gradInput err") end function nnxtest.SpatialNormalization_Gaussian2D() |