diff options
author | Nicholas Léonard <nick@nikopia.org> | 2015-03-18 04:49:54 +0300 |
---|---|---|
committer | Nicholas Léonard <nick@nikopia.org> | 2015-03-18 04:49:54 +0300 |
commit | c6de4cf8d372a5e507321464a9a98ed2d829798f (patch) | |
tree | 715b1aa80c1b2769ac543a4ca5f93527ae207e82 | |
parent | eaee722aab186e13bcf33479f3724a53fb049ec5 (diff) | |
parent | edf33956ca750ed790f5275734fa80a2c07b8c7d (diff) |
Merge pull request #27 from nicholas-leonard/lstm
LSTM (work in progress)
-rw-r--r-- | AbstractRecurrent.lua | 281 | ||||
-rw-r--r-- | LSTM.lua | 353 | ||||
-rw-r--r-- | Recurrent.lua | 318 | ||||
-rw-r--r-- | RepeaterCriterion.lua | 1 | ||||
-rw-r--r-- | ZeroGrad.lua | 28 | ||||
-rw-r--r-- | init.lua | 11 | ||||
-rw-r--r-- | test/test-all.lua | 132 |
7 files changed, 849 insertions, 275 deletions
diff --git a/AbstractRecurrent.lua b/AbstractRecurrent.lua new file mode 100644 index 0000000..af5cba0 --- /dev/null +++ b/AbstractRecurrent.lua @@ -0,0 +1,281 @@ +local AbstractRecurrent, parent = torch.class('nn.AbstractRecurrent', 'nn.Container') + +function AbstractRecurrent:__init(rho) + parent.__init(self) + + self.rho = rho --the maximum number of time steps to BPTT + + self.fastBackward = true + self.copyInputs = true + + self.inputs = {} + self.outputs = {} + self.gradOutputs = {} + self.scales = {} + + self.gradParametersAccumulated = false + self.step = 1 + + -- stores internal states of Modules at different time-steps + self.recurrentOutputs = {} + self.recurrentGradInputs = {} + + self:reset() +end + +local function recursiveResizeAs(t1,t2) + if torch.type(t2) == 'table' then + t1 = (torch.type(t1) == 'table') and t1 or {t1} + for key,_ in pairs(t2) do + t1[key], t2[key] = recursiveResizeAs(t1[key], t2[key]) + end + elseif torch.isTensor(t2) then + t1 = t1 or t2.new() + t1:resizeAs(t2) + else + error("expecting nested tensors or tables. Got ".. + torch.type(t1).." and "..torch.type(t2).." instead") + end + return t1, t2 +end +AbstractRecurrent.recursiveResizeAs = recursiveResizeAs + +local function recursiveSet(t1,t2) + if torch.type(t2) == 'table' then + t1 = (torch.type(t1) == 'table') and t1 or {t1} + for key,_ in pairs(t2) do + t1[key], t2[key] = recursiveSet(t1[key], t2[key]) + end + elseif torch.isTensor(t2) then + t1 = t1 or t2.new() + t1:set(t2) + else + error("expecting nested tensors or tables. Got ".. + torch.type(t1).." and "..torch.type(t2).." instead") + end + return t1, t2 +end +AbstractRecurrent.recursiveSet = recursiveSet + +local function recursiveCopy(t1,t2) + if torch.type(t2) == 'table' then + t1 = (torch.type(t1) == 'table') and t1 or {t1} + for key,_ in pairs(t2) do + t1[key], t2[key] = recursiveCopy(t1[key], t2[key]) + end + elseif torch.isTensor(t2) then + t1 = t1 or t2.new() + t1:resizeAs(t2):copy(t2) + else + error("expecting nested tensors or tables. Got ".. + torch.type(t1).." and "..torch.type(t2).." instead") + end + return t1, t2 +end +AbstractRecurrent.recursiveCopy = recursiveCopy + +local function recursiveAdd(t1, t2) + if torch.type(t2) == 'table' then + t1 = (torch.type(t1) == 'table') and t1 or {t1} + for key,_ in pairs(t2) do + t1[key], t2[key] = recursiveAdd(t1[key], t2[key]) + end + elseif torch.isTensor(t2) and torch.isTensor(t2) then + t1:add(t2) + else + error("expecting nested tensors or tables. Got ".. + torch.type(t1).." and "..torch.type(t2).." instead") + end + return t1, t2 +end +AbstractRecurrent.recursiveAdd = recursiveAdd + +local function recursiveTensorEq(t1, t2) + if torch.type(t2) == 'table' then + local isEqual = true + if torch.type(t1) ~= 'table' then + return false + end + for key,_ in pairs(t2) do + isEqual = isEqual and recursiveTensorEq(t1[key], t2[key]) + end + return isEqual + elseif torch.isTensor(t2) and torch.isTensor(t2) then + local diff = t1-t2 + local err = diff:abs():max() + return err < 0.00001 + else + error("expecting nested tensors or tables. Got ".. + torch.type(t1).." and "..torch.type(t2).." instead") + end +end +AbstractRecurrent.recursiveTensorEq = recursiveTensorEq + +local function recursiveNormal(t2) + if torch.type(t2) == 'table' then + for key,_ in pairs(t2) do + t2[key] = recursiveNormal(t2[key]) + end + elseif torch.isTensor(t2) then + t2:normal() + else + error("expecting tensor or table thereof. Got " + ..torch.type(t2).." instead") + end + return t2 +end +AbstractRecurrent.recursiveNormal = recursiveNormal + +function AbstractRecurrent:updateGradInput(input, gradOutput) + -- Back-Propagate Through Time (BPTT) happens in updateParameters() + -- for now we just keep a list of the gradOutputs + self.gradOutputs[self.step-1] = self.recursiveCopy(self.gradOutputs[self.step-1] , gradOutput) +end + +function AbstractRecurrent:accGradParameters(input, gradOutput, scale) + -- Back-Propagate Through Time (BPTT) happens in updateParameters() + -- for now we just keep a list of the scales + self.scales[self.step-1] = scale +end + +function AbstractRecurrent:backwardUpdateThroughTime(learningRate) + local gradInput = self:updateGradInputThroughTime() + self:accUpdateGradParametersThroughTime(learningRate) + return gradInput +end + +function AbstractRecurrent:updateParameters(learningRate) + if self.gradParametersAccumulated then + for i=1,#self.modules do + self.modules[i]:updateParameters(learningRate) + end + else + self:backwardUpdateThroughTime(learningRate) + end +end + +-- goes hand in hand with the next method : forget() +function AbstractRecurrent:recycle() + -- +1 is to skip initialModule + if self.step > self.rho + 1 then + assert(self.recurrentOutputs[self.step] == nil) + assert(self.recurrentOutputs[self.step-self.rho] ~= nil) + self.recurrentOutputs[self.step] = self.recurrentOutputs[self.step-self.rho] + self.recurrentGradInputs[self.step] = self.recurrentGradInputs[self.step-self.rho] + self.recurrentOutputs[self.step-self.rho] = nil + self.recurrentGradInputs[self.step-self.rho] = nil + -- need to keep rho+1 of these + self.outputs[self.step] = self.outputs[self.step-self.rho-1] + self.outputs[self.step-self.rho-1] = nil + end + if self.step > self.rho then + assert(self.inputs[self.step] == nil) + assert(self.inputs[self.step-self.rho] ~= nil) + self.inputs[self.step] = self.inputs[self.step-self.rho] + self.gradOutputs[self.step] = self.gradOutputs[self.step-self.rho] + self.inputs[self.step-self.rho] = nil + self.gradOutputs[self.step-self.rho] = nil + self.scales[self.step-self.rho] = nil + end +end + +function AbstractRecurrent:forget(offset) + offset = offset or 1 + if self.train ~= false then + -- bring all states back to the start of the sequence buffers + local lastStep = self.step - 1 + + if lastStep > self.rho + offset then + local i = 1 + offset + for step = lastStep-self.rho+offset,lastStep do + self.recurrentOutputs[i] = self.recurrentOutputs[step] + self.recurrentGradInputs[i] = self.recurrentGradInputs[step] + self.recurrentOutputs[step] = nil + self.recurrentGradInputs[step] = nil + -- we keep rho+1 of these : outputs[k]=outputs[k+rho+1] + self.outputs[i-1] = self.outputs[step] + self.outputs[step] = nil + i = i + 1 + end + + end + + if lastStep > self.rho then + local i = 1 + for step = lastStep-self.rho+1,lastStep do + self.inputs[i] = self.inputs[step] + self.gradOutputs[i] = self.gradOutputs[step] + self.inputs[step] = nil + self.gradOutputs[step] = nil + self.scales[step] = nil + i = i + 1 + end + + end + end + + -- forget the past inputs; restart from first step + self.step = 1 +end + +-- tests whether or not the mlp can be used internally for recursion. +-- forward A, backward A, forward B, forward A should be consistent with +-- forward B, backward B, backward A where A and B each +-- have their own gradInputs/outputs. +function AbstractRecurrent.isRecursable(mlp, input) + local output = recursiveCopy(nil, mlp:forward(input)) --forward A + local gradOutput = recursiveNormal(recursiveCopy(nil, output)) + mlp:zeroGradParameters() + local gradInput = recursiveCopy(nil, mlp:backward(input, gradOutput)) --backward A + local params, gradParams = mlp:parameters() + gradParams = recursiveCopy(nil, gradParams) + + -- output/gradInput are the only internal module states that we track + local recurrentOutputs = {} + local recurrentGradInputs = {} + + local modules = mlp:listModules() + + -- save the output/gradInput states of A + for i,modula in ipairs(modules) do + recurrentOutputs[i] = modula.output + recurrentGradInputs[i] = modula.gradInput + end + -- set the output/gradInput states for B + local recurrentOutputs2 = {} + local recurrentGradInputs2 = {} + for i,modula in ipairs(modules) do + modula.output = recursiveResizeAs(recurrentOutputs2[i], modula.output) + modula.gradInput = recursiveResizeAs(recurrentGradInputs2[i], modula.gradInput) + end + + local input2 = recursiveNormal(recursiveCopy(nil, input)) + local gradOutput2 = recursiveNormal(recursiveCopy(nil, gradOutput)) + local output2 = mlp:forward(input2) --forward B + mlp:zeroGradParameters() + local gradInput2 = mlp:backward(input2, gradOutput2) --backward B + + -- save the output/gradInput state of B + for i,modula in ipairs(modules) do + recurrentOutputs2[i] = modula.output + recurrentGradInputs2[i] = modula.gradInput + end + + -- set the output/gradInput states for A + for i,modula in ipairs(modules) do + modula.output = recursiveResizeAs(recurrentOutputs[i], modula.output) + modula.gradInput = recursiveResizeAs(recurrentGradInputs[i], modula.gradInput) + end + + mlp:zeroGradParameters() + local gradInput3 = mlp:backward(input, gradOutput) --forward A + local gradInputTest = recursiveTensorEq(gradInput, gradInput3) + local params3, gradParams3 = mlp:parameters() + local nEq = 0 + for i,gradParam in ipairs(gradParams) do + nEq = nEq + (recursiveTensorEq(gradParam, gradParams3[i]) and 1 or 0) + end + local gradParamsTest = (nEq == #gradParams3) + mlp:zeroGradParameters() + return gradParamsTest and gradInputTest, gradParamsTest, gradInputTest +end diff --git a/LSTM.lua b/LSTM.lua new file mode 100644 index 0000000..a3541b8 --- /dev/null +++ b/LSTM.lua @@ -0,0 +1,353 @@ +------------------------------------------------------------------------ +--[[ LSTM ]]-- +-- Long Short Term Memory architecture. +-- Ref. A.: http://arxiv.org/pdf/1303.5778v1 (blueprint for this module) +-- B. http://web.eecs.utk.edu/~itamar/courses/ECE-692/Bobby_paper1.pdf +-- C. https://github.com/wojzaremba/lstm +-- Expects 1D or 2D input. +-- The first input in sequence uses zero value for cell and hidden state +------------------------------------------------------------------------ +local LSTM, parent = torch.class('nn.LSTM', 'nn.AbstractRecurrent') + +function LSTM:__init(inputSize, outputSize, rho) + parent.__init(self, rho or 999999999999) + self.inputSize = inputSize + self.outputSize = outputSize + -- build the model + self.recurrentModule = self:buildModel() + -- make it work with nn.Container + self.modules[1] = self.recurrentModule + + -- for output(0), cell(0) and gradCell(T) + self.zeroTensor = torch.Tensor() + + self.cells = {} + self.gradCells = {} +end + +-------------------------- factory methods ----------------------------- +function LSTM:buildGate() + -- Note : gate expects an input table : {input, output(t-1), cell(t-1)} + local gate = nn.Sequential() + local input2gate = nn.Linear(self.inputSize, self.outputSize) + local output2gate = nn.Linear(self.outputSize, self.outputSize) + local cell2gate = nn.CMul(self.outputSize) -- diagonal cell to gate weight matrix + --output2gate:noBias() --TODO + local para = nn.ParallelTable() + para:add(input2gate):add(output2gate):add(cell2gate) + gate:add(para) + gate:add(nn.CAddTable()) + gate:add(nn.Sigmoid()) + return gate +end + +function LSTM:buildInputGate() + self.inputGate = self:buildGate() + return self.inputGate +end + +function LSTM:buildForgetGate() + self.forgetGate = self:buildGate() + return self.forgetGate +end + +function LSTM:buildHidden() + local hidden = nn.Sequential() + local input2hidden = nn.Linear(self.inputSize, self.outputSize) + local output2hidden = nn.Linear(self.outputSize, self.outputSize) + local para = nn.ParallelTable() + --output2hidden:noBias() + para:add(input2hidden):add(output2hidden) + -- input is {input, output(t-1), cell(t-1)}, but we only need {input, output(t-1)} + local concat = nn.ConcatTable() + concat:add(nn.SelectTable(1)):add(nn.SelectTable(2)) + hidden:add(concat) + hidden:add(para) + hidden:add(nn.CAddTable()) + self.hiddenLayer = hidden + return hidden +end + +function LSTM:buildCell() + -- build + self.inputGate = self:buildInputGate() + self.forgetGate = self:buildForgetGate() + self.hiddenLayer = self:buildHidden() + -- forget = forgetGate{input, output(t-1), cell(t-1)} * cell(t-1) + local forget = nn.Sequential() + local concat = nn.ConcatTable() + concat:add(self.forgetGate):add(nn.SelectTable(3)) + forget:add(concat) + forget:add(nn.CMulTable()) + -- input = inputGate{input, output(t-1), cell(t-1)} * hiddenLayer{input, output(t-1), cell(t-1)} + local input = nn.Sequential() + local concat2 = nn.ConcatTable() + concat2:add(self.inputGate):add(self.hiddenLayer) + input:add(concat2) + input:add(nn.CMulTable()) + -- cell(t) = forget + input + local cell = nn.Sequential() + local concat3 = nn.ConcatTable() + concat3:add(forget):add(input) + cell:add(concat3) + cell:add(nn.CAddTable()) + self.cellLayer = cell + return cell +end + +function LSTM:buildOutputGate() + self.outputGate = self:buildGate() + return self.outputGate +end + +-- cell(t) = cellLayer{input, output(t-1), cell(t-1)} +-- output(t) = outputGate{input, output(t-1), cell(t)}*tanh(cell(t)) +-- output of Model is table : {output(t), cell(t)} +function LSTM:buildModel() + -- build components + self.cellLayer = self:buildCell() + self.outputGate = self:buildOutputGate() + -- assemble + local concat = nn.ConcatTable() + local concat2 = nn.ConcatTable() + concat2:add(nn.SelectTable(1)):add(nn.SelectTable(2)) + concat:add(concat2):add(self.cellLayer) + local model = nn.Sequential() + model:add(concat) + -- output of concat is {{input, output}, cell(t)}, + -- so flatten to {input, output, cell(t)} + model:add(nn.FlattenTable()) + local cellAct = nn.Sequential() + cellAct:add(nn.SelectTable(3)) + cellAct:add(nn.Tanh()) + local concat3 = nn.ConcatTable() + concat3:add(self.outputGate):add(cellAct) + local output = nn.Sequential() + output:add(concat3) + output:add(nn.CMulTable()) + -- we want the model to output : {output(t), cell(t)} + local concat4 = nn.ConcatTable() + concat4:add(output):add(nn.SelectTable(3)) + model:add(concat4) + return model +end + +------------------------- forward backward ----------------------------- +function LSTM:updateOutput(input) + local prevOutput, prevCell + if self.step == 1 then + prevOutput = self.zeroTensor + prevCell = self.zeroTensor + if input:dim() == 2 then + self.zeroTensor:resize(input:size(1), self.outputSize):zero() + else + self.zeroTensor:resize(self.outputSize):zero() + end + self.outputs[0] = self.zeroTensor + self.cells[0] = self.zeroTensor + else + -- previous output and cell of this module + prevOutput = self.output + prevCell = self.cell + end + + -- output(t), cell(t) = lstm{input(t), output(t-1), cell(t-1)} + local output, cell + if self.train ~= false then + -- set/save the output states + local modules = self.recurrentModule:listModules() + self:recycle() + local recurrentOutputs = self.recurrentOutputs[self.step] + if not recurrentOutputs then + recurrentOutputs = {} + self.recurrentOutputs[self.step] = recurrentOutputs + end + for i,modula in ipairs(modules) do + local output_ = self.recursiveResizeAs(recurrentOutputs[i], modula.output) + modula.output = output_ + end + -- the actual forward propagation + output, cell = unpack(self.recurrentModule:updateOutput{input, prevOutput, prevCell}) + + for i,modula in ipairs(modules) do + recurrentOutputs[i] = modula.output + end + else + output, cell = unpack(self.recurrentModule:updateOutput{input, prevOutput, prevCell}) + end + + if self.train ~= false then + local input_ = self.inputs[self.step] + self.inputs[self.step] = self.copyInputs + and self.recursiveCopy(input_, input) + or self.recursiveSet(input_, input) + end + + self.outputs[self.step] = output + self.cells[self.step] = cell + + self.output = output + self.cell = cell + + self.step = self.step + 1 + self.gradParametersAccumulated = false + -- note that we don't return the cell, just the output + return self.output +end + +function LSTM:backwardThroughTime() + assert(self.step > 1, "expecting at least one updateOutput") + self.gradInputs = {} + local rho = math.min(self.rho, self.step-1) + local stop = self.step - rho + if self.fastBackward then + local gradInput, gradPrevOutput, gradCell + for step=self.step-1,math.max(stop,1),-1 do + -- set the output/gradOutput states of current Module + local modules = self.recurrentModule:listModules() + local recurrentOutputs = self.recurrentOutputs[step] + local recurrentGradInputs = self.recurrentGradInputs[step] + if not recurrentGradInputs then + recurrentGradInputs = {} + self.recurrentGradInputs[step] = recurrentGradInputs + end + + for i,modula in ipairs(modules) do + local output, gradInput = modula.output, modula.gradInput + assert(gradInput, "missing gradInput") + local output_ = recurrentOutputs[i] + assert(output_, "backwardThroughTime should be preceded by updateOutput") + modula.output = output_ + modula.gradInput = self.recursiveResizeAs(recurrentGradInputs[i], gradInput) --resize, NOT copy + end + + -- backward propagate through this step + local gradOutput = self.gradOutputs[step] + if gradPrevOutput then + self.recursiveAdd(gradOutput, gradPrevOutput) + end + + self.gradCells[step] = gradCell + local scale = self.scales[step]/rho + + local inputTable = {self.inputs[step], self.outputs[step-1], self.cells[step-1]} + local gradInputTable = self.recurrentModule:backward(inputTable, {gradOutput, gradCell}, scale) + gradInput, gradPrevOutput, gradCell = unpack(gradInputTable) + table.insert(self.gradInputs, 1, gradInput) + + for i,modula in ipairs(modules) do + recurrentGradInputs[i] = modula.gradInput + end + end + return gradInput + else + local gradInput = self:updateGradInputThroughTime() + self:accGradParametersThroughTime() + return gradInput + end +end + +function LSTM:updateGradInputThroughTime() + assert(self.step > 1, "expecting at least one updateOutput") + self.gradInputs = {} + local gradInput, gradPrevOutput + local gradCell = self.zeroTensor + local rho = math.min(self.rho, self.step-1) + local stop = self.step - rho + for step=self.step-1,math.max(stop,1),-1 do + -- set the output/gradOutput states of current Module + local modules = self.recurrentModule:listModules() + local recurrentOutputs = self.recurrentOutputs[step] + local recurrentGradInputs = self.recurrentGradInputs[step] + if not recurrentGradInputs then + recurrentGradInputs = {} + self.recurrentGradInputs[step] = recurrentGradInputs + end + for i,modula in ipairs(modules) do + local output, gradInput = modula.output, modula.gradInput + local output_ = recurrentOutputs[i] + assert(output_, "updateGradInputThroughTime should be preceded by updateOutput") + modula.output = output_ + modula.gradInput = self.recursiveResizeAs(recurrentGradInputs[i], gradInput) + end + + -- backward propagate through this step + local gradOutput = self.gradOutputs[step] + if gradPrevOutput then + self.recursiveAdd(gradOutput, gradPrevOutput) + end + + self.gradCells[step] = gradCell + local scale = self.scales[step]/rho + local inputTable = {self.inputs[step], self.outputs[step-1], self.cells[step-1]} + local gradInputTable = self.recurrentModule:updateGradInput(inputTable, {gradOutput, gradCell}, scale) + gradInput, gradPrevOutput, gradCell = unpack(gradInputTable) + table.insert(self.gradInputs, 1, gradInput) + + for i,modula in ipairs(modules) do + recurrentGradInputs[i] = modula.gradInput + end + end + + return gradInput +end + +function LSTM:accGradParametersThroughTime() + local rho = math.min(self.rho, self.step-1) + local stop = self.step - rho + for step=self.step-1,math.max(stop,1),-1 do + -- set the output/gradOutput states of current Module + local modules = self.recurrentModule:listModules() + local recurrentOutputs = self.recurrentOutputs[step] + local recurrentGradInputs = self.recurrentGradInputs[step] + + for i,modula in ipairs(modules) do + local output, gradInput = modula.output, modula.gradInput + local output_ = recurrentOutputs[i] + local gradInput_ = recurrentGradInputs[i] + assert(output_, "accGradParametersThroughTime should be preceded by updateOutput") + assert(gradInput_, "accGradParametersThroughTime should be preceded by updateGradInputThroughTime") + modula.output = output_ + modula.gradInput = gradInput_ + end + + -- backward propagate through this step + local scale = self.scales[step]/rho + local inputTable = {self.inputs[step], self.outputs[step-1], self.cells[step-1]} + local gradOutputTable = {self.gradOutputs[step], self.gradCells[step]} + self.recurrentModule:accGradParameters(inputTable, gradOutputTable, scale) + end + + self.gradParametersAccumulated = true + return gradInput +end + +function LSTM:accUpdateGradParametersThroughTime(lr) + local rho = math.min(self.rho, self.step-1) + local stop = self.step - rho + for step=self.step-1,math.max(stop,1),-1 do + -- set the output/gradOutput states of current Module + local modules = self.recurrentModule:listModules() + local recurrentOutputs = self.recurrentOutputs[step] + local recurrentGradInputs = self.recurrentGradInputs[step] + + for i,modula in ipairs(modules) do + local output, gradInput = modula.output, modula.gradInput + local output_ = recurrentOutputs[i] + local gradInput_ = recurrentGradInputs[i] + assert(output_, "accGradParametersThroughTime should be preceded by updateOutput") + assert(gradInput_, "accGradParametersThroughTime should be preceded by updateGradInputThroughTime") + modula.output = output_ + modula.gradInput = gradInput_ + end + + -- backward propagate through this step + local scale = self.scales[step]/rho + local inputTable = {self.inputs[step], self.outputs[step-1], self.cells[step]} + local gradOutputTable = {self.gradOutputs[step], self.gradCells[step]} + self.recurrentModule:accUpdateGradParameters(inputTable, gradOutputTable, lr*scale) + end + + return gradInput +end + diff --git a/Recurrent.lua b/Recurrent.lua index 859ce19..495d3a2 100644 --- a/Recurrent.lua +++ b/Recurrent.lua @@ -14,10 +14,10 @@ -- output attribute to keep track of their internal state between -- forward and backward. ------------------------------------------------------------------------ -local Recurrent, parent = torch.class('nn.Recurrent', 'nn.Module') +local Recurrent, parent = torch.class('nn.Recurrent', 'nn.AbstractRecurrent') function Recurrent:__init(start, input, feedback, transfer, rho, merge) - parent.__init(self) + parent.__init(self, rho or 5) local ts = torch.type(start) if ts == 'torch.LongTensor' or ts == 'number' then @@ -29,7 +29,6 @@ function Recurrent:__init(start, input, feedback, transfer, rho, merge) self.feedbackModule = feedback self.transferModule = transfer or nn.Sigmoid() self.mergeModule = merge or nn.CAddTable() - self.rho = rho or 5 self.modules = {self.startModule, self.inputModule, self.feedbackModule, self.transferModule, self.mergeModule} @@ -38,22 +37,6 @@ function Recurrent:__init(start, input, feedback, transfer, rho, merge) self.initialOutputs = {} self.initialGradInputs = {} - self.recurrentOutputs = {} - self.recurrentGradInputs = {} - - self.fastBackward = true - self.copyInputs = true - - self.inputs = {} - self.outputs = {} - self.gradOutputs = {} - self.gradInputs = {} - self.scales = {} - - self.gradParametersAccumulated = false - self.step = 1 - - self:reset() end -- build module used for the first step (steps == 1) @@ -75,69 +58,6 @@ function Recurrent:buildRecurrentModule() self.recurrentModule:add(self.transferModule) end -local function recursiveResizeAs(t1,t2) - if torch.type(t2) == 'table' then - t1 = (torch.type(t1) == 'table') and t1 or {t1} - for key,_ in pairs(t2) do - t1[key], t2[key] = recursiveResizeAs(t1[key], t2[key]) - end - elseif torch.isTensor(t2) then - t1 = t1 or t2.new() - t1:resizeAs(t2) - else - error("expecting nested tensors or tables. Got ".. - torch.type(t1).." and "..torch.type(t2).." instead") - end - return t1, t2 -end - -local function recursiveSet(t1,t2) - if torch.type(t2) == 'table' then - t1 = (torch.type(t1) == 'table') and t1 or {t1} - for key,_ in pairs(t2) do - t1[key], t2[key] = recursiveSet(t1[key], t2[key]) - end - elseif torch.isTensor(t2) then - t1 = t1 or t2.new() - t1:set(t2) - else - error("expecting nested tensors or tables. Got ".. - torch.type(t1).." and "..torch.type(t2).." instead") - end - return t1, t2 -end - -local function recursiveCopy(t1,t2) - if torch.type(t2) == 'table' then - t1 = (torch.type(t1) == 'table') and t1 or {t1} - for key,_ in pairs(t2) do - t1[key], t2[key] = recursiveCopy(t1[key], t2[key]) - end - elseif torch.isTensor(t2) then - t1 = t1 or t2.new() - t1:resizeAs(t2):copy(t2) - else - error("expecting nested tensors or tables. Got ".. - torch.type(t1).." and "..torch.type(t2).." instead") - end - return t1, t2 -end - -local function recursiveAdd(t1, t2) - if torch.type(t2) == 'table' then - t1 = (torch.type(t1) == 'table') and t1 or {t1} - for key,_ in pairs(t2) do - t1[key], t2[key] = recursiveAdd(t1[key], t2[key]) - end - elseif torch.isTensor(t2) and torch.isTensor(t2) then - t1:add(t2) - else - error("expecting nested tensors or tables. Got ".. - torch.type(t1).." and "..torch.type(t2).." instead") - end - return t1, t2 -end - function Recurrent:updateOutput(input) -- output(t) = transfer(feedback(output_(t-1)) + input(input_(t))) local output @@ -145,7 +65,7 @@ function Recurrent:updateOutput(input) -- set/save the output states local modules = self.initialModule:listModules() for i,modula in ipairs(modules) do - local output_ = recursiveResizeAs(self.initialOutputs[i], modula.output) + local output_ = self.recursiveResizeAs(self.initialOutputs[i], modula.output) modula.output = output_ end output = self.initialModule:updateOutput(input) @@ -163,7 +83,7 @@ function Recurrent:updateOutput(input) self.recurrentOutputs[self.step] = recurrentOutputs end for i,modula in ipairs(modules) do - local output_ = recursiveResizeAs(recurrentOutputs[i], modula.output) + local output_ = self.recursiveResizeAs(recurrentOutputs[i], modula.output) modula.output = output_ end -- self.output is the previous output of this module @@ -179,18 +99,9 @@ function Recurrent:updateOutput(input) if self.train ~= false then local input_ = self.inputs[self.step] - if self.copyInputs then - input_ = recursiveCopy(input_, input) - else - input_:set(input) - end - end - - if self.train ~= false then - local input_ = self.inputs[self.step] self.inputs[self.step] = self.copyInputs - and recursiveCopy(input_, input) - or recursiveSet(input_, input) + and self.recursiveCopy(input_, input) + or self.recursiveSet(input_, input) end self.outputs[self.step] = output @@ -200,18 +111,6 @@ function Recurrent:updateOutput(input) return self.output end -function Recurrent:updateGradInput(input, gradOutput) - -- Back-Propagate Through Time (BPTT) happens in updateParameters() - -- for now we just keep a list of the gradOutputs - self.gradOutputs[self.step-1] = recursiveCopy(self.gradOutputs[self.step-1] , gradOutput) -end - -function Recurrent:accGradParameters(input, gradOutput, scale) - -- Back-Propagate Through Time (BPTT) happens in updateParameters() - -- for now we just keep a list of the scales - self.scales[self.step-1] = scale -end - -- not to be confused with the hit movie Back to the Future function Recurrent:backwardThroughTime() assert(self.step > 1, "expecting at least one updateOutput") @@ -219,7 +118,7 @@ function Recurrent:backwardThroughTime() local stop = self.step - rho if self.fastBackward then self.gradInputs = {} - local gradInput + local gradInput, gradPrevOutput for step=self.step-1,math.max(stop, 2),-1 do -- set the output/gradOutput states of current Module local modules = self.recurrentModule:listModules() @@ -235,19 +134,19 @@ function Recurrent:backwardThroughTime() local output_ = recurrentOutputs[i] assert(output_, "backwardThroughTime should be preceded by updateOutput") modula.output = output_ - modula.gradInput = recursiveResizeAs(recurrentGradInputs[i], gradInput) + modula.gradInput = self.recursiveResizeAs(recurrentGradInputs[i], gradInput) end -- backward propagate through this step local input = self.inputs[step] local output = self.outputs[step-1] local gradOutput = self.gradOutputs[step] - if gradInput then - recursiveAdd(gradOutput, gradInput) + if gradPrevOutput then + self.recursiveAdd(gradOutput, gradPrevOutput) end local scale = self.scales[step] - gradInput = self.recurrentModule:backward({input, output}, gradOutput, scale/rho)[2] + gradInput, gradPrevOutput = unpack(self.recurrentModule:backward({input, output}, gradOutput, scale/rho)) table.insert(self.gradInputs, 1, gradInput) for i,modula in ipairs(modules) do @@ -260,14 +159,14 @@ function Recurrent:backwardThroughTime() local modules = self.initialModule:listModules() for i,modula in ipairs(modules) do modula.output = self.initialOutputs[i] - modula.gradInput = recursiveResizeAs(self.initialGradInputs[i], modula.gradInput) + modula.gradInput = self.recursiveResizeAs(self.initialGradInputs[i], modula.gradInput) end -- backward propagate through first step local input = self.inputs[1] local gradOutput = self.gradOutputs[1] - if gradInput then - recursiveAdd(gradOutput, gradInput) + if gradPrevOutput then + self.recursiveAdd(gradOutput, gradPrevOutput) end local scale = self.scales[1] gradInput = self.initialModule:backward(input, gradOutput, scale/rho) @@ -296,16 +195,10 @@ function Recurrent:backwardThroughTime() end end -function Recurrent:backwardUpdateThroughTime(learningRate) - local gradInput = self:updateGradInputThroughTime() - self:accUpdateGradParametersThroughTime(learningRate) - return gradInput -end - function Recurrent:updateGradInputThroughTime() assert(self.step > 1, "expecting at least one updateOutput") self.gradInputs = {} - local gradInput + local gradInput, gradPrevOutput local rho = math.min(self.rho, self.step-1) local stop = self.step - rho for step=self.step-1,math.max(stop,2),-1 do @@ -322,21 +215,23 @@ function Recurrent:updateGradInputThroughTime() local output_ = recurrentOutputs[i] assert(output_, "updateGradInputThroughTime should be preceded by updateOutput") modula.output = output_ - modula.gradInput = recursiveResizeAs(recurrentGradInputs[i], gradInput) + modula.gradInput = self.recursiveResizeAs(recurrentGradInputs[i], gradInput) end -- backward propagate through this step local input = self.inputs[step] local output = self.outputs[step-1] local gradOutput = self.gradOutputs[step] - if gradInput then - gradOutput:add(gradInput) + if gradPrevOutput then + self.recursiveAdd(gradOutput, gradPrevOutput) end - gradInput = self.recurrentModule:updateGradInput({input, output}, gradOutput)[2] + + gradInput, gradPrevOutput = unpack(self.recurrentModule:updateGradInput({input, output}, gradOutput)) + table.insert(self.gradInputs, 1, gradInput) + for i,modula in ipairs(modules) do recurrentGradInputs[i] = modula.gradInput end - table.insert(self.gradInputs, 1, gradInput) end if stop <= 1 then @@ -344,21 +239,21 @@ function Recurrent:updateGradInputThroughTime() local modules = self.initialModule:listModules() for i,modula in ipairs(modules) do modula.output = self.initialOutputs[i] - modula.gradInput = recursiveResizeAs(self.initialGradInputs[i], modula.gradInput) + modula.gradInput = self.recursiveResizeAs(self.initialGradInputs[i], modula.gradInput) end -- backward propagate through first step local input = self.inputs[1] local gradOutput = self.gradOutputs[1] - if gradInput then - gradOutput:add(gradInput) + if gradPrevOutput then + self.recursiveAdd(gradOutput, gradPrevOutput) end gradInput = self.initialModule:updateGradInput(input, gradOutput) + table.insert(self.gradInputs, 1, gradInput) for i,modula in ipairs(modules) do self.initialGradInputs[i] = modula.gradInput end - table.insert(self.gradInputs, 1, gradInput) end return gradInput @@ -390,7 +285,6 @@ function Recurrent:accGradParametersThroughTime() local scale = self.scales[step] self.recurrentModule:accGradParameters({input, output}, gradOutput, scale/rho) - end if stop <= 1 then @@ -475,154 +369,36 @@ function Recurrent:accUpdateGradParametersThroughTime(lr) return gradInput end -function Recurrent:updateParameters(learningRate) - if self.gradParametersAccumulated then - for i=1,#self.modules do - self.modules[i]:updateParameters(learningRate) - end - else - self:backwardUpdateThroughTime(learningRate) - end -end - --- goes hand in hand with the next method : forget() -function Recurrent:recycle() - -- +1 is to skip initialModule - if self.step > self.rho + 1 then - assert(self.recurrentOutputs[self.step] == nil) - assert(self.recurrentOutputs[self.step-self.rho] ~= nil) - self.recurrentOutputs[self.step] = self.recurrentOutputs[self.step-self.rho] - self.recurrentGradInputs[self.step] = self.recurrentGradInputs[self.step-self.rho] - self.recurrentOutputs[self.step-self.rho] = nil - self.recurrentGradInputs[self.step-self.rho] = nil - -- need to keep rho+1 of these - self.outputs[self.step] = self.outputs[self.step-self.rho-1] - self.outputs[self.step-self.rho-1] = nil - end - if self.step > self.rho then - assert(self.inputs[self.step] == nil) - assert(self.inputs[self.step-self.rho] ~= nil) - self.inputs[self.step] = self.inputs[self.step-self.rho] - self.gradOutputs[self.step] = self.gradOutputs[self.step-self.rho] - self.inputs[self.step-self.rho] = nil - self.gradOutputs[self.step-self.rho] = nil - self.scales[self.step-self.rho] = nil - end -end - function Recurrent:forget() - - if self.train ~= false then - -- bring all states back to the start of the sequence buffers - local lastStep = self.step - 1 - - if lastStep > self.rho + 1 then - local i = 2 - for step = lastStep-self.rho+1,lastStep do - self.recurrentOutputs[i] = self.recurrentOutputs[step] - self.recurrentGradInputs[i] = self.recurrentGradInputs[step] - self.recurrentOutputs[step] = nil - self.recurrentGradInputs[step] = nil - -- we keep rho+1 of these : outputs[k]=outputs[k+rho+1] - self.outputs[i-1] = self.outputs[step] - self.outputs[step] = nil - i = i + 1 - end - - end - - if lastStep > self.rho then - local i = 1 - for step = lastStep-self.rho+1,lastStep do - self.inputs[i] = self.inputs[step] - self.gradOutputs[i] = self.gradOutputs[step] - self.inputs[step] = nil - self.gradOutputs[step] = nil - self.scales[step] = nil - i = i + 1 - end - - end - end - - -- forget the past inputs; restart from first step - self.step = 1 -end - -function Recurrent:size() - return #self.modules -end - -function Recurrent:get(index) - return self.modules[index] -end - -function Recurrent:zeroGradParameters() - for i=1,#self.modules do - self.modules[i]:zeroGradParameters() - end -end - -function Recurrent:training() - for i=1,#self.modules do - self.modules[i]:training() - end -end - -function Recurrent:evaluate() - for i=1,#self.modules do - self.modules[i]:evaluate() - end -end - -function Recurrent:share(mlp,...) - for i=1,#self.modules do - self.modules[i]:share(mlp.modules[i],...); - end -end - -function Recurrent:reset(stdv) - self:forget() - for i=1,#self.modules do - self.modules[i]:reset(stdv) - end -end - -function Recurrent:parameters() - local function tinsert(to, from) - if type(from) == 'table' then - for i=1,#from do - tinsert(to,from[i]) - end - else - table.insert(to,from) - end - end - local w = {} - local gw = {} - for i=1,#self.modules do - local mw,mgw = self.modules[i]:parameters() - if mw then - tinsert(w,mw) - tinsert(gw,mgw) - end - end - return w,gw + parent.forget(self, 1) end function Recurrent:__tostring__() local tab = ' ' local line = '\n' local next = ' -> ' - local str = 'nn.Recurrent' - str = str .. ' {' .. line .. tab .. '[input' - for i=1,#self.modules do + local str = torch.type(self) + str = str .. ' {' .. line .. tab .. '[{input(t), output(t-1)}' + for i=1,3 do str = str .. next .. '(' .. i .. ')' end - str = str .. next .. 'output]' - for i=1,#self.modules do - str = str .. line .. tab .. '(' .. i .. '): ' .. tostring(self.modules[i]):gsub(line, line .. tab) - end + str = str .. next .. 'output(t)]' + + local tab = ' ' + local line = '\n ' + local next = ' |`-> ' + local ext = ' | ' + local last = ' ... -> ' + str = str .. line .. '(1): ' .. ' {' .. line .. tab .. 'input(t)' + str = str .. line .. tab .. next .. '(t==0): ' .. tostring(self.startModule):gsub('\n', '\n' .. tab .. ext) + str = str .. line .. tab .. next .. '(t~=0): ' .. tostring(self.inputModule):gsub('\n', '\n' .. tab .. ext) + str = str .. line .. tab .. 'output(t-1)' + str = str .. line .. tab .. next .. tostring(self.feedbackModule):gsub('\n', line .. tab .. ext) + local tab = ' ' + local line = '\n' + local next = ' -> ' + str = str .. line .. tab .. '(' .. 2 .. '): ' .. tostring(self.mergeModule):gsub(line, line .. tab) + str = str .. line .. tab .. '(' .. 3 .. '): ' .. tostring(self.transferModule):gsub(line, line .. tab) str = str .. line .. '}' return str end diff --git a/RepeaterCriterion.lua b/RepeaterCriterion.lua index 44bf078..a6ad078 100644 --- a/RepeaterCriterion.lua +++ b/RepeaterCriterion.lua @@ -13,6 +13,7 @@ function RepeaterCriterion:__init(criterion) end function RepeaterCriterion:forward(inputTable, target) + self.output = 0 for i,input in ipairs(inputTable) do self.output = self.output + self.criterion:forward(input, target) end diff --git a/ZeroGrad.lua b/ZeroGrad.lua new file mode 100644 index 0000000..83f88d2 --- /dev/null +++ b/ZeroGrad.lua @@ -0,0 +1,28 @@ +local ZeroGrad, parent = torch.class("nn.ZeroGrad", "nn.Module") + +local function recursiveZero(t1,t2) + if torch.type(t2) == 'table' then + t1 = (torch.type(t1) == 'table') and t1 or {t1} + for key,_ in pairs(t2) do + t1[key], t2[key] = recursiveZero(t1[key], t2[key]) + end + elseif torch.isTensor(t2) then + t1 = t1 or t2.new() + t1:resizeAs(t2):zero() + else + error("expecting nested tensors or tables. Got ".. + torch.type(t1).." and "..torch.type(t2).." instead") + end + return t1, t2 +end + +function ZeroGrad:updateOutput(input) + self.output:set(input) + return self.output +end + +-- the gradient is simply zeroed. +-- useful when you don't want to backpropgate through certain paths. +function ZeroGrad:updateGradInput(input, gradOutput) + self.gradInput = recursiveZero(self.gradInput, gradOutput) +end @@ -75,13 +75,18 @@ torch.include('nnx', 'SoftMaxTree.lua') torch.include('nnx', 'MultiSoftMax.lua') torch.include('nnx', 'Balance.lua') torch.include('nnx', 'NarrowLookupTable.lua') -torch.include('nnx', 'Recurrent.lua') -torch.include('nnx', 'Repeater.lua') -torch.include('nnx', 'Sequencer.lua') torch.include('nnx', 'PushTable.lua') torch.include('nnx', 'PullTable.lua') +torch.include('nnx', 'ZeroGrad.lua') torch.include('nnx', 'Padding.lua') +-- recurrent +torch.include('nnx', 'AbstractRecurrent.lua') +torch.include('nnx', 'Recurrent.lua') +torch.include('nnx', 'LSTM.lua') +torch.include('nnx', 'Repeater.lua') +torch.include('nnx', 'Sequencer.lua') + -- criterions: torch.include('nnx', 'SuperCriterion.lua') torch.include('nnx', 'DistNLLCriterion.lua') diff --git a/test/test-all.lua b/test/test-all.lua index 5abd4e2..16f819f 100644 --- a/test/test-all.lua +++ b/test/test-all.lua @@ -342,6 +342,15 @@ function nnxtest.Recurrent() -- rho = nSteps local mlp = nn.Recurrent(outputSize, inputModule, feedbackModule, transferModule:clone(), nSteps) + -- test that the internal mlps are recursable : + local isRecursable = nn.AbstractRecurrent.isRecursable + mytester:assert(isRecursable(mlp.initialModule, torch.randn(inputSize)), "Recurrent isRecursable() initial error") + mytester:assert(isRecursable(mlp.recurrentModule, {torch.randn(inputSize), torch.randn(outputSize)}), "Recurrent isRecursable() recurrent error") + + -- test that the above test actually works + local euclidean = nn.Euclidean(inputSize, outputSize) + mytester:assert(not isRecursable(euclidean, torch.randn(batchSize, inputSize)), "AbstractRecurrent.isRecursable error") + local gradOutputs, outputs = {}, {} -- inputs = {inputN, {inputN-1, {inputN-2, ...}}}}} local inputs @@ -355,8 +364,10 @@ function nnxtest.Recurrent() mlp:zeroGradParameters() local mlp7 = mlp:clone() mlp7.rho = nSteps - 1 + local inputSequence = {} for step=1,nSteps do local input = torch.randn(batchSize, inputSize) + inputSequence[step] = input local gradOutput if step ~= nSteps then -- for the sake of keeping this unit test simple, @@ -389,7 +400,18 @@ function nnxtest.Recurrent() local mlp5 = mlp:clone() -- backward propagate through time (BPTT) - local gradInput = mlp:backwardThroughTime() + local gradInput = mlp:backwardThroughTime():clone() + mlp:forget() -- test ability to forget + mlp:zeroGradParameters() + local foutputs = {} + for step=1,nSteps do + foutputs[step] = mlp:forward(inputSequence[step]) + mytester:assertTensorEq(foutputs[step], outputs[step], 0.00001, "Recurrent forget output error "..step) + mlp:backward(input, gradOutputs[step]) + end + local fgradInput = mlp:backwardThroughTime():clone() + mytester:assertTensorEq(gradInput, fgradInput, 0.00001, "Recurrent forget gradInput error") + mlp4.fastBackward = false local gradInput4 = mlp4:backwardThroughTime() mytester:assertTensorEq(gradInput, gradInput4, 0.000001, 'error slow vs fast backwardThroughTime') @@ -564,6 +586,114 @@ function nnxtest.Recurrent_TestTable() mlp:backwardThroughTime(learningRate) end +function nnxtest.LSTM() + local batchSize = math.random(1,2) + local inputSize = math.random(3,4) + local outputSize = math.random(5,6) + local nStep = 3 + local input = {} + local gradOutput = {} + for step=1,nStep do + input[step] = torch.randn(batchSize, inputSize) + if step == nStep then + -- for the sake of keeping this unit test simple, + gradOutput[step] = torch.randn(batchSize, outputSize) + else + -- only the last step will get a gradient from the output + gradOutput[step] = torch.zeros(batchSize, outputSize) + end + end + local lstm = nn.LSTM(inputSize, outputSize) + + local isRecursable = nn.AbstractRecurrent.isRecursable + local inputTable = {torch.randn(batchSize, inputSize), torch.randn(batchSize, outputSize), torch.randn(batchSize, outputSize)} + mytester:assert(isRecursable(lstm.recurrentModule, inputTable), "LSTM isRecursable() error") + + -- we will use this to build an LSTM step by step (with shared params) + local lstmStep = lstm.recurrentModule:clone() + + -- forward/backward through LSTM + local output = {} + lstm:zeroGradParameters() + for step=1,nStep do + output[step] = lstm:forward(input[step]) + assert(torch.isTensor(input[step])) + lstm:backward(input[step], gradOutput[step], 1) + end + local gradInput = lstm:backwardThroughTime() + + local mlp2 -- this one will simulate rho = nSteps + local inputs + for step=1,nStep do + -- iteratively build an LSTM out of non-recurrent components + local lstm = lstmStep:clone() + lstm:share(lstmStep, 'weight', 'gradWeight', 'bias', 'gradBias') + if step == 1 then + mlp2 = lstm + else + local rnn = nn.Sequential() + local para = nn.ParallelTable() + para:add(nn.Identity()):add(mlp2) + rnn:add(para) + rnn:add(nn.FlattenTable()) + rnn:add(lstm) + mlp2 = rnn + end + + -- prepare inputs for mlp2 + if inputs then + inputs = {input[step], inputs} + else + inputs = {input[step], torch.zeros(batchSize, outputSize), torch.zeros(batchSize, outputSize)} + end + end + mlp2:add(nn.SelectTable(1)) --just output the output (not cell) + local output2 = mlp2:forward(inputs) + + mlp2:zeroGradParameters() + local gradInput2 = mlp2:backward(inputs, gradOutput[nStep], 1/nStep) + mytester:assertTensorEq(gradInput2[2][2][1], gradInput, 0.00001, "LSTM gradInput error") + mytester:assertTensorEq(output[nStep], output2, 0.00001, "LSTM output error") + + local params, gradParams = lstm:parameters() + local params2, gradParams2 = lstmStep:parameters() + mytester:assert(#params == #params2, "LSTM parameters error "..#params.." ~= "..#params2) + for i, gradParam in ipairs(gradParams) do + local gradParam2 = gradParams2[i] + mytester:assertTensorEq(gradParam, gradParam2, 0.000001, + "LSTM gradParam "..i.." error "..tostring(gradParam).." "..tostring(gradParam2)) + end + + gradParams = lstm.recursiveCopy(nil, gradParams) + gradInput = gradInput:clone() + mytester:assert(lstm.zeroTensor:sum() == 0, "zeroTensor error") + lstm:forget() + output = lstm.recursiveCopy(nil, output) + local output3 = {} + lstm:zeroGradParameters() + for step=1,nStep do + output3[step] = lstm:forward(input[step]) + lstm:backward(input[step], gradOutput[step], 1) + end + local gradInput3 = lstm:updateGradInputThroughTime() + lstm:accGradParametersThroughTime() + + mytester:assert(#output == #output3, "LSTM output size error") + for i,output in ipairs(output) do + mytester:assertTensorEq(output, output3[i], 0.00001, "LSTM forget (updateOutput) error "..i) + end + + mytester:assertTensorEq(gradInput, gradInput3, 0.00001, "LSTM updateGradInputThroughTime error") + --if true then return end + local params3, gradParams3 = lstm:parameters() + mytester:assert(#params == #params3, "LSTM parameters error "..#params.." ~= "..#params3) + for i, gradParam in ipairs(gradParams) do + local gradParam3 = gradParams3[i] + mytester:assertTensorEq(gradParam, gradParam3, 0.000001, + "LSTM gradParam "..i.." error "..tostring(gradParam).." "..tostring(gradParam3)) + end +end + function nnxtest.SpatialNormalization_Gaussian2D() local inputSize = math.random(11,20) local kersize = 9 |