Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authornicholas-leonard <nick@nikopia.org>2015-03-18 06:20:07 +0300
committernicholas-leonard <nick@nikopia.org>2015-03-18 06:20:07 +0300
commit0cc77fa5e2e878d13f7103631032ca2fabf8929c (patch)
treecb2f08958fe47da9824629033c394ae6804db8af
parente974dd27fd1af8cbbaf0c5940e6c4b4dc7438f98 (diff)
Repeater unit tests
-rw-r--r--AbstractRecurrent.lua4
-rw-r--r--Repeater.lua22
-rw-r--r--Sequencer.lua4
-rw-r--r--test/test-all.lua37
4 files changed, 56 insertions, 11 deletions
diff --git a/AbstractRecurrent.lua b/AbstractRecurrent.lua
index af5cba0..cf40626 100644
--- a/AbstractRecurrent.lua
+++ b/AbstractRecurrent.lua
@@ -30,7 +30,7 @@ local function recursiveResizeAs(t1,t2)
t1[key], t2[key] = recursiveResizeAs(t1[key], t2[key])
end
elseif torch.isTensor(t2) then
- t1 = t1 or t2.new()
+ t1 = torch.isTensor(t1) and t1 or t2.new()
t1:resizeAs(t2)
else
error("expecting nested tensors or tables. Got "..
@@ -64,7 +64,7 @@ local function recursiveCopy(t1,t2)
t1[key], t2[key] = recursiveCopy(t1[key], t2[key])
end
elseif torch.isTensor(t2) then
- t1 = t1 or t2.new()
+ t1 = torch.isTensor(t1) and t1 or t2.new()
t1:resizeAs(t2):copy(t2)
else
error("expecting nested tensors or tables. Got "..
diff --git a/Repeater.lua b/Repeater.lua
index 3cf7f86..68ea41b 100644
--- a/Repeater.lua
+++ b/Repeater.lua
@@ -6,11 +6,12 @@
------------------------------------------------------------------------
local Repeater, parent = torch.class("nn.Repeater", "nn.Container")
-function Repeater:__init(nStep, rnn)
+function Repeater:__init(rnn, nStep)
parent.__init(self)
+ assert(torch.type(nStep) == 'number', "expecting number value for arg 2")
self.nStep = nStep
self.rnn = rnn
- assert(rnn.backwardThroughTime, "expecting AbstractRecurrent instance for arg 2")
+ assert(rnn.backwardThroughTime, "expecting AbstractRecurrent instance for arg 1")
self.modules[1] = rnn
self.output = {}
end
@@ -23,6 +24,9 @@ function Repeater:updateOutput(input)
return self.output
end
+local recursiveAdd = nn.AbstractRecurrent.recursiveAdd
+local recursiveCopy = nn.AbstractRecurrent.recursiveCopy
+
function Repeater:updateGradInput(input, gradOutput)
assert(self.rnn.step - 1 == self.nStep, "inconsistent rnn steps")
assert(torch.type(gradOutput) == 'table', "expecting gradOutput table")
@@ -33,7 +37,15 @@ function Repeater:updateGradInput(input, gradOutput)
end
-- back-propagate through time (BPTT)
self.rnn:updateGradInputThroughTime()
- self.gradInput = self.rnn.gradInputs
+
+ for i,currentGradInput in ipairs(self.rnn.gradInputs) do
+ if i == 1 then
+ self.gradInput = recursiveCopy(self.gradInput, currentGradInput)
+ else
+ recursiveAdd(self.gradInput, currentGradInput)
+ end
+ end
+
return self.gradInput
end
@@ -66,9 +78,9 @@ function Repeater:__tostring__()
local line = '\n'
local str = torch.type(self) .. ' {' .. line
str = str .. tab .. '[ input, input, ..., input ]'.. line
- str = str .. tab .. ' V '.. line
+ str = str .. tab .. ' V V V '.. line
str = str .. tab .. tostring(self.modules[1]):gsub(line, line .. tab) .. line
- str = str .. tab .. ' V '.. line
+ str = str .. tab .. ' V V V '.. line
str = str .. tab .. '[output(1),output(2),...,output('..self.nStep..')]' .. line
str = str .. '}'
return str
diff --git a/Sequencer.lua b/Sequencer.lua
index f6e8bd0..97f9d6c 100644
--- a/Sequencer.lua
+++ b/Sequencer.lua
@@ -188,9 +188,9 @@ function Sequencer:__tostring__()
local line = '\n'
local str = torch.type(self) .. ' {' .. line
str = str .. tab .. '[input(1), input(2), ..., input(T)]'.. line
- str = str .. tab .. ' V '.. line
+ str = str .. tab .. ' V V V '.. line
str = str .. tab .. tostring(self.modules[1]):gsub(line, line .. tab) .. line
- str = str .. tab .. ' V '.. line
+ str = str .. tab .. ' V V V '.. line
str = str .. tab .. '[output(1),output(2),...,output(T)]' .. line
str = str .. '}'
return str
diff --git a/test/test-all.lua b/test/test-all.lua
index 265b14d..bf489a3 100644
--- a/test/test-all.lua
+++ b/test/test-all.lua
@@ -725,8 +725,41 @@ function nnxtest.Sequencer()
mytester:assertTensorEq(outputs3[step], output, 0.00001, "Sequencer output "..step)
mytester:assertTensorEq(gradInputs3[step], rnn.gradInputs[step], 0.00001, "Sequencer gradInputs "..step)
end
- print""
- print(rnn3)
+end
+
+function nnxtest.Repeater()
+ local batchSize = 4
+ local inputSize = 10
+ local outputSize = 7
+ local nSteps = 5
+ local inputModule = nn.Linear(inputSize, outputSize)
+ local transferModule = nn.Sigmoid()
+ -- test MLP feedback Module (because of Module:representations())
+ local feedbackModule = nn.Linear(outputSize, outputSize)
+ -- rho = nSteps
+ local rnn = nn.Recurrent(outputSize, inputModule, feedbackModule, transferModule, nSteps)
+ local rnn2 = rnn:clone()
+
+ local inputs, outputs, gradOutputs = {}, {}, {}
+ local input = torch.randn(batchSize, inputSize)
+ for step=1,nSteps do
+ outputs[step] = rnn:forward(input)
+ gradOutputs[step] = torch.randn(batchSize, outputSize)
+ rnn:backward(input, gradOutputs[step])
+ end
+ rnn:backwardThroughTime()
+
+ local rnn3 = nn.Repeater(rnn2, nSteps)
+ local outputs3 = rnn3:forward(input)
+ local gradInput3 = rnn3:backward(input, gradOutputs)
+ mytester:assert(#outputs3 == #outputs, "Repeater output size err")
+ mytester:assert(#outputs3 == #rnn.gradInputs, "Repeater gradInputs size err")
+ local gradInput = rnn.gradInputs[1]:clone():zero()
+ for step,output in ipairs(outputs) do
+ mytester:assertTensorEq(outputs3[step], output, 0.00001, "Sequencer output "..step)
+ gradInput:add(rnn.gradInputs[step])
+ end
+ mytester:assertTensorEq(gradInput3, gradInput, 0.00001, "Repeater gradInput err")
end
function nnxtest.SpatialNormalization_Gaussian2D()