Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNicholas Léonard <nick@nikopia.org>2015-03-18 04:49:54 +0300
committerNicholas Léonard <nick@nikopia.org>2015-03-18 04:49:54 +0300
commitc6de4cf8d372a5e507321464a9a98ed2d829798f (patch)
tree715b1aa80c1b2769ac543a4ca5f93527ae207e82
parenteaee722aab186e13bcf33479f3724a53fb049ec5 (diff)
parentedf33956ca750ed790f5275734fa80a2c07b8c7d (diff)
Merge pull request #27 from nicholas-leonard/lstm
LSTM (work in progress)
-rw-r--r--AbstractRecurrent.lua281
-rw-r--r--LSTM.lua353
-rw-r--r--Recurrent.lua318
-rw-r--r--RepeaterCriterion.lua1
-rw-r--r--ZeroGrad.lua28
-rw-r--r--init.lua11
-rw-r--r--test/test-all.lua132
7 files changed, 849 insertions, 275 deletions
diff --git a/AbstractRecurrent.lua b/AbstractRecurrent.lua
new file mode 100644
index 0000000..af5cba0
--- /dev/null
+++ b/AbstractRecurrent.lua
@@ -0,0 +1,281 @@
+local AbstractRecurrent, parent = torch.class('nn.AbstractRecurrent', 'nn.Container')
+
+function AbstractRecurrent:__init(rho)
+ parent.__init(self)
+
+ self.rho = rho --the maximum number of time steps to BPTT
+
+ self.fastBackward = true
+ self.copyInputs = true
+
+ self.inputs = {}
+ self.outputs = {}
+ self.gradOutputs = {}
+ self.scales = {}
+
+ self.gradParametersAccumulated = false
+ self.step = 1
+
+ -- stores internal states of Modules at different time-steps
+ self.recurrentOutputs = {}
+ self.recurrentGradInputs = {}
+
+ self:reset()
+end
+
+local function recursiveResizeAs(t1,t2)
+ if torch.type(t2) == 'table' then
+ t1 = (torch.type(t1) == 'table') and t1 or {t1}
+ for key,_ in pairs(t2) do
+ t1[key], t2[key] = recursiveResizeAs(t1[key], t2[key])
+ end
+ elseif torch.isTensor(t2) then
+ t1 = t1 or t2.new()
+ t1:resizeAs(t2)
+ else
+ error("expecting nested tensors or tables. Got "..
+ torch.type(t1).." and "..torch.type(t2).." instead")
+ end
+ return t1, t2
+end
+AbstractRecurrent.recursiveResizeAs = recursiveResizeAs
+
+local function recursiveSet(t1,t2)
+ if torch.type(t2) == 'table' then
+ t1 = (torch.type(t1) == 'table') and t1 or {t1}
+ for key,_ in pairs(t2) do
+ t1[key], t2[key] = recursiveSet(t1[key], t2[key])
+ end
+ elseif torch.isTensor(t2) then
+ t1 = t1 or t2.new()
+ t1:set(t2)
+ else
+ error("expecting nested tensors or tables. Got "..
+ torch.type(t1).." and "..torch.type(t2).." instead")
+ end
+ return t1, t2
+end
+AbstractRecurrent.recursiveSet = recursiveSet
+
+local function recursiveCopy(t1,t2)
+ if torch.type(t2) == 'table' then
+ t1 = (torch.type(t1) == 'table') and t1 or {t1}
+ for key,_ in pairs(t2) do
+ t1[key], t2[key] = recursiveCopy(t1[key], t2[key])
+ end
+ elseif torch.isTensor(t2) then
+ t1 = t1 or t2.new()
+ t1:resizeAs(t2):copy(t2)
+ else
+ error("expecting nested tensors or tables. Got "..
+ torch.type(t1).." and "..torch.type(t2).." instead")
+ end
+ return t1, t2
+end
+AbstractRecurrent.recursiveCopy = recursiveCopy
+
+local function recursiveAdd(t1, t2)
+ if torch.type(t2) == 'table' then
+ t1 = (torch.type(t1) == 'table') and t1 or {t1}
+ for key,_ in pairs(t2) do
+ t1[key], t2[key] = recursiveAdd(t1[key], t2[key])
+ end
+ elseif torch.isTensor(t2) and torch.isTensor(t2) then
+ t1:add(t2)
+ else
+ error("expecting nested tensors or tables. Got "..
+ torch.type(t1).." and "..torch.type(t2).." instead")
+ end
+ return t1, t2
+end
+AbstractRecurrent.recursiveAdd = recursiveAdd
+
+local function recursiveTensorEq(t1, t2)
+ if torch.type(t2) == 'table' then
+ local isEqual = true
+ if torch.type(t1) ~= 'table' then
+ return false
+ end
+ for key,_ in pairs(t2) do
+ isEqual = isEqual and recursiveTensorEq(t1[key], t2[key])
+ end
+ return isEqual
+ elseif torch.isTensor(t2) and torch.isTensor(t2) then
+ local diff = t1-t2
+ local err = diff:abs():max()
+ return err < 0.00001
+ else
+ error("expecting nested tensors or tables. Got "..
+ torch.type(t1).." and "..torch.type(t2).." instead")
+ end
+end
+AbstractRecurrent.recursiveTensorEq = recursiveTensorEq
+
+local function recursiveNormal(t2)
+ if torch.type(t2) == 'table' then
+ for key,_ in pairs(t2) do
+ t2[key] = recursiveNormal(t2[key])
+ end
+ elseif torch.isTensor(t2) then
+ t2:normal()
+ else
+ error("expecting tensor or table thereof. Got "
+ ..torch.type(t2).." instead")
+ end
+ return t2
+end
+AbstractRecurrent.recursiveNormal = recursiveNormal
+
+function AbstractRecurrent:updateGradInput(input, gradOutput)
+ -- Back-Propagate Through Time (BPTT) happens in updateParameters()
+ -- for now we just keep a list of the gradOutputs
+ self.gradOutputs[self.step-1] = self.recursiveCopy(self.gradOutputs[self.step-1] , gradOutput)
+end
+
+function AbstractRecurrent:accGradParameters(input, gradOutput, scale)
+ -- Back-Propagate Through Time (BPTT) happens in updateParameters()
+ -- for now we just keep a list of the scales
+ self.scales[self.step-1] = scale
+end
+
+function AbstractRecurrent:backwardUpdateThroughTime(learningRate)
+ local gradInput = self:updateGradInputThroughTime()
+ self:accUpdateGradParametersThroughTime(learningRate)
+ return gradInput
+end
+
+function AbstractRecurrent:updateParameters(learningRate)
+ if self.gradParametersAccumulated then
+ for i=1,#self.modules do
+ self.modules[i]:updateParameters(learningRate)
+ end
+ else
+ self:backwardUpdateThroughTime(learningRate)
+ end
+end
+
+-- goes hand in hand with the next method : forget()
+function AbstractRecurrent:recycle()
+ -- +1 is to skip initialModule
+ if self.step > self.rho + 1 then
+ assert(self.recurrentOutputs[self.step] == nil)
+ assert(self.recurrentOutputs[self.step-self.rho] ~= nil)
+ self.recurrentOutputs[self.step] = self.recurrentOutputs[self.step-self.rho]
+ self.recurrentGradInputs[self.step] = self.recurrentGradInputs[self.step-self.rho]
+ self.recurrentOutputs[self.step-self.rho] = nil
+ self.recurrentGradInputs[self.step-self.rho] = nil
+ -- need to keep rho+1 of these
+ self.outputs[self.step] = self.outputs[self.step-self.rho-1]
+ self.outputs[self.step-self.rho-1] = nil
+ end
+ if self.step > self.rho then
+ assert(self.inputs[self.step] == nil)
+ assert(self.inputs[self.step-self.rho] ~= nil)
+ self.inputs[self.step] = self.inputs[self.step-self.rho]
+ self.gradOutputs[self.step] = self.gradOutputs[self.step-self.rho]
+ self.inputs[self.step-self.rho] = nil
+ self.gradOutputs[self.step-self.rho] = nil
+ self.scales[self.step-self.rho] = nil
+ end
+end
+
+function AbstractRecurrent:forget(offset)
+ offset = offset or 1
+ if self.train ~= false then
+ -- bring all states back to the start of the sequence buffers
+ local lastStep = self.step - 1
+
+ if lastStep > self.rho + offset then
+ local i = 1 + offset
+ for step = lastStep-self.rho+offset,lastStep do
+ self.recurrentOutputs[i] = self.recurrentOutputs[step]
+ self.recurrentGradInputs[i] = self.recurrentGradInputs[step]
+ self.recurrentOutputs[step] = nil
+ self.recurrentGradInputs[step] = nil
+ -- we keep rho+1 of these : outputs[k]=outputs[k+rho+1]
+ self.outputs[i-1] = self.outputs[step]
+ self.outputs[step] = nil
+ i = i + 1
+ end
+
+ end
+
+ if lastStep > self.rho then
+ local i = 1
+ for step = lastStep-self.rho+1,lastStep do
+ self.inputs[i] = self.inputs[step]
+ self.gradOutputs[i] = self.gradOutputs[step]
+ self.inputs[step] = nil
+ self.gradOutputs[step] = nil
+ self.scales[step] = nil
+ i = i + 1
+ end
+
+ end
+ end
+
+ -- forget the past inputs; restart from first step
+ self.step = 1
+end
+
+-- tests whether or not the mlp can be used internally for recursion.
+-- forward A, backward A, forward B, forward A should be consistent with
+-- forward B, backward B, backward A where A and B each
+-- have their own gradInputs/outputs.
+function AbstractRecurrent.isRecursable(mlp, input)
+ local output = recursiveCopy(nil, mlp:forward(input)) --forward A
+ local gradOutput = recursiveNormal(recursiveCopy(nil, output))
+ mlp:zeroGradParameters()
+ local gradInput = recursiveCopy(nil, mlp:backward(input, gradOutput)) --backward A
+ local params, gradParams = mlp:parameters()
+ gradParams = recursiveCopy(nil, gradParams)
+
+ -- output/gradInput are the only internal module states that we track
+ local recurrentOutputs = {}
+ local recurrentGradInputs = {}
+
+ local modules = mlp:listModules()
+
+ -- save the output/gradInput states of A
+ for i,modula in ipairs(modules) do
+ recurrentOutputs[i] = modula.output
+ recurrentGradInputs[i] = modula.gradInput
+ end
+ -- set the output/gradInput states for B
+ local recurrentOutputs2 = {}
+ local recurrentGradInputs2 = {}
+ for i,modula in ipairs(modules) do
+ modula.output = recursiveResizeAs(recurrentOutputs2[i], modula.output)
+ modula.gradInput = recursiveResizeAs(recurrentGradInputs2[i], modula.gradInput)
+ end
+
+ local input2 = recursiveNormal(recursiveCopy(nil, input))
+ local gradOutput2 = recursiveNormal(recursiveCopy(nil, gradOutput))
+ local output2 = mlp:forward(input2) --forward B
+ mlp:zeroGradParameters()
+ local gradInput2 = mlp:backward(input2, gradOutput2) --backward B
+
+ -- save the output/gradInput state of B
+ for i,modula in ipairs(modules) do
+ recurrentOutputs2[i] = modula.output
+ recurrentGradInputs2[i] = modula.gradInput
+ end
+
+ -- set the output/gradInput states for A
+ for i,modula in ipairs(modules) do
+ modula.output = recursiveResizeAs(recurrentOutputs[i], modula.output)
+ modula.gradInput = recursiveResizeAs(recurrentGradInputs[i], modula.gradInput)
+ end
+
+ mlp:zeroGradParameters()
+ local gradInput3 = mlp:backward(input, gradOutput) --forward A
+ local gradInputTest = recursiveTensorEq(gradInput, gradInput3)
+ local params3, gradParams3 = mlp:parameters()
+ local nEq = 0
+ for i,gradParam in ipairs(gradParams) do
+ nEq = nEq + (recursiveTensorEq(gradParam, gradParams3[i]) and 1 or 0)
+ end
+ local gradParamsTest = (nEq == #gradParams3)
+ mlp:zeroGradParameters()
+ return gradParamsTest and gradInputTest, gradParamsTest, gradInputTest
+end
diff --git a/LSTM.lua b/LSTM.lua
new file mode 100644
index 0000000..a3541b8
--- /dev/null
+++ b/LSTM.lua
@@ -0,0 +1,353 @@
+------------------------------------------------------------------------
+--[[ LSTM ]]--
+-- Long Short Term Memory architecture.
+-- Ref. A.: http://arxiv.org/pdf/1303.5778v1 (blueprint for this module)
+-- B. http://web.eecs.utk.edu/~itamar/courses/ECE-692/Bobby_paper1.pdf
+-- C. https://github.com/wojzaremba/lstm
+-- Expects 1D or 2D input.
+-- The first input in sequence uses zero value for cell and hidden state
+------------------------------------------------------------------------
+local LSTM, parent = torch.class('nn.LSTM', 'nn.AbstractRecurrent')
+
+function LSTM:__init(inputSize, outputSize, rho)
+ parent.__init(self, rho or 999999999999)
+ self.inputSize = inputSize
+ self.outputSize = outputSize
+ -- build the model
+ self.recurrentModule = self:buildModel()
+ -- make it work with nn.Container
+ self.modules[1] = self.recurrentModule
+
+ -- for output(0), cell(0) and gradCell(T)
+ self.zeroTensor = torch.Tensor()
+
+ self.cells = {}
+ self.gradCells = {}
+end
+
+-------------------------- factory methods -----------------------------
+function LSTM:buildGate()
+ -- Note : gate expects an input table : {input, output(t-1), cell(t-1)}
+ local gate = nn.Sequential()
+ local input2gate = nn.Linear(self.inputSize, self.outputSize)
+ local output2gate = nn.Linear(self.outputSize, self.outputSize)
+ local cell2gate = nn.CMul(self.outputSize) -- diagonal cell to gate weight matrix
+ --output2gate:noBias() --TODO
+ local para = nn.ParallelTable()
+ para:add(input2gate):add(output2gate):add(cell2gate)
+ gate:add(para)
+ gate:add(nn.CAddTable())
+ gate:add(nn.Sigmoid())
+ return gate
+end
+
+function LSTM:buildInputGate()
+ self.inputGate = self:buildGate()
+ return self.inputGate
+end
+
+function LSTM:buildForgetGate()
+ self.forgetGate = self:buildGate()
+ return self.forgetGate
+end
+
+function LSTM:buildHidden()
+ local hidden = nn.Sequential()
+ local input2hidden = nn.Linear(self.inputSize, self.outputSize)
+ local output2hidden = nn.Linear(self.outputSize, self.outputSize)
+ local para = nn.ParallelTable()
+ --output2hidden:noBias()
+ para:add(input2hidden):add(output2hidden)
+ -- input is {input, output(t-1), cell(t-1)}, but we only need {input, output(t-1)}
+ local concat = nn.ConcatTable()
+ concat:add(nn.SelectTable(1)):add(nn.SelectTable(2))
+ hidden:add(concat)
+ hidden:add(para)
+ hidden:add(nn.CAddTable())
+ self.hiddenLayer = hidden
+ return hidden
+end
+
+function LSTM:buildCell()
+ -- build
+ self.inputGate = self:buildInputGate()
+ self.forgetGate = self:buildForgetGate()
+ self.hiddenLayer = self:buildHidden()
+ -- forget = forgetGate{input, output(t-1), cell(t-1)} * cell(t-1)
+ local forget = nn.Sequential()
+ local concat = nn.ConcatTable()
+ concat:add(self.forgetGate):add(nn.SelectTable(3))
+ forget:add(concat)
+ forget:add(nn.CMulTable())
+ -- input = inputGate{input, output(t-1), cell(t-1)} * hiddenLayer{input, output(t-1), cell(t-1)}
+ local input = nn.Sequential()
+ local concat2 = nn.ConcatTable()
+ concat2:add(self.inputGate):add(self.hiddenLayer)
+ input:add(concat2)
+ input:add(nn.CMulTable())
+ -- cell(t) = forget + input
+ local cell = nn.Sequential()
+ local concat3 = nn.ConcatTable()
+ concat3:add(forget):add(input)
+ cell:add(concat3)
+ cell:add(nn.CAddTable())
+ self.cellLayer = cell
+ return cell
+end
+
+function LSTM:buildOutputGate()
+ self.outputGate = self:buildGate()
+ return self.outputGate
+end
+
+-- cell(t) = cellLayer{input, output(t-1), cell(t-1)}
+-- output(t) = outputGate{input, output(t-1), cell(t)}*tanh(cell(t))
+-- output of Model is table : {output(t), cell(t)}
+function LSTM:buildModel()
+ -- build components
+ self.cellLayer = self:buildCell()
+ self.outputGate = self:buildOutputGate()
+ -- assemble
+ local concat = nn.ConcatTable()
+ local concat2 = nn.ConcatTable()
+ concat2:add(nn.SelectTable(1)):add(nn.SelectTable(2))
+ concat:add(concat2):add(self.cellLayer)
+ local model = nn.Sequential()
+ model:add(concat)
+ -- output of concat is {{input, output}, cell(t)},
+ -- so flatten to {input, output, cell(t)}
+ model:add(nn.FlattenTable())
+ local cellAct = nn.Sequential()
+ cellAct:add(nn.SelectTable(3))
+ cellAct:add(nn.Tanh())
+ local concat3 = nn.ConcatTable()
+ concat3:add(self.outputGate):add(cellAct)
+ local output = nn.Sequential()
+ output:add(concat3)
+ output:add(nn.CMulTable())
+ -- we want the model to output : {output(t), cell(t)}
+ local concat4 = nn.ConcatTable()
+ concat4:add(output):add(nn.SelectTable(3))
+ model:add(concat4)
+ return model
+end
+
+------------------------- forward backward -----------------------------
+function LSTM:updateOutput(input)
+ local prevOutput, prevCell
+ if self.step == 1 then
+ prevOutput = self.zeroTensor
+ prevCell = self.zeroTensor
+ if input:dim() == 2 then
+ self.zeroTensor:resize(input:size(1), self.outputSize):zero()
+ else
+ self.zeroTensor:resize(self.outputSize):zero()
+ end
+ self.outputs[0] = self.zeroTensor
+ self.cells[0] = self.zeroTensor
+ else
+ -- previous output and cell of this module
+ prevOutput = self.output
+ prevCell = self.cell
+ end
+
+ -- output(t), cell(t) = lstm{input(t), output(t-1), cell(t-1)}
+ local output, cell
+ if self.train ~= false then
+ -- set/save the output states
+ local modules = self.recurrentModule:listModules()
+ self:recycle()
+ local recurrentOutputs = self.recurrentOutputs[self.step]
+ if not recurrentOutputs then
+ recurrentOutputs = {}
+ self.recurrentOutputs[self.step] = recurrentOutputs
+ end
+ for i,modula in ipairs(modules) do
+ local output_ = self.recursiveResizeAs(recurrentOutputs[i], modula.output)
+ modula.output = output_
+ end
+ -- the actual forward propagation
+ output, cell = unpack(self.recurrentModule:updateOutput{input, prevOutput, prevCell})
+
+ for i,modula in ipairs(modules) do
+ recurrentOutputs[i] = modula.output
+ end
+ else
+ output, cell = unpack(self.recurrentModule:updateOutput{input, prevOutput, prevCell})
+ end
+
+ if self.train ~= false then
+ local input_ = self.inputs[self.step]
+ self.inputs[self.step] = self.copyInputs
+ and self.recursiveCopy(input_, input)
+ or self.recursiveSet(input_, input)
+ end
+
+ self.outputs[self.step] = output
+ self.cells[self.step] = cell
+
+ self.output = output
+ self.cell = cell
+
+ self.step = self.step + 1
+ self.gradParametersAccumulated = false
+ -- note that we don't return the cell, just the output
+ return self.output
+end
+
+function LSTM:backwardThroughTime()
+ assert(self.step > 1, "expecting at least one updateOutput")
+ self.gradInputs = {}
+ local rho = math.min(self.rho, self.step-1)
+ local stop = self.step - rho
+ if self.fastBackward then
+ local gradInput, gradPrevOutput, gradCell
+ for step=self.step-1,math.max(stop,1),-1 do
+ -- set the output/gradOutput states of current Module
+ local modules = self.recurrentModule:listModules()
+ local recurrentOutputs = self.recurrentOutputs[step]
+ local recurrentGradInputs = self.recurrentGradInputs[step]
+ if not recurrentGradInputs then
+ recurrentGradInputs = {}
+ self.recurrentGradInputs[step] = recurrentGradInputs
+ end
+
+ for i,modula in ipairs(modules) do
+ local output, gradInput = modula.output, modula.gradInput
+ assert(gradInput, "missing gradInput")
+ local output_ = recurrentOutputs[i]
+ assert(output_, "backwardThroughTime should be preceded by updateOutput")
+ modula.output = output_
+ modula.gradInput = self.recursiveResizeAs(recurrentGradInputs[i], gradInput) --resize, NOT copy
+ end
+
+ -- backward propagate through this step
+ local gradOutput = self.gradOutputs[step]
+ if gradPrevOutput then
+ self.recursiveAdd(gradOutput, gradPrevOutput)
+ end
+
+ self.gradCells[step] = gradCell
+ local scale = self.scales[step]/rho
+
+ local inputTable = {self.inputs[step], self.outputs[step-1], self.cells[step-1]}
+ local gradInputTable = self.recurrentModule:backward(inputTable, {gradOutput, gradCell}, scale)
+ gradInput, gradPrevOutput, gradCell = unpack(gradInputTable)
+ table.insert(self.gradInputs, 1, gradInput)
+
+ for i,modula in ipairs(modules) do
+ recurrentGradInputs[i] = modula.gradInput
+ end
+ end
+ return gradInput
+ else
+ local gradInput = self:updateGradInputThroughTime()
+ self:accGradParametersThroughTime()
+ return gradInput
+ end
+end
+
+function LSTM:updateGradInputThroughTime()
+ assert(self.step > 1, "expecting at least one updateOutput")
+ self.gradInputs = {}
+ local gradInput, gradPrevOutput
+ local gradCell = self.zeroTensor
+ local rho = math.min(self.rho, self.step-1)
+ local stop = self.step - rho
+ for step=self.step-1,math.max(stop,1),-1 do
+ -- set the output/gradOutput states of current Module
+ local modules = self.recurrentModule:listModules()
+ local recurrentOutputs = self.recurrentOutputs[step]
+ local recurrentGradInputs = self.recurrentGradInputs[step]
+ if not recurrentGradInputs then
+ recurrentGradInputs = {}
+ self.recurrentGradInputs[step] = recurrentGradInputs
+ end
+ for i,modula in ipairs(modules) do
+ local output, gradInput = modula.output, modula.gradInput
+ local output_ = recurrentOutputs[i]
+ assert(output_, "updateGradInputThroughTime should be preceded by updateOutput")
+ modula.output = output_
+ modula.gradInput = self.recursiveResizeAs(recurrentGradInputs[i], gradInput)
+ end
+
+ -- backward propagate through this step
+ local gradOutput = self.gradOutputs[step]
+ if gradPrevOutput then
+ self.recursiveAdd(gradOutput, gradPrevOutput)
+ end
+
+ self.gradCells[step] = gradCell
+ local scale = self.scales[step]/rho
+ local inputTable = {self.inputs[step], self.outputs[step-1], self.cells[step-1]}
+ local gradInputTable = self.recurrentModule:updateGradInput(inputTable, {gradOutput, gradCell}, scale)
+ gradInput, gradPrevOutput, gradCell = unpack(gradInputTable)
+ table.insert(self.gradInputs, 1, gradInput)
+
+ for i,modula in ipairs(modules) do
+ recurrentGradInputs[i] = modula.gradInput
+ end
+ end
+
+ return gradInput
+end
+
+function LSTM:accGradParametersThroughTime()
+ local rho = math.min(self.rho, self.step-1)
+ local stop = self.step - rho
+ for step=self.step-1,math.max(stop,1),-1 do
+ -- set the output/gradOutput states of current Module
+ local modules = self.recurrentModule:listModules()
+ local recurrentOutputs = self.recurrentOutputs[step]
+ local recurrentGradInputs = self.recurrentGradInputs[step]
+
+ for i,modula in ipairs(modules) do
+ local output, gradInput = modula.output, modula.gradInput
+ local output_ = recurrentOutputs[i]
+ local gradInput_ = recurrentGradInputs[i]
+ assert(output_, "accGradParametersThroughTime should be preceded by updateOutput")
+ assert(gradInput_, "accGradParametersThroughTime should be preceded by updateGradInputThroughTime")
+ modula.output = output_
+ modula.gradInput = gradInput_
+ end
+
+ -- backward propagate through this step
+ local scale = self.scales[step]/rho
+ local inputTable = {self.inputs[step], self.outputs[step-1], self.cells[step-1]}
+ local gradOutputTable = {self.gradOutputs[step], self.gradCells[step]}
+ self.recurrentModule:accGradParameters(inputTable, gradOutputTable, scale)
+ end
+
+ self.gradParametersAccumulated = true
+ return gradInput
+end
+
+function LSTM:accUpdateGradParametersThroughTime(lr)
+ local rho = math.min(self.rho, self.step-1)
+ local stop = self.step - rho
+ for step=self.step-1,math.max(stop,1),-1 do
+ -- set the output/gradOutput states of current Module
+ local modules = self.recurrentModule:listModules()
+ local recurrentOutputs = self.recurrentOutputs[step]
+ local recurrentGradInputs = self.recurrentGradInputs[step]
+
+ for i,modula in ipairs(modules) do
+ local output, gradInput = modula.output, modula.gradInput
+ local output_ = recurrentOutputs[i]
+ local gradInput_ = recurrentGradInputs[i]
+ assert(output_, "accGradParametersThroughTime should be preceded by updateOutput")
+ assert(gradInput_, "accGradParametersThroughTime should be preceded by updateGradInputThroughTime")
+ modula.output = output_
+ modula.gradInput = gradInput_
+ end
+
+ -- backward propagate through this step
+ local scale = self.scales[step]/rho
+ local inputTable = {self.inputs[step], self.outputs[step-1], self.cells[step]}
+ local gradOutputTable = {self.gradOutputs[step], self.gradCells[step]}
+ self.recurrentModule:accUpdateGradParameters(inputTable, gradOutputTable, lr*scale)
+ end
+
+ return gradInput
+end
+
diff --git a/Recurrent.lua b/Recurrent.lua
index 859ce19..495d3a2 100644
--- a/Recurrent.lua
+++ b/Recurrent.lua
@@ -14,10 +14,10 @@
-- output attribute to keep track of their internal state between
-- forward and backward.
------------------------------------------------------------------------
-local Recurrent, parent = torch.class('nn.Recurrent', 'nn.Module')
+local Recurrent, parent = torch.class('nn.Recurrent', 'nn.AbstractRecurrent')
function Recurrent:__init(start, input, feedback, transfer, rho, merge)
- parent.__init(self)
+ parent.__init(self, rho or 5)
local ts = torch.type(start)
if ts == 'torch.LongTensor' or ts == 'number' then
@@ -29,7 +29,6 @@ function Recurrent:__init(start, input, feedback, transfer, rho, merge)
self.feedbackModule = feedback
self.transferModule = transfer or nn.Sigmoid()
self.mergeModule = merge or nn.CAddTable()
- self.rho = rho or 5
self.modules = {self.startModule, self.inputModule, self.feedbackModule, self.transferModule, self.mergeModule}
@@ -38,22 +37,6 @@ function Recurrent:__init(start, input, feedback, transfer, rho, merge)
self.initialOutputs = {}
self.initialGradInputs = {}
- self.recurrentOutputs = {}
- self.recurrentGradInputs = {}
-
- self.fastBackward = true
- self.copyInputs = true
-
- self.inputs = {}
- self.outputs = {}
- self.gradOutputs = {}
- self.gradInputs = {}
- self.scales = {}
-
- self.gradParametersAccumulated = false
- self.step = 1
-
- self:reset()
end
-- build module used for the first step (steps == 1)
@@ -75,69 +58,6 @@ function Recurrent:buildRecurrentModule()
self.recurrentModule:add(self.transferModule)
end
-local function recursiveResizeAs(t1,t2)
- if torch.type(t2) == 'table' then
- t1 = (torch.type(t1) == 'table') and t1 or {t1}
- for key,_ in pairs(t2) do
- t1[key], t2[key] = recursiveResizeAs(t1[key], t2[key])
- end
- elseif torch.isTensor(t2) then
- t1 = t1 or t2.new()
- t1:resizeAs(t2)
- else
- error("expecting nested tensors or tables. Got "..
- torch.type(t1).." and "..torch.type(t2).." instead")
- end
- return t1, t2
-end
-
-local function recursiveSet(t1,t2)
- if torch.type(t2) == 'table' then
- t1 = (torch.type(t1) == 'table') and t1 or {t1}
- for key,_ in pairs(t2) do
- t1[key], t2[key] = recursiveSet(t1[key], t2[key])
- end
- elseif torch.isTensor(t2) then
- t1 = t1 or t2.new()
- t1:set(t2)
- else
- error("expecting nested tensors or tables. Got "..
- torch.type(t1).." and "..torch.type(t2).." instead")
- end
- return t1, t2
-end
-
-local function recursiveCopy(t1,t2)
- if torch.type(t2) == 'table' then
- t1 = (torch.type(t1) == 'table') and t1 or {t1}
- for key,_ in pairs(t2) do
- t1[key], t2[key] = recursiveCopy(t1[key], t2[key])
- end
- elseif torch.isTensor(t2) then
- t1 = t1 or t2.new()
- t1:resizeAs(t2):copy(t2)
- else
- error("expecting nested tensors or tables. Got "..
- torch.type(t1).." and "..torch.type(t2).." instead")
- end
- return t1, t2
-end
-
-local function recursiveAdd(t1, t2)
- if torch.type(t2) == 'table' then
- t1 = (torch.type(t1) == 'table') and t1 or {t1}
- for key,_ in pairs(t2) do
- t1[key], t2[key] = recursiveAdd(t1[key], t2[key])
- end
- elseif torch.isTensor(t2) and torch.isTensor(t2) then
- t1:add(t2)
- else
- error("expecting nested tensors or tables. Got "..
- torch.type(t1).." and "..torch.type(t2).." instead")
- end
- return t1, t2
-end
-
function Recurrent:updateOutput(input)
-- output(t) = transfer(feedback(output_(t-1)) + input(input_(t)))
local output
@@ -145,7 +65,7 @@ function Recurrent:updateOutput(input)
-- set/save the output states
local modules = self.initialModule:listModules()
for i,modula in ipairs(modules) do
- local output_ = recursiveResizeAs(self.initialOutputs[i], modula.output)
+ local output_ = self.recursiveResizeAs(self.initialOutputs[i], modula.output)
modula.output = output_
end
output = self.initialModule:updateOutput(input)
@@ -163,7 +83,7 @@ function Recurrent:updateOutput(input)
self.recurrentOutputs[self.step] = recurrentOutputs
end
for i,modula in ipairs(modules) do
- local output_ = recursiveResizeAs(recurrentOutputs[i], modula.output)
+ local output_ = self.recursiveResizeAs(recurrentOutputs[i], modula.output)
modula.output = output_
end
-- self.output is the previous output of this module
@@ -179,18 +99,9 @@ function Recurrent:updateOutput(input)
if self.train ~= false then
local input_ = self.inputs[self.step]
- if self.copyInputs then
- input_ = recursiveCopy(input_, input)
- else
- input_:set(input)
- end
- end
-
- if self.train ~= false then
- local input_ = self.inputs[self.step]
self.inputs[self.step] = self.copyInputs
- and recursiveCopy(input_, input)
- or recursiveSet(input_, input)
+ and self.recursiveCopy(input_, input)
+ or self.recursiveSet(input_, input)
end
self.outputs[self.step] = output
@@ -200,18 +111,6 @@ function Recurrent:updateOutput(input)
return self.output
end
-function Recurrent:updateGradInput(input, gradOutput)
- -- Back-Propagate Through Time (BPTT) happens in updateParameters()
- -- for now we just keep a list of the gradOutputs
- self.gradOutputs[self.step-1] = recursiveCopy(self.gradOutputs[self.step-1] , gradOutput)
-end
-
-function Recurrent:accGradParameters(input, gradOutput, scale)
- -- Back-Propagate Through Time (BPTT) happens in updateParameters()
- -- for now we just keep a list of the scales
- self.scales[self.step-1] = scale
-end
-
-- not to be confused with the hit movie Back to the Future
function Recurrent:backwardThroughTime()
assert(self.step > 1, "expecting at least one updateOutput")
@@ -219,7 +118,7 @@ function Recurrent:backwardThroughTime()
local stop = self.step - rho
if self.fastBackward then
self.gradInputs = {}
- local gradInput
+ local gradInput, gradPrevOutput
for step=self.step-1,math.max(stop, 2),-1 do
-- set the output/gradOutput states of current Module
local modules = self.recurrentModule:listModules()
@@ -235,19 +134,19 @@ function Recurrent:backwardThroughTime()
local output_ = recurrentOutputs[i]
assert(output_, "backwardThroughTime should be preceded by updateOutput")
modula.output = output_
- modula.gradInput = recursiveResizeAs(recurrentGradInputs[i], gradInput)
+ modula.gradInput = self.recursiveResizeAs(recurrentGradInputs[i], gradInput)
end
-- backward propagate through this step
local input = self.inputs[step]
local output = self.outputs[step-1]
local gradOutput = self.gradOutputs[step]
- if gradInput then
- recursiveAdd(gradOutput, gradInput)
+ if gradPrevOutput then
+ self.recursiveAdd(gradOutput, gradPrevOutput)
end
local scale = self.scales[step]
- gradInput = self.recurrentModule:backward({input, output}, gradOutput, scale/rho)[2]
+ gradInput, gradPrevOutput = unpack(self.recurrentModule:backward({input, output}, gradOutput, scale/rho))
table.insert(self.gradInputs, 1, gradInput)
for i,modula in ipairs(modules) do
@@ -260,14 +159,14 @@ function Recurrent:backwardThroughTime()
local modules = self.initialModule:listModules()
for i,modula in ipairs(modules) do
modula.output = self.initialOutputs[i]
- modula.gradInput = recursiveResizeAs(self.initialGradInputs[i], modula.gradInput)
+ modula.gradInput = self.recursiveResizeAs(self.initialGradInputs[i], modula.gradInput)
end
-- backward propagate through first step
local input = self.inputs[1]
local gradOutput = self.gradOutputs[1]
- if gradInput then
- recursiveAdd(gradOutput, gradInput)
+ if gradPrevOutput then
+ self.recursiveAdd(gradOutput, gradPrevOutput)
end
local scale = self.scales[1]
gradInput = self.initialModule:backward(input, gradOutput, scale/rho)
@@ -296,16 +195,10 @@ function Recurrent:backwardThroughTime()
end
end
-function Recurrent:backwardUpdateThroughTime(learningRate)
- local gradInput = self:updateGradInputThroughTime()
- self:accUpdateGradParametersThroughTime(learningRate)
- return gradInput
-end
-
function Recurrent:updateGradInputThroughTime()
assert(self.step > 1, "expecting at least one updateOutput")
self.gradInputs = {}
- local gradInput
+ local gradInput, gradPrevOutput
local rho = math.min(self.rho, self.step-1)
local stop = self.step - rho
for step=self.step-1,math.max(stop,2),-1 do
@@ -322,21 +215,23 @@ function Recurrent:updateGradInputThroughTime()
local output_ = recurrentOutputs[i]
assert(output_, "updateGradInputThroughTime should be preceded by updateOutput")
modula.output = output_
- modula.gradInput = recursiveResizeAs(recurrentGradInputs[i], gradInput)
+ modula.gradInput = self.recursiveResizeAs(recurrentGradInputs[i], gradInput)
end
-- backward propagate through this step
local input = self.inputs[step]
local output = self.outputs[step-1]
local gradOutput = self.gradOutputs[step]
- if gradInput then
- gradOutput:add(gradInput)
+ if gradPrevOutput then
+ self.recursiveAdd(gradOutput, gradPrevOutput)
end
- gradInput = self.recurrentModule:updateGradInput({input, output}, gradOutput)[2]
+
+ gradInput, gradPrevOutput = unpack(self.recurrentModule:updateGradInput({input, output}, gradOutput))
+ table.insert(self.gradInputs, 1, gradInput)
+
for i,modula in ipairs(modules) do
recurrentGradInputs[i] = modula.gradInput
end
- table.insert(self.gradInputs, 1, gradInput)
end
if stop <= 1 then
@@ -344,21 +239,21 @@ function Recurrent:updateGradInputThroughTime()
local modules = self.initialModule:listModules()
for i,modula in ipairs(modules) do
modula.output = self.initialOutputs[i]
- modula.gradInput = recursiveResizeAs(self.initialGradInputs[i], modula.gradInput)
+ modula.gradInput = self.recursiveResizeAs(self.initialGradInputs[i], modula.gradInput)
end
-- backward propagate through first step
local input = self.inputs[1]
local gradOutput = self.gradOutputs[1]
- if gradInput then
- gradOutput:add(gradInput)
+ if gradPrevOutput then
+ self.recursiveAdd(gradOutput, gradPrevOutput)
end
gradInput = self.initialModule:updateGradInput(input, gradOutput)
+ table.insert(self.gradInputs, 1, gradInput)
for i,modula in ipairs(modules) do
self.initialGradInputs[i] = modula.gradInput
end
- table.insert(self.gradInputs, 1, gradInput)
end
return gradInput
@@ -390,7 +285,6 @@ function Recurrent:accGradParametersThroughTime()
local scale = self.scales[step]
self.recurrentModule:accGradParameters({input, output}, gradOutput, scale/rho)
-
end
if stop <= 1 then
@@ -475,154 +369,36 @@ function Recurrent:accUpdateGradParametersThroughTime(lr)
return gradInput
end
-function Recurrent:updateParameters(learningRate)
- if self.gradParametersAccumulated then
- for i=1,#self.modules do
- self.modules[i]:updateParameters(learningRate)
- end
- else
- self:backwardUpdateThroughTime(learningRate)
- end
-end
-
--- goes hand in hand with the next method : forget()
-function Recurrent:recycle()
- -- +1 is to skip initialModule
- if self.step > self.rho + 1 then
- assert(self.recurrentOutputs[self.step] == nil)
- assert(self.recurrentOutputs[self.step-self.rho] ~= nil)
- self.recurrentOutputs[self.step] = self.recurrentOutputs[self.step-self.rho]
- self.recurrentGradInputs[self.step] = self.recurrentGradInputs[self.step-self.rho]
- self.recurrentOutputs[self.step-self.rho] = nil
- self.recurrentGradInputs[self.step-self.rho] = nil
- -- need to keep rho+1 of these
- self.outputs[self.step] = self.outputs[self.step-self.rho-1]
- self.outputs[self.step-self.rho-1] = nil
- end
- if self.step > self.rho then
- assert(self.inputs[self.step] == nil)
- assert(self.inputs[self.step-self.rho] ~= nil)
- self.inputs[self.step] = self.inputs[self.step-self.rho]
- self.gradOutputs[self.step] = self.gradOutputs[self.step-self.rho]
- self.inputs[self.step-self.rho] = nil
- self.gradOutputs[self.step-self.rho] = nil
- self.scales[self.step-self.rho] = nil
- end
-end
-
function Recurrent:forget()
-
- if self.train ~= false then
- -- bring all states back to the start of the sequence buffers
- local lastStep = self.step - 1
-
- if lastStep > self.rho + 1 then
- local i = 2
- for step = lastStep-self.rho+1,lastStep do
- self.recurrentOutputs[i] = self.recurrentOutputs[step]
- self.recurrentGradInputs[i] = self.recurrentGradInputs[step]
- self.recurrentOutputs[step] = nil
- self.recurrentGradInputs[step] = nil
- -- we keep rho+1 of these : outputs[k]=outputs[k+rho+1]
- self.outputs[i-1] = self.outputs[step]
- self.outputs[step] = nil
- i = i + 1
- end
-
- end
-
- if lastStep > self.rho then
- local i = 1
- for step = lastStep-self.rho+1,lastStep do
- self.inputs[i] = self.inputs[step]
- self.gradOutputs[i] = self.gradOutputs[step]
- self.inputs[step] = nil
- self.gradOutputs[step] = nil
- self.scales[step] = nil
- i = i + 1
- end
-
- end
- end
-
- -- forget the past inputs; restart from first step
- self.step = 1
-end
-
-function Recurrent:size()
- return #self.modules
-end
-
-function Recurrent:get(index)
- return self.modules[index]
-end
-
-function Recurrent:zeroGradParameters()
- for i=1,#self.modules do
- self.modules[i]:zeroGradParameters()
- end
-end
-
-function Recurrent:training()
- for i=1,#self.modules do
- self.modules[i]:training()
- end
-end
-
-function Recurrent:evaluate()
- for i=1,#self.modules do
- self.modules[i]:evaluate()
- end
-end
-
-function Recurrent:share(mlp,...)
- for i=1,#self.modules do
- self.modules[i]:share(mlp.modules[i],...);
- end
-end
-
-function Recurrent:reset(stdv)
- self:forget()
- for i=1,#self.modules do
- self.modules[i]:reset(stdv)
- end
-end
-
-function Recurrent:parameters()
- local function tinsert(to, from)
- if type(from) == 'table' then
- for i=1,#from do
- tinsert(to,from[i])
- end
- else
- table.insert(to,from)
- end
- end
- local w = {}
- local gw = {}
- for i=1,#self.modules do
- local mw,mgw = self.modules[i]:parameters()
- if mw then
- tinsert(w,mw)
- tinsert(gw,mgw)
- end
- end
- return w,gw
+ parent.forget(self, 1)
end
function Recurrent:__tostring__()
local tab = ' '
local line = '\n'
local next = ' -> '
- local str = 'nn.Recurrent'
- str = str .. ' {' .. line .. tab .. '[input'
- for i=1,#self.modules do
+ local str = torch.type(self)
+ str = str .. ' {' .. line .. tab .. '[{input(t), output(t-1)}'
+ for i=1,3 do
str = str .. next .. '(' .. i .. ')'
end
- str = str .. next .. 'output]'
- for i=1,#self.modules do
- str = str .. line .. tab .. '(' .. i .. '): ' .. tostring(self.modules[i]):gsub(line, line .. tab)
- end
+ str = str .. next .. 'output(t)]'
+
+ local tab = ' '
+ local line = '\n '
+ local next = ' |`-> '
+ local ext = ' | '
+ local last = ' ... -> '
+ str = str .. line .. '(1): ' .. ' {' .. line .. tab .. 'input(t)'
+ str = str .. line .. tab .. next .. '(t==0): ' .. tostring(self.startModule):gsub('\n', '\n' .. tab .. ext)
+ str = str .. line .. tab .. next .. '(t~=0): ' .. tostring(self.inputModule):gsub('\n', '\n' .. tab .. ext)
+ str = str .. line .. tab .. 'output(t-1)'
+ str = str .. line .. tab .. next .. tostring(self.feedbackModule):gsub('\n', line .. tab .. ext)
+ local tab = ' '
+ local line = '\n'
+ local next = ' -> '
+ str = str .. line .. tab .. '(' .. 2 .. '): ' .. tostring(self.mergeModule):gsub(line, line .. tab)
+ str = str .. line .. tab .. '(' .. 3 .. '): ' .. tostring(self.transferModule):gsub(line, line .. tab)
str = str .. line .. '}'
return str
end
diff --git a/RepeaterCriterion.lua b/RepeaterCriterion.lua
index 44bf078..a6ad078 100644
--- a/RepeaterCriterion.lua
+++ b/RepeaterCriterion.lua
@@ -13,6 +13,7 @@ function RepeaterCriterion:__init(criterion)
end
function RepeaterCriterion:forward(inputTable, target)
+ self.output = 0
for i,input in ipairs(inputTable) do
self.output = self.output + self.criterion:forward(input, target)
end
diff --git a/ZeroGrad.lua b/ZeroGrad.lua
new file mode 100644
index 0000000..83f88d2
--- /dev/null
+++ b/ZeroGrad.lua
@@ -0,0 +1,28 @@
+local ZeroGrad, parent = torch.class("nn.ZeroGrad", "nn.Module")
+
+local function recursiveZero(t1,t2)
+ if torch.type(t2) == 'table' then
+ t1 = (torch.type(t1) == 'table') and t1 or {t1}
+ for key,_ in pairs(t2) do
+ t1[key], t2[key] = recursiveZero(t1[key], t2[key])
+ end
+ elseif torch.isTensor(t2) then
+ t1 = t1 or t2.new()
+ t1:resizeAs(t2):zero()
+ else
+ error("expecting nested tensors or tables. Got "..
+ torch.type(t1).." and "..torch.type(t2).." instead")
+ end
+ return t1, t2
+end
+
+function ZeroGrad:updateOutput(input)
+ self.output:set(input)
+ return self.output
+end
+
+-- the gradient is simply zeroed.
+-- useful when you don't want to backpropgate through certain paths.
+function ZeroGrad:updateGradInput(input, gradOutput)
+ self.gradInput = recursiveZero(self.gradInput, gradOutput)
+end
diff --git a/init.lua b/init.lua
index 3d6c5df..278a07b 100644
--- a/init.lua
+++ b/init.lua
@@ -75,13 +75,18 @@ torch.include('nnx', 'SoftMaxTree.lua')
torch.include('nnx', 'MultiSoftMax.lua')
torch.include('nnx', 'Balance.lua')
torch.include('nnx', 'NarrowLookupTable.lua')
-torch.include('nnx', 'Recurrent.lua')
-torch.include('nnx', 'Repeater.lua')
-torch.include('nnx', 'Sequencer.lua')
torch.include('nnx', 'PushTable.lua')
torch.include('nnx', 'PullTable.lua')
+torch.include('nnx', 'ZeroGrad.lua')
torch.include('nnx', 'Padding.lua')
+-- recurrent
+torch.include('nnx', 'AbstractRecurrent.lua')
+torch.include('nnx', 'Recurrent.lua')
+torch.include('nnx', 'LSTM.lua')
+torch.include('nnx', 'Repeater.lua')
+torch.include('nnx', 'Sequencer.lua')
+
-- criterions:
torch.include('nnx', 'SuperCriterion.lua')
torch.include('nnx', 'DistNLLCriterion.lua')
diff --git a/test/test-all.lua b/test/test-all.lua
index 5abd4e2..16f819f 100644
--- a/test/test-all.lua
+++ b/test/test-all.lua
@@ -342,6 +342,15 @@ function nnxtest.Recurrent()
-- rho = nSteps
local mlp = nn.Recurrent(outputSize, inputModule, feedbackModule, transferModule:clone(), nSteps)
+ -- test that the internal mlps are recursable :
+ local isRecursable = nn.AbstractRecurrent.isRecursable
+ mytester:assert(isRecursable(mlp.initialModule, torch.randn(inputSize)), "Recurrent isRecursable() initial error")
+ mytester:assert(isRecursable(mlp.recurrentModule, {torch.randn(inputSize), torch.randn(outputSize)}), "Recurrent isRecursable() recurrent error")
+
+ -- test that the above test actually works
+ local euclidean = nn.Euclidean(inputSize, outputSize)
+ mytester:assert(not isRecursable(euclidean, torch.randn(batchSize, inputSize)), "AbstractRecurrent.isRecursable error")
+
local gradOutputs, outputs = {}, {}
-- inputs = {inputN, {inputN-1, {inputN-2, ...}}}}}
local inputs
@@ -355,8 +364,10 @@ function nnxtest.Recurrent()
mlp:zeroGradParameters()
local mlp7 = mlp:clone()
mlp7.rho = nSteps - 1
+ local inputSequence = {}
for step=1,nSteps do
local input = torch.randn(batchSize, inputSize)
+ inputSequence[step] = input
local gradOutput
if step ~= nSteps then
-- for the sake of keeping this unit test simple,
@@ -389,7 +400,18 @@ function nnxtest.Recurrent()
local mlp5 = mlp:clone()
-- backward propagate through time (BPTT)
- local gradInput = mlp:backwardThroughTime()
+ local gradInput = mlp:backwardThroughTime():clone()
+ mlp:forget() -- test ability to forget
+ mlp:zeroGradParameters()
+ local foutputs = {}
+ for step=1,nSteps do
+ foutputs[step] = mlp:forward(inputSequence[step])
+ mytester:assertTensorEq(foutputs[step], outputs[step], 0.00001, "Recurrent forget output error "..step)
+ mlp:backward(input, gradOutputs[step])
+ end
+ local fgradInput = mlp:backwardThroughTime():clone()
+ mytester:assertTensorEq(gradInput, fgradInput, 0.00001, "Recurrent forget gradInput error")
+
mlp4.fastBackward = false
local gradInput4 = mlp4:backwardThroughTime()
mytester:assertTensorEq(gradInput, gradInput4, 0.000001, 'error slow vs fast backwardThroughTime')
@@ -564,6 +586,114 @@ function nnxtest.Recurrent_TestTable()
mlp:backwardThroughTime(learningRate)
end
+function nnxtest.LSTM()
+ local batchSize = math.random(1,2)
+ local inputSize = math.random(3,4)
+ local outputSize = math.random(5,6)
+ local nStep = 3
+ local input = {}
+ local gradOutput = {}
+ for step=1,nStep do
+ input[step] = torch.randn(batchSize, inputSize)
+ if step == nStep then
+ -- for the sake of keeping this unit test simple,
+ gradOutput[step] = torch.randn(batchSize, outputSize)
+ else
+ -- only the last step will get a gradient from the output
+ gradOutput[step] = torch.zeros(batchSize, outputSize)
+ end
+ end
+ local lstm = nn.LSTM(inputSize, outputSize)
+
+ local isRecursable = nn.AbstractRecurrent.isRecursable
+ local inputTable = {torch.randn(batchSize, inputSize), torch.randn(batchSize, outputSize), torch.randn(batchSize, outputSize)}
+ mytester:assert(isRecursable(lstm.recurrentModule, inputTable), "LSTM isRecursable() error")
+
+ -- we will use this to build an LSTM step by step (with shared params)
+ local lstmStep = lstm.recurrentModule:clone()
+
+ -- forward/backward through LSTM
+ local output = {}
+ lstm:zeroGradParameters()
+ for step=1,nStep do
+ output[step] = lstm:forward(input[step])
+ assert(torch.isTensor(input[step]))
+ lstm:backward(input[step], gradOutput[step], 1)
+ end
+ local gradInput = lstm:backwardThroughTime()
+
+ local mlp2 -- this one will simulate rho = nSteps
+ local inputs
+ for step=1,nStep do
+ -- iteratively build an LSTM out of non-recurrent components
+ local lstm = lstmStep:clone()
+ lstm:share(lstmStep, 'weight', 'gradWeight', 'bias', 'gradBias')
+ if step == 1 then
+ mlp2 = lstm
+ else
+ local rnn = nn.Sequential()
+ local para = nn.ParallelTable()
+ para:add(nn.Identity()):add(mlp2)
+ rnn:add(para)
+ rnn:add(nn.FlattenTable())
+ rnn:add(lstm)
+ mlp2 = rnn
+ end
+
+ -- prepare inputs for mlp2
+ if inputs then
+ inputs = {input[step], inputs}
+ else
+ inputs = {input[step], torch.zeros(batchSize, outputSize), torch.zeros(batchSize, outputSize)}
+ end
+ end
+ mlp2:add(nn.SelectTable(1)) --just output the output (not cell)
+ local output2 = mlp2:forward(inputs)
+
+ mlp2:zeroGradParameters()
+ local gradInput2 = mlp2:backward(inputs, gradOutput[nStep], 1/nStep)
+ mytester:assertTensorEq(gradInput2[2][2][1], gradInput, 0.00001, "LSTM gradInput error")
+ mytester:assertTensorEq(output[nStep], output2, 0.00001, "LSTM output error")
+
+ local params, gradParams = lstm:parameters()
+ local params2, gradParams2 = lstmStep:parameters()
+ mytester:assert(#params == #params2, "LSTM parameters error "..#params.." ~= "..#params2)
+ for i, gradParam in ipairs(gradParams) do
+ local gradParam2 = gradParams2[i]
+ mytester:assertTensorEq(gradParam, gradParam2, 0.000001,
+ "LSTM gradParam "..i.." error "..tostring(gradParam).." "..tostring(gradParam2))
+ end
+
+ gradParams = lstm.recursiveCopy(nil, gradParams)
+ gradInput = gradInput:clone()
+ mytester:assert(lstm.zeroTensor:sum() == 0, "zeroTensor error")
+ lstm:forget()
+ output = lstm.recursiveCopy(nil, output)
+ local output3 = {}
+ lstm:zeroGradParameters()
+ for step=1,nStep do
+ output3[step] = lstm:forward(input[step])
+ lstm:backward(input[step], gradOutput[step], 1)
+ end
+ local gradInput3 = lstm:updateGradInputThroughTime()
+ lstm:accGradParametersThroughTime()
+
+ mytester:assert(#output == #output3, "LSTM output size error")
+ for i,output in ipairs(output) do
+ mytester:assertTensorEq(output, output3[i], 0.00001, "LSTM forget (updateOutput) error "..i)
+ end
+
+ mytester:assertTensorEq(gradInput, gradInput3, 0.00001, "LSTM updateGradInputThroughTime error")
+ --if true then return end
+ local params3, gradParams3 = lstm:parameters()
+ mytester:assert(#params == #params3, "LSTM parameters error "..#params.." ~= "..#params3)
+ for i, gradParam in ipairs(gradParams) do
+ local gradParam3 = gradParams3[i]
+ mytester:assertTensorEq(gradParam, gradParam3, 0.000001,
+ "LSTM gradParam "..i.." error "..tostring(gradParam).." "..tostring(gradParam3))
+ end
+end
+
function nnxtest.SpatialNormalization_Gaussian2D()
local inputSize = math.random(11,20)
local kersize = 9