diff options
author | nicholas-leonard <nick@nikopia.org> | 2015-03-16 00:09:23 +0300 |
---|---|---|
committer | nicholas-leonard <nick@nikopia.org> | 2015-03-16 00:09:23 +0300 |
commit | a54ac28e4261ee62645e4c80ea4db21b02923b6f (patch) | |
tree | 1ff396c9e99c94f5926854eb6ce2a4a533cbf1cd | |
parent | bbf319f81ea9808b7ab051dce120383e8c56e7e0 (diff) |
Recurrent and LSTM fixes
-rw-r--r-- | AbstractRecurrent.lua | 8 | ||||
-rw-r--r-- | LSTM.lua | 90 | ||||
-rw-r--r-- | Recurrent.lua | 36 | ||||
-rw-r--r-- | test/test-all.lua | 70 |
4 files changed, 143 insertions, 61 deletions
diff --git a/AbstractRecurrent.lua b/AbstractRecurrent.lua index b241aad..d11eb32 100644 --- a/AbstractRecurrent.lua +++ b/AbstractRecurrent.lua @@ -1,8 +1,10 @@ local AbstractRecurrent, parent = torch.class('nn.AbstractRecurrent', 'nn.Container') -function AbstractRecurrent:__init(start, input, feedback, transfer, rho, merge) +function AbstractRecurrent:__init(rho) parent.__init(self) + self.rho = rho --the maximum number of time steps to BPTT + self.fastBackward = true self.copyInputs = true @@ -14,6 +16,10 @@ function AbstractRecurrent:__init(start, input, feedback, transfer, rho, merge) self.gradParametersAccumulated = false self.step = 1 + -- stores internal states of Modules at different time-steps + self.recurrentOutputs = {} + self.recurrentGradInputs = {} + self:reset() end @@ -9,8 +9,9 @@ ------------------------------------------------------------------------ local LSTM, parent = torch.class('nn.LSTM', 'nn.AbstractRecurrent') -function LSTM:__init(inputSize, outputSize) - parent.__init(self) +function LSTM:__init(inputSize, outputSize, rho) + require 'dp' + parent.__init(self, rho or 999999999999) self.inputSize = inputSize self.outputSize = outputSize -- build the model @@ -21,18 +22,19 @@ function LSTM:__init(inputSize, outputSize) self.startOutput = torch.Tensor() self.startCell = torch.Tensor() self.cells = {} + self.gradCells = {} end -------------------------- factory methods ----------------------------- function LSTM:buildGate() - -- Note : inputGate:forward expects an input table : {input, output, cell} + -- Note : gate expects an input table : {input, output(t-1), cell(t-1)} local gate = nn.Sequential() local input2gate = nn.Linear(self.inputSize, self.outputSize) - local cell2gate = nn.CMul(self.outputSize) -- diagonal cell to gate weight matrix local output2gate = nn.Linear(self.outputSize, self.outputSize) - output2gate:noBias() --TODO + local cell2gate = nn.CMul(self.outputSize) -- diagonal cell to gate weight matrix + --output2gate:noBias() --TODO local para = nn.ParallelTable() - para:add(input2gate):add(cell2gate):add(output2gate) + para:add(input2gate):add(output2gate):add(cell2gate) gate:add(para) gate:add(nn.CAddTable()) gate:add(nn.Sigmoid()) @@ -40,11 +42,13 @@ function LSTM:buildGate() end function LSTM:buildInputGate() - return self:buildGate() + local gate = self:buildGate() + return gate end function LSTM:buildForgetGate() - return self:buildGate() + local gate = self:buildGate() + return gate end function LSTM:buildHidden() @@ -52,15 +56,14 @@ function LSTM:buildHidden() local input2hidden = nn.Linear(self.inputSize, self.outputSize) local output2hidden = nn.Linear(self.outputSize, self.outputSize) local para = nn.ParallelTable() - output2hidden:noBias() + --output2hidden:noBias() para:add(input2hidden):add(output2hidden) - -- input is {input, output, cell}, but we only need {input, output} + -- input is {input, output(t-1), cell(t-1)}, but we only need {input, output(t-1)} local concat = nn.ConcatTable() - concat:add(nn.SelectTable(1):add(nn.SelectTable(2)) + concat:add(nn.SelectTable(1)):add(nn.SelectTable(2)) hidden:add(concat) hidden:add(para) hidden:add(nn.CAddTable()) - hidden:add(nn.Tanh()) return hidden end @@ -72,7 +75,7 @@ function LSTM:buildCell() -- forget = forgetGate{input, output(t-1), cell(t-1)} * cell(t-1) local forget = nn.Sequential() local concat = nn.ConcatTable() - concat:add(self.forgetGate):add(self.SelectTable(3)) + concat:add(self.forgetGate):add(nn.SelectTable(3)) forget:add(concat) forget:add(nn.CMulTable()) -- input = inputGate{input, output(t-1), cell(t-1)} * hiddenLayer{input, output(t-1), cell(t-1)} @@ -87,14 +90,17 @@ function LSTM:buildCell() concat3:add(forget):add(input) cell:add(concat3) cell:add(nn.CAddTable()) + return cell end function LSTM:buildOutputGate() - return self:buildGate() + local gate = self:buildGate() + return gate end -- cell(t) = cellLayer{input, output(t-1), cell(t-1)} --- output = outputGate{input, output(t-1), cell(t)}*tanh(cell(t)) +-- output(t) = outputGate{input, output(t-1), cell(t)}*tanh(cell(t)) +-- output of Model is table : {output(t), cell(t)} function LSTM:buildModel() -- build components self.cellLayer = self:buildCell() @@ -102,24 +108,24 @@ function LSTM:buildModel() -- assemble local concat = nn.ConcatTable() local concat2 = nn.ConcatTable() - concat2:add(nn.SelectTable(1):add(nn.SelectTable(2)) + concat2:add(nn.SelectTable(1)):add(nn.SelectTable(2)) concat:add(concat2):add(self.cellLayer) local model = nn.Sequential() - model:add(concat2) - -- output of concat2 is {{input, output}, cell(t)}, + model:add(concat) + -- output of concat is {{input, output}, cell(t)}, -- so flatten to {input, output, cell(t)} model:add(nn.FlattenTable()) local cellAct = nn.Sequential() - cellAct:add(nn.Select(3)) + cellAct:add(nn.SelectTable(3)) cellAct:add(nn.Tanh()) local concat3 = nn.ConcatTable() concat3:add(self.outputGate):add(cellAct) - -- we want the model to output : {output(t), cell(t)} - local concat4 = nn.ConcatTable() local output = nn.Sequential() output:add(concat3) output:add(nn.CMulTable()) - concat4:add(output):add(nn.Identity()) + -- we want the model to output : {output(t), cell(t)} + local concat4 = nn.ConcatTable() + concat4:add(output):add(nn.SelectTable(3)) model:add(concat4) return model end @@ -131,9 +137,9 @@ function LSTM:updateOutput(input) prevOutput = self.startOutput prevCell = self.startCell if input:dim() == 2 then - self.startOutput:resize(input:size(1), self.outputSize) + self.startOutput:resize(input:size(1), self.outputSize):zero() else - self.startOutput:resize(self.outputSize) + self.startOutput:resize(self.outputSize):zero() end self.startCell:set(self.startOutput) else @@ -154,11 +160,12 @@ function LSTM:updateOutput(input) self.recurrentOutputs[self.step] = recurrentOutputs end for i,modula in ipairs(modules) do - local output_ = recursiveResizeAs(recurrentOutputs[i], modula.output) + local output_ = self.recursiveResizeAs(recurrentOutputs[i], modula.output) modula.output = output_ end -- the actual forward propagation - output, cell = self.recurrentModule:updateOutput{input, prevOutput, prevCell} + output = self.recurrentModule:updateOutput{input, prevOutput, prevCell} + output, cell = unpack(output) for i,modula in ipairs(modules) do recurrentOutputs[i] = modula.output @@ -192,7 +199,8 @@ function LSTM:backwardThroughTime() local rho = math.min(self.rho, self.step-1) local stop = self.step - rho if self.fastBackward then - local gradInput, gradCell + local gradInput, gradPrevOutput + local gradCell = self.startCell for step=self.step-1,math.max(stop,1),-1 do -- set the output/gradOutput states of current Module local modules = self.recurrentModule:listModules() @@ -208,20 +216,21 @@ function LSTM:backwardThroughTime() local output_ = recurrentOutputs[i] assert(output_, "backwardThroughTime should be preceded by updateOutput") modula.output = output_ - modula.gradInput = recursiveCopy(recurrentGradInputs[i], gradInput) + modula.gradInput = self.recursiveCopy(recurrentGradInputs[i], gradInput) end -- backward propagate through this step local gradOutput = self.gradOutputs[step] - if gradInput then - recursiveAdd(gradOutput, gradInput) + if gradPrevOutput then + assert(gradPrevOutput:sum() ~= 0) + self.recursiveAdd(gradOutput, gradPrevOutput) end + self.gradCells[self.step] = gradCell local scale = self.scales[step]/rho - local inputTable = {input, self.cells[step-1], self.outputs[step-1]} - local gradOutputTable = {gradOutput, self.gradCells[step]} - local gradInputTable = self.recurrentModule:backward(inputTable, gradOutputTable, scale) - gradInput, gradCell = unpack(gradInputTable) + local inputTable = {self.inputs[step], self.cells[step-1] or self.startOutput, self.outputs[step-1] or self.startCell} + local gradInputTable = self.recurrentModule:backward(inputTable, {gradOutput, gradCell}, scale) + gradInput, gradPrevOutput, gradCell = unpack(gradInputTable) table.insert(self.gradInputs, 1, gradInput) for i,modula in ipairs(modules) do @@ -239,7 +248,7 @@ end function LSTM:updateGradInputThroughTime() assert(self.step > 1, "expecting at least one updateOutput") self.gradInputs = {} - local gradInput + local gradInput, gradPrevOutput local gradCell = self.startCell local rho = math.min(self.rho, self.step-1) local stop = self.step - rho @@ -262,16 +271,15 @@ function LSTM:updateGradInputThroughTime() -- backward propagate through this step local gradOutput = self.gradOutputs[step] - if gradInput then - self.recursiveAdd(gradOutput, gradInput) + if gradPrevOutput then + self.recursiveAdd(gradOutput, gradPrevOutput) end self.gradCells[self.step] = gradCell local scale = self.scales[step]/rho - local inputTable = {self.inputs[step], self.cells[step-1], self.outputs[step-1]} - local gradOutputTable = {gradOutput, gradCell} - local gradInputTable = self.recurrentModule:backward(inputTable, gradOutputTable, scale) - gradInput, gradCell = unpack(gradInputTable) + local inputTable = {self.inputs[step], self.cells[step-1] or self.startOutput, self.outputs[step-1] or self.startCell} + local gradInputTable = self.recurrentModule:backward(inputTable, {gradOutput, gradCell}, scale) + gradInput, gradPrevOutput, gradCell = unpack(gradInputTable) table.insert(self.gradInputs, 1, gradInput) for i,modula in ipairs(modules) do diff --git a/Recurrent.lua b/Recurrent.lua index f908586..e80318a 100644 --- a/Recurrent.lua +++ b/Recurrent.lua @@ -17,7 +17,7 @@ local Recurrent, parent = torch.class('nn.Recurrent', 'nn.AbstractRecurrent') function Recurrent:__init(start, input, feedback, transfer, rho, merge) - parent.__init(self) + parent.__init(self, rho or 5) local ts = torch.type(start) if ts == 'torch.LongTensor' or ts == 'number' then @@ -29,7 +29,6 @@ function Recurrent:__init(start, input, feedback, transfer, rho, merge) self.feedbackModule = feedback self.transferModule = transfer or nn.Sigmoid() self.mergeModule = merge or nn.CAddTable() - self.rho = rho or 5 self.modules = {self.startModule, self.inputModule, self.feedbackModule, self.transferModule, self.mergeModule} @@ -38,8 +37,6 @@ function Recurrent:__init(start, input, feedback, transfer, rho, merge) self.initialOutputs = {} self.initialGradInputs = {} - self.recurrentOutputs = {} - self.recurrentGradInputs = {} end -- build module used for the first step (steps == 1) @@ -121,7 +118,7 @@ function Recurrent:backwardThroughTime() local stop = self.step - rho if self.fastBackward then self.gradInputs = {} - local gradInput + local gradInput, gradPrevOutput for step=self.step-1,math.max(stop, 2),-1 do -- set the output/gradOutput states of current Module local modules = self.recurrentModule:listModules() @@ -144,12 +141,12 @@ function Recurrent:backwardThroughTime() local input = self.inputs[step] local output = self.outputs[step-1] local gradOutput = self.gradOutputs[step] - if gradInput then - self.recursiveAdd(gradOutput, gradInput) + if gradPrevOutput then + self.recursiveAdd(gradOutput, gradPrevOutput) end local scale = self.scales[step] - gradInput = self.recurrentModule:backward({input, output}, gradOutput, scale/rho)[2] + gradInput, gradPrevOutput = unpack(self.recurrentModule:backward({input, output}, gradOutput, scale/rho)) table.insert(self.gradInputs, 1, gradInput) for i,modula in ipairs(modules) do @@ -168,8 +165,8 @@ function Recurrent:backwardThroughTime() -- backward propagate through first step local input = self.inputs[1] local gradOutput = self.gradOutputs[1] - if gradInput then - self.recursiveAdd(gradOutput, gradInput) + if gradPrevOutput then + self.recursiveAdd(gradOutput, gradPrevOutput) end local scale = self.scales[1] gradInput = self.initialModule:backward(input, gradOutput, scale/rho) @@ -201,7 +198,7 @@ end function Recurrent:updateGradInputThroughTime() assert(self.step > 1, "expecting at least one updateOutput") self.gradInputs = {} - local gradInput + local gradInput, gradPrevOutput local rho = math.min(self.rho, self.step-1) local stop = self.step - rho for step=self.step-1,math.max(stop,2),-1 do @@ -225,14 +222,16 @@ function Recurrent:updateGradInputThroughTime() local input = self.inputs[step] local output = self.outputs[step-1] local gradOutput = self.gradOutputs[step] - if gradInput then - self.recursiveAdd(gradOutput, gradInput) + if gradPrevOutput then + self.recursiveAdd(gradOutput, gradPrevOutput) end - gradInput = self.recurrentModule:updateGradInput({input, output}, gradOutput)[2] + + gradInput, gradPrevOutput = unpack(self.recurrentModule:updateGradInput({input, output}, gradOutput)) + table.insert(self.gradInputs, 1, gradInput) + for i,modula in ipairs(modules) do recurrentGradInputs[i] = modula.gradInput end - table.insert(self.gradInputs, 1, gradInput) end if stop <= 1 then @@ -246,15 +245,15 @@ function Recurrent:updateGradInputThroughTime() -- backward propagate through first step local input = self.inputs[1] local gradOutput = self.gradOutputs[1] - if gradInput then - self.recursiveAdd(gradOutput, gradInput) + if gradPrevOutput then + self.recursiveAdd(gradOutput, gradPrevOutput) end gradInput = self.initialModule:updateGradInput(input, gradOutput) + table.insert(self.gradInputs, 1, gradInput) for i,modula in ipairs(modules) do self.initialGradInputs[i] = modula.gradInput end - table.insert(self.gradInputs, 1, gradInput) end return gradInput @@ -286,7 +285,6 @@ function Recurrent:accGradParametersThroughTime() local scale = self.scales[step] self.recurrentModule:accGradParameters({input, output}, gradOutput, scale/rho) - end if stop <= 1 then diff --git a/test/test-all.lua b/test/test-all.lua index 5abd4e2..fd72b6e 100644 --- a/test/test-all.lua +++ b/test/test-all.lua @@ -564,6 +564,76 @@ function nnxtest.Recurrent_TestTable() mlp:backwardThroughTime(learningRate) end +function nnxtest.LSTM() + local batchSize = math.random(1,2) + local inputSize = math.random(3,4) + local outputSize = math.random(5,6) + local nStep = 3 + local input = {} + local gradOutput = {} + for step=1,nStep do + input[step] = torch.randn(batchSize, inputSize) + if step == nStep then + -- for the sake of keeping this unit test simple, + gradOutput[step] = torch.randn(batchSize, outputSize) + else + -- only the last step will get a gradient from the output + gradOutput[step] = torch.zeros(batchSize, outputSize) + end + end + local lstm = nn.LSTM(inputSize, outputSize) + + -- we will use this to build an LSTM step by step (with shared params) + local lstmStep = lstm.recurrentModule:clone() + + -- forward/backward through LSTM + local output = {} + lstm:zeroGradParameters() + for step=1,nStep do + output[step] = lstm:forward(input[step]) + assert(torch.isTensor(input[step])) + lstm:backward(input[step], gradOutput[step], 1) + end + local gradInput = lstm:backwardThroughTime() + + local mlp2 -- this one will simulate rho = nSteps + local inputs + for step=1,nStep do + -- iteratively build an LSTM out of non-recurrent components + local lstm = lstmStep:clone() + lstm:share(lstmStep) + lstm:share(lstmStep, 'weight', 'gradWeight', 'bias', 'gradBias') + if step == 1 then + mlp2 = lstm + else + local rnn = nn.Sequential() + local para = nn.ParallelTable() + para:add(nn.Identity()):add(mlp2) + rnn:add(para) + rnn:add(nn.FlattenTable()) + rnn:add(lstm) + mlp2 = rnn + end + + + -- prepare inputs for mlp2 + if inputs then + inputs = {input[step], inputs} + else + inputs = {input[step], torch.zeros(batchSize, outputSize), torch.zeros(batchSize, outputSize)} + end + end + mlp2:add(nn.SelectTable(1)) --just output the output (not cell) + + local output2 = mlp2:forward(inputs) + + + mlp2:zeroGradParameters() + local gradInput2 = mlp2:backward(inputs, gradOutput[nStep], 1/nStep) + mytester:assertTensorEq(gradInput2[2][2][1], gradInput, 0.00001, "LSTM gradInput error") + mytester:assertTensorEq(output[nStep], output2, 0.00001, "LSTM output error") +end + function nnxtest.SpatialNormalization_Gaussian2D() local inputSize = math.random(11,20) local kersize = 9 |