Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authornicholas-leonard <nick@nikopia.org>2015-03-16 00:09:23 +0300
committernicholas-leonard <nick@nikopia.org>2015-03-16 00:09:23 +0300
commita54ac28e4261ee62645e4c80ea4db21b02923b6f (patch)
tree1ff396c9e99c94f5926854eb6ce2a4a533cbf1cd
parentbbf319f81ea9808b7ab051dce120383e8c56e7e0 (diff)
Recurrent and LSTM fixes
-rw-r--r--AbstractRecurrent.lua8
-rw-r--r--LSTM.lua90
-rw-r--r--Recurrent.lua36
-rw-r--r--test/test-all.lua70
4 files changed, 143 insertions, 61 deletions
diff --git a/AbstractRecurrent.lua b/AbstractRecurrent.lua
index b241aad..d11eb32 100644
--- a/AbstractRecurrent.lua
+++ b/AbstractRecurrent.lua
@@ -1,8 +1,10 @@
local AbstractRecurrent, parent = torch.class('nn.AbstractRecurrent', 'nn.Container')
-function AbstractRecurrent:__init(start, input, feedback, transfer, rho, merge)
+function AbstractRecurrent:__init(rho)
parent.__init(self)
+ self.rho = rho --the maximum number of time steps to BPTT
+
self.fastBackward = true
self.copyInputs = true
@@ -14,6 +16,10 @@ function AbstractRecurrent:__init(start, input, feedback, transfer, rho, merge)
self.gradParametersAccumulated = false
self.step = 1
+ -- stores internal states of Modules at different time-steps
+ self.recurrentOutputs = {}
+ self.recurrentGradInputs = {}
+
self:reset()
end
diff --git a/LSTM.lua b/LSTM.lua
index a138289..fdca379 100644
--- a/LSTM.lua
+++ b/LSTM.lua
@@ -9,8 +9,9 @@
------------------------------------------------------------------------
local LSTM, parent = torch.class('nn.LSTM', 'nn.AbstractRecurrent')
-function LSTM:__init(inputSize, outputSize)
- parent.__init(self)
+function LSTM:__init(inputSize, outputSize, rho)
+ require 'dp'
+ parent.__init(self, rho or 999999999999)
self.inputSize = inputSize
self.outputSize = outputSize
-- build the model
@@ -21,18 +22,19 @@ function LSTM:__init(inputSize, outputSize)
self.startOutput = torch.Tensor()
self.startCell = torch.Tensor()
self.cells = {}
+ self.gradCells = {}
end
-------------------------- factory methods -----------------------------
function LSTM:buildGate()
- -- Note : inputGate:forward expects an input table : {input, output, cell}
+ -- Note : gate expects an input table : {input, output(t-1), cell(t-1)}
local gate = nn.Sequential()
local input2gate = nn.Linear(self.inputSize, self.outputSize)
- local cell2gate = nn.CMul(self.outputSize) -- diagonal cell to gate weight matrix
local output2gate = nn.Linear(self.outputSize, self.outputSize)
- output2gate:noBias() --TODO
+ local cell2gate = nn.CMul(self.outputSize) -- diagonal cell to gate weight matrix
+ --output2gate:noBias() --TODO
local para = nn.ParallelTable()
- para:add(input2gate):add(cell2gate):add(output2gate)
+ para:add(input2gate):add(output2gate):add(cell2gate)
gate:add(para)
gate:add(nn.CAddTable())
gate:add(nn.Sigmoid())
@@ -40,11 +42,13 @@ function LSTM:buildGate()
end
function LSTM:buildInputGate()
- return self:buildGate()
+ local gate = self:buildGate()
+ return gate
end
function LSTM:buildForgetGate()
- return self:buildGate()
+ local gate = self:buildGate()
+ return gate
end
function LSTM:buildHidden()
@@ -52,15 +56,14 @@ function LSTM:buildHidden()
local input2hidden = nn.Linear(self.inputSize, self.outputSize)
local output2hidden = nn.Linear(self.outputSize, self.outputSize)
local para = nn.ParallelTable()
- output2hidden:noBias()
+ --output2hidden:noBias()
para:add(input2hidden):add(output2hidden)
- -- input is {input, output, cell}, but we only need {input, output}
+ -- input is {input, output(t-1), cell(t-1)}, but we only need {input, output(t-1)}
local concat = nn.ConcatTable()
- concat:add(nn.SelectTable(1):add(nn.SelectTable(2))
+ concat:add(nn.SelectTable(1)):add(nn.SelectTable(2))
hidden:add(concat)
hidden:add(para)
hidden:add(nn.CAddTable())
- hidden:add(nn.Tanh())
return hidden
end
@@ -72,7 +75,7 @@ function LSTM:buildCell()
-- forget = forgetGate{input, output(t-1), cell(t-1)} * cell(t-1)
local forget = nn.Sequential()
local concat = nn.ConcatTable()
- concat:add(self.forgetGate):add(self.SelectTable(3))
+ concat:add(self.forgetGate):add(nn.SelectTable(3))
forget:add(concat)
forget:add(nn.CMulTable())
-- input = inputGate{input, output(t-1), cell(t-1)} * hiddenLayer{input, output(t-1), cell(t-1)}
@@ -87,14 +90,17 @@ function LSTM:buildCell()
concat3:add(forget):add(input)
cell:add(concat3)
cell:add(nn.CAddTable())
+ return cell
end
function LSTM:buildOutputGate()
- return self:buildGate()
+ local gate = self:buildGate()
+ return gate
end
-- cell(t) = cellLayer{input, output(t-1), cell(t-1)}
--- output = outputGate{input, output(t-1), cell(t)}*tanh(cell(t))
+-- output(t) = outputGate{input, output(t-1), cell(t)}*tanh(cell(t))
+-- output of Model is table : {output(t), cell(t)}
function LSTM:buildModel()
-- build components
self.cellLayer = self:buildCell()
@@ -102,24 +108,24 @@ function LSTM:buildModel()
-- assemble
local concat = nn.ConcatTable()
local concat2 = nn.ConcatTable()
- concat2:add(nn.SelectTable(1):add(nn.SelectTable(2))
+ concat2:add(nn.SelectTable(1)):add(nn.SelectTable(2))
concat:add(concat2):add(self.cellLayer)
local model = nn.Sequential()
- model:add(concat2)
- -- output of concat2 is {{input, output}, cell(t)},
+ model:add(concat)
+ -- output of concat is {{input, output}, cell(t)},
-- so flatten to {input, output, cell(t)}
model:add(nn.FlattenTable())
local cellAct = nn.Sequential()
- cellAct:add(nn.Select(3))
+ cellAct:add(nn.SelectTable(3))
cellAct:add(nn.Tanh())
local concat3 = nn.ConcatTable()
concat3:add(self.outputGate):add(cellAct)
- -- we want the model to output : {output(t), cell(t)}
- local concat4 = nn.ConcatTable()
local output = nn.Sequential()
output:add(concat3)
output:add(nn.CMulTable())
- concat4:add(output):add(nn.Identity())
+ -- we want the model to output : {output(t), cell(t)}
+ local concat4 = nn.ConcatTable()
+ concat4:add(output):add(nn.SelectTable(3))
model:add(concat4)
return model
end
@@ -131,9 +137,9 @@ function LSTM:updateOutput(input)
prevOutput = self.startOutput
prevCell = self.startCell
if input:dim() == 2 then
- self.startOutput:resize(input:size(1), self.outputSize)
+ self.startOutput:resize(input:size(1), self.outputSize):zero()
else
- self.startOutput:resize(self.outputSize)
+ self.startOutput:resize(self.outputSize):zero()
end
self.startCell:set(self.startOutput)
else
@@ -154,11 +160,12 @@ function LSTM:updateOutput(input)
self.recurrentOutputs[self.step] = recurrentOutputs
end
for i,modula in ipairs(modules) do
- local output_ = recursiveResizeAs(recurrentOutputs[i], modula.output)
+ local output_ = self.recursiveResizeAs(recurrentOutputs[i], modula.output)
modula.output = output_
end
-- the actual forward propagation
- output, cell = self.recurrentModule:updateOutput{input, prevOutput, prevCell}
+ output = self.recurrentModule:updateOutput{input, prevOutput, prevCell}
+ output, cell = unpack(output)
for i,modula in ipairs(modules) do
recurrentOutputs[i] = modula.output
@@ -192,7 +199,8 @@ function LSTM:backwardThroughTime()
local rho = math.min(self.rho, self.step-1)
local stop = self.step - rho
if self.fastBackward then
- local gradInput, gradCell
+ local gradInput, gradPrevOutput
+ local gradCell = self.startCell
for step=self.step-1,math.max(stop,1),-1 do
-- set the output/gradOutput states of current Module
local modules = self.recurrentModule:listModules()
@@ -208,20 +216,21 @@ function LSTM:backwardThroughTime()
local output_ = recurrentOutputs[i]
assert(output_, "backwardThroughTime should be preceded by updateOutput")
modula.output = output_
- modula.gradInput = recursiveCopy(recurrentGradInputs[i], gradInput)
+ modula.gradInput = self.recursiveCopy(recurrentGradInputs[i], gradInput)
end
-- backward propagate through this step
local gradOutput = self.gradOutputs[step]
- if gradInput then
- recursiveAdd(gradOutput, gradInput)
+ if gradPrevOutput then
+ assert(gradPrevOutput:sum() ~= 0)
+ self.recursiveAdd(gradOutput, gradPrevOutput)
end
+ self.gradCells[self.step] = gradCell
local scale = self.scales[step]/rho
- local inputTable = {input, self.cells[step-1], self.outputs[step-1]}
- local gradOutputTable = {gradOutput, self.gradCells[step]}
- local gradInputTable = self.recurrentModule:backward(inputTable, gradOutputTable, scale)
- gradInput, gradCell = unpack(gradInputTable)
+ local inputTable = {self.inputs[step], self.cells[step-1] or self.startOutput, self.outputs[step-1] or self.startCell}
+ local gradInputTable = self.recurrentModule:backward(inputTable, {gradOutput, gradCell}, scale)
+ gradInput, gradPrevOutput, gradCell = unpack(gradInputTable)
table.insert(self.gradInputs, 1, gradInput)
for i,modula in ipairs(modules) do
@@ -239,7 +248,7 @@ end
function LSTM:updateGradInputThroughTime()
assert(self.step > 1, "expecting at least one updateOutput")
self.gradInputs = {}
- local gradInput
+ local gradInput, gradPrevOutput
local gradCell = self.startCell
local rho = math.min(self.rho, self.step-1)
local stop = self.step - rho
@@ -262,16 +271,15 @@ function LSTM:updateGradInputThroughTime()
-- backward propagate through this step
local gradOutput = self.gradOutputs[step]
- if gradInput then
- self.recursiveAdd(gradOutput, gradInput)
+ if gradPrevOutput then
+ self.recursiveAdd(gradOutput, gradPrevOutput)
end
self.gradCells[self.step] = gradCell
local scale = self.scales[step]/rho
- local inputTable = {self.inputs[step], self.cells[step-1], self.outputs[step-1]}
- local gradOutputTable = {gradOutput, gradCell}
- local gradInputTable = self.recurrentModule:backward(inputTable, gradOutputTable, scale)
- gradInput, gradCell = unpack(gradInputTable)
+ local inputTable = {self.inputs[step], self.cells[step-1] or self.startOutput, self.outputs[step-1] or self.startCell}
+ local gradInputTable = self.recurrentModule:backward(inputTable, {gradOutput, gradCell}, scale)
+ gradInput, gradPrevOutput, gradCell = unpack(gradInputTable)
table.insert(self.gradInputs, 1, gradInput)
for i,modula in ipairs(modules) do
diff --git a/Recurrent.lua b/Recurrent.lua
index f908586..e80318a 100644
--- a/Recurrent.lua
+++ b/Recurrent.lua
@@ -17,7 +17,7 @@
local Recurrent, parent = torch.class('nn.Recurrent', 'nn.AbstractRecurrent')
function Recurrent:__init(start, input, feedback, transfer, rho, merge)
- parent.__init(self)
+ parent.__init(self, rho or 5)
local ts = torch.type(start)
if ts == 'torch.LongTensor' or ts == 'number' then
@@ -29,7 +29,6 @@ function Recurrent:__init(start, input, feedback, transfer, rho, merge)
self.feedbackModule = feedback
self.transferModule = transfer or nn.Sigmoid()
self.mergeModule = merge or nn.CAddTable()
- self.rho = rho or 5
self.modules = {self.startModule, self.inputModule, self.feedbackModule, self.transferModule, self.mergeModule}
@@ -38,8 +37,6 @@ function Recurrent:__init(start, input, feedback, transfer, rho, merge)
self.initialOutputs = {}
self.initialGradInputs = {}
- self.recurrentOutputs = {}
- self.recurrentGradInputs = {}
end
-- build module used for the first step (steps == 1)
@@ -121,7 +118,7 @@ function Recurrent:backwardThroughTime()
local stop = self.step - rho
if self.fastBackward then
self.gradInputs = {}
- local gradInput
+ local gradInput, gradPrevOutput
for step=self.step-1,math.max(stop, 2),-1 do
-- set the output/gradOutput states of current Module
local modules = self.recurrentModule:listModules()
@@ -144,12 +141,12 @@ function Recurrent:backwardThroughTime()
local input = self.inputs[step]
local output = self.outputs[step-1]
local gradOutput = self.gradOutputs[step]
- if gradInput then
- self.recursiveAdd(gradOutput, gradInput)
+ if gradPrevOutput then
+ self.recursiveAdd(gradOutput, gradPrevOutput)
end
local scale = self.scales[step]
- gradInput = self.recurrentModule:backward({input, output}, gradOutput, scale/rho)[2]
+ gradInput, gradPrevOutput = unpack(self.recurrentModule:backward({input, output}, gradOutput, scale/rho))
table.insert(self.gradInputs, 1, gradInput)
for i,modula in ipairs(modules) do
@@ -168,8 +165,8 @@ function Recurrent:backwardThroughTime()
-- backward propagate through first step
local input = self.inputs[1]
local gradOutput = self.gradOutputs[1]
- if gradInput then
- self.recursiveAdd(gradOutput, gradInput)
+ if gradPrevOutput then
+ self.recursiveAdd(gradOutput, gradPrevOutput)
end
local scale = self.scales[1]
gradInput = self.initialModule:backward(input, gradOutput, scale/rho)
@@ -201,7 +198,7 @@ end
function Recurrent:updateGradInputThroughTime()
assert(self.step > 1, "expecting at least one updateOutput")
self.gradInputs = {}
- local gradInput
+ local gradInput, gradPrevOutput
local rho = math.min(self.rho, self.step-1)
local stop = self.step - rho
for step=self.step-1,math.max(stop,2),-1 do
@@ -225,14 +222,16 @@ function Recurrent:updateGradInputThroughTime()
local input = self.inputs[step]
local output = self.outputs[step-1]
local gradOutput = self.gradOutputs[step]
- if gradInput then
- self.recursiveAdd(gradOutput, gradInput)
+ if gradPrevOutput then
+ self.recursiveAdd(gradOutput, gradPrevOutput)
end
- gradInput = self.recurrentModule:updateGradInput({input, output}, gradOutput)[2]
+
+ gradInput, gradPrevOutput = unpack(self.recurrentModule:updateGradInput({input, output}, gradOutput))
+ table.insert(self.gradInputs, 1, gradInput)
+
for i,modula in ipairs(modules) do
recurrentGradInputs[i] = modula.gradInput
end
- table.insert(self.gradInputs, 1, gradInput)
end
if stop <= 1 then
@@ -246,15 +245,15 @@ function Recurrent:updateGradInputThroughTime()
-- backward propagate through first step
local input = self.inputs[1]
local gradOutput = self.gradOutputs[1]
- if gradInput then
- self.recursiveAdd(gradOutput, gradInput)
+ if gradPrevOutput then
+ self.recursiveAdd(gradOutput, gradPrevOutput)
end
gradInput = self.initialModule:updateGradInput(input, gradOutput)
+ table.insert(self.gradInputs, 1, gradInput)
for i,modula in ipairs(modules) do
self.initialGradInputs[i] = modula.gradInput
end
- table.insert(self.gradInputs, 1, gradInput)
end
return gradInput
@@ -286,7 +285,6 @@ function Recurrent:accGradParametersThroughTime()
local scale = self.scales[step]
self.recurrentModule:accGradParameters({input, output}, gradOutput, scale/rho)
-
end
if stop <= 1 then
diff --git a/test/test-all.lua b/test/test-all.lua
index 5abd4e2..fd72b6e 100644
--- a/test/test-all.lua
+++ b/test/test-all.lua
@@ -564,6 +564,76 @@ function nnxtest.Recurrent_TestTable()
mlp:backwardThroughTime(learningRate)
end
+function nnxtest.LSTM()
+ local batchSize = math.random(1,2)
+ local inputSize = math.random(3,4)
+ local outputSize = math.random(5,6)
+ local nStep = 3
+ local input = {}
+ local gradOutput = {}
+ for step=1,nStep do
+ input[step] = torch.randn(batchSize, inputSize)
+ if step == nStep then
+ -- for the sake of keeping this unit test simple,
+ gradOutput[step] = torch.randn(batchSize, outputSize)
+ else
+ -- only the last step will get a gradient from the output
+ gradOutput[step] = torch.zeros(batchSize, outputSize)
+ end
+ end
+ local lstm = nn.LSTM(inputSize, outputSize)
+
+ -- we will use this to build an LSTM step by step (with shared params)
+ local lstmStep = lstm.recurrentModule:clone()
+
+ -- forward/backward through LSTM
+ local output = {}
+ lstm:zeroGradParameters()
+ for step=1,nStep do
+ output[step] = lstm:forward(input[step])
+ assert(torch.isTensor(input[step]))
+ lstm:backward(input[step], gradOutput[step], 1)
+ end
+ local gradInput = lstm:backwardThroughTime()
+
+ local mlp2 -- this one will simulate rho = nSteps
+ local inputs
+ for step=1,nStep do
+ -- iteratively build an LSTM out of non-recurrent components
+ local lstm = lstmStep:clone()
+ lstm:share(lstmStep)
+ lstm:share(lstmStep, 'weight', 'gradWeight', 'bias', 'gradBias')
+ if step == 1 then
+ mlp2 = lstm
+ else
+ local rnn = nn.Sequential()
+ local para = nn.ParallelTable()
+ para:add(nn.Identity()):add(mlp2)
+ rnn:add(para)
+ rnn:add(nn.FlattenTable())
+ rnn:add(lstm)
+ mlp2 = rnn
+ end
+
+
+ -- prepare inputs for mlp2
+ if inputs then
+ inputs = {input[step], inputs}
+ else
+ inputs = {input[step], torch.zeros(batchSize, outputSize), torch.zeros(batchSize, outputSize)}
+ end
+ end
+ mlp2:add(nn.SelectTable(1)) --just output the output (not cell)
+
+ local output2 = mlp2:forward(inputs)
+
+
+ mlp2:zeroGradParameters()
+ local gradInput2 = mlp2:backward(inputs, gradOutput[nStep], 1/nStep)
+ mytester:assertTensorEq(gradInput2[2][2][1], gradInput, 0.00001, "LSTM gradInput error")
+ mytester:assertTensorEq(output[nStep], output2, 0.00001, "LSTM output error")
+end
+
function nnxtest.SpatialNormalization_Gaussian2D()
local inputSize = math.random(11,20)
local kersize = 9