diff options
author | nicholas-leonard <nick@nikopia.org> | 2015-03-18 06:20:07 +0300 |
---|---|---|
committer | nicholas-leonard <nick@nikopia.org> | 2015-03-18 06:20:07 +0300 |
commit | 0cc77fa5e2e878d13f7103631032ca2fabf8929c (patch) | |
tree | cb2f08958fe47da9824629033c394ae6804db8af | |
parent | e974dd27fd1af8cbbaf0c5940e6c4b4dc7438f98 (diff) |
Repeater unit tests
-rw-r--r-- | AbstractRecurrent.lua | 4 | ||||
-rw-r--r-- | Repeater.lua | 22 | ||||
-rw-r--r-- | Sequencer.lua | 4 | ||||
-rw-r--r-- | test/test-all.lua | 37 |
4 files changed, 56 insertions, 11 deletions
diff --git a/AbstractRecurrent.lua b/AbstractRecurrent.lua index af5cba0..cf40626 100644 --- a/AbstractRecurrent.lua +++ b/AbstractRecurrent.lua @@ -30,7 +30,7 @@ local function recursiveResizeAs(t1,t2) t1[key], t2[key] = recursiveResizeAs(t1[key], t2[key]) end elseif torch.isTensor(t2) then - t1 = t1 or t2.new() + t1 = torch.isTensor(t1) and t1 or t2.new() t1:resizeAs(t2) else error("expecting nested tensors or tables. Got ".. @@ -64,7 +64,7 @@ local function recursiveCopy(t1,t2) t1[key], t2[key] = recursiveCopy(t1[key], t2[key]) end elseif torch.isTensor(t2) then - t1 = t1 or t2.new() + t1 = torch.isTensor(t1) and t1 or t2.new() t1:resizeAs(t2):copy(t2) else error("expecting nested tensors or tables. Got ".. diff --git a/Repeater.lua b/Repeater.lua index 3cf7f86..68ea41b 100644 --- a/Repeater.lua +++ b/Repeater.lua @@ -6,11 +6,12 @@ ------------------------------------------------------------------------ local Repeater, parent = torch.class("nn.Repeater", "nn.Container") -function Repeater:__init(nStep, rnn) +function Repeater:__init(rnn, nStep) parent.__init(self) + assert(torch.type(nStep) == 'number', "expecting number value for arg 2") self.nStep = nStep self.rnn = rnn - assert(rnn.backwardThroughTime, "expecting AbstractRecurrent instance for arg 2") + assert(rnn.backwardThroughTime, "expecting AbstractRecurrent instance for arg 1") self.modules[1] = rnn self.output = {} end @@ -23,6 +24,9 @@ function Repeater:updateOutput(input) return self.output end +local recursiveAdd = nn.AbstractRecurrent.recursiveAdd +local recursiveCopy = nn.AbstractRecurrent.recursiveCopy + function Repeater:updateGradInput(input, gradOutput) assert(self.rnn.step - 1 == self.nStep, "inconsistent rnn steps") assert(torch.type(gradOutput) == 'table', "expecting gradOutput table") @@ -33,7 +37,15 @@ function Repeater:updateGradInput(input, gradOutput) end -- back-propagate through time (BPTT) self.rnn:updateGradInputThroughTime() - self.gradInput = self.rnn.gradInputs + + for i,currentGradInput in ipairs(self.rnn.gradInputs) do + if i == 1 then + self.gradInput = recursiveCopy(self.gradInput, currentGradInput) + else + recursiveAdd(self.gradInput, currentGradInput) + end + end + return self.gradInput end @@ -66,9 +78,9 @@ function Repeater:__tostring__() local line = '\n' local str = torch.type(self) .. ' {' .. line str = str .. tab .. '[ input, input, ..., input ]'.. line - str = str .. tab .. ' V '.. line + str = str .. tab .. ' V V V '.. line str = str .. tab .. tostring(self.modules[1]):gsub(line, line .. tab) .. line - str = str .. tab .. ' V '.. line + str = str .. tab .. ' V V V '.. line str = str .. tab .. '[output(1),output(2),...,output('..self.nStep..')]' .. line str = str .. '}' return str diff --git a/Sequencer.lua b/Sequencer.lua index f6e8bd0..97f9d6c 100644 --- a/Sequencer.lua +++ b/Sequencer.lua @@ -188,9 +188,9 @@ function Sequencer:__tostring__() local line = '\n' local str = torch.type(self) .. ' {' .. line str = str .. tab .. '[input(1), input(2), ..., input(T)]'.. line - str = str .. tab .. ' V '.. line + str = str .. tab .. ' V V V '.. line str = str .. tab .. tostring(self.modules[1]):gsub(line, line .. tab) .. line - str = str .. tab .. ' V '.. line + str = str .. tab .. ' V V V '.. line str = str .. tab .. '[output(1),output(2),...,output(T)]' .. line str = str .. '}' return str diff --git a/test/test-all.lua b/test/test-all.lua index 265b14d..bf489a3 100644 --- a/test/test-all.lua +++ b/test/test-all.lua @@ -725,8 +725,41 @@ function nnxtest.Sequencer() mytester:assertTensorEq(outputs3[step], output, 0.00001, "Sequencer output "..step) mytester:assertTensorEq(gradInputs3[step], rnn.gradInputs[step], 0.00001, "Sequencer gradInputs "..step) end - print"" - print(rnn3) +end + +function nnxtest.Repeater() + local batchSize = 4 + local inputSize = 10 + local outputSize = 7 + local nSteps = 5 + local inputModule = nn.Linear(inputSize, outputSize) + local transferModule = nn.Sigmoid() + -- test MLP feedback Module (because of Module:representations()) + local feedbackModule = nn.Linear(outputSize, outputSize) + -- rho = nSteps + local rnn = nn.Recurrent(outputSize, inputModule, feedbackModule, transferModule, nSteps) + local rnn2 = rnn:clone() + + local inputs, outputs, gradOutputs = {}, {}, {} + local input = torch.randn(batchSize, inputSize) + for step=1,nSteps do + outputs[step] = rnn:forward(input) + gradOutputs[step] = torch.randn(batchSize, outputSize) + rnn:backward(input, gradOutputs[step]) + end + rnn:backwardThroughTime() + + local rnn3 = nn.Repeater(rnn2, nSteps) + local outputs3 = rnn3:forward(input) + local gradInput3 = rnn3:backward(input, gradOutputs) + mytester:assert(#outputs3 == #outputs, "Repeater output size err") + mytester:assert(#outputs3 == #rnn.gradInputs, "Repeater gradInputs size err") + local gradInput = rnn.gradInputs[1]:clone():zero() + for step,output in ipairs(outputs) do + mytester:assertTensorEq(outputs3[step], output, 0.00001, "Sequencer output "..step) + gradInput:add(rnn.gradInputs[step]) + end + mytester:assertTensorEq(gradInput3, gradInput, 0.00001, "Repeater gradInput err") end function nnxtest.SpatialNormalization_Gaussian2D() |