diff options
author | Nicholas Léonard <nick@nikopia.org> | 2015-03-14 05:41:48 +0300 |
---|---|---|
committer | Nicholas Léonard <nick@nikopia.org> | 2015-03-14 05:41:48 +0300 |
commit | e0c591c87f574e3ce4148048e65d82678859dc93 (patch) | |
tree | 96aef713d7e2ae3360edd2d3680503a3f3130cee | |
parent | 5360b32a9a24625838ec0788de6790611cd2231e (diff) | |
parent | c638f3aa83172ae757ecacffae09c08cb4a0a55f (diff) |
Merge pull request #31 from nicholas-leonard/Repeater
Repeater, RepeaterCriterion and Sequencer
-rw-r--r-- | Recurrent.lua | 13 | ||||
-rw-r--r-- | Repeater.lua | 62 | ||||
-rw-r--r-- | RepeaterCriterion.lua | 47 | ||||
-rw-r--r-- | Sequencer.lua | 184 | ||||
-rw-r--r-- | init.lua | 3 |
5 files changed, 306 insertions, 3 deletions
diff --git a/Recurrent.lua b/Recurrent.lua index fc961bd..859ce19 100644 --- a/Recurrent.lua +++ b/Recurrent.lua @@ -47,6 +47,7 @@ function Recurrent:__init(start, input, feedback, transfer, rho, merge) self.inputs = {} self.outputs = {} self.gradOutputs = {} + self.gradInputs = {} self.scales = {} self.gradParametersAccumulated = false @@ -217,6 +218,7 @@ function Recurrent:backwardThroughTime() local rho = math.min(self.rho, self.step-1) local stop = self.step - rho if self.fastBackward then + self.gradInputs = {} local gradInput for step=self.step-1,math.max(stop, 2),-1 do -- set the output/gradOutput states of current Module @@ -233,7 +235,7 @@ function Recurrent:backwardThroughTime() local output_ = recurrentOutputs[i] assert(output_, "backwardThroughTime should be preceded by updateOutput") modula.output = output_ - modula.gradInput = recursiveCopy(recurrentGradInputs[i], gradInput) + modula.gradInput = recursiveResizeAs(recurrentGradInputs[i], gradInput) end -- backward propagate through this step @@ -246,6 +248,7 @@ function Recurrent:backwardThroughTime() local scale = self.scales[step] gradInput = self.recurrentModule:backward({input, output}, gradOutput, scale/rho)[2] + table.insert(self.gradInputs, 1, gradInput) for i,modula in ipairs(modules) do recurrentGradInputs[i] = modula.gradInput @@ -257,7 +260,7 @@ function Recurrent:backwardThroughTime() local modules = self.initialModule:listModules() for i,modula in ipairs(modules) do modula.output = self.initialOutputs[i] - modula.gradInput = recursiveCopy(self.initialGradInputs[i], modula.gradInput) + modula.gradInput = recursiveResizeAs(self.initialGradInputs[i], modula.gradInput) end -- backward propagate through first step @@ -268,6 +271,7 @@ function Recurrent:backwardThroughTime() end local scale = self.scales[1] gradInput = self.initialModule:backward(input, gradOutput, scale/rho) + table.insert(self.gradInputs, 1, gradInput) for i,modula in ipairs(modules) do self.initialGradInputs[i] = modula.gradInput @@ -300,6 +304,7 @@ end function Recurrent:updateGradInputThroughTime() assert(self.step > 1, "expecting at least one updateOutput") + self.gradInputs = {} local gradInput local rho = math.min(self.rho, self.step-1) local stop = self.step - rho @@ -317,7 +322,7 @@ function Recurrent:updateGradInputThroughTime() local output_ = recurrentOutputs[i] assert(output_, "updateGradInputThroughTime should be preceded by updateOutput") modula.output = output_ - modula.gradInput = recursiveCopy(recurrentGradInputs[i], gradInput) + modula.gradInput = recursiveResizeAs(recurrentGradInputs[i], gradInput) end -- backward propagate through this step @@ -331,6 +336,7 @@ function Recurrent:updateGradInputThroughTime() for i,modula in ipairs(modules) do recurrentGradInputs[i] = modula.gradInput end + table.insert(self.gradInputs, 1, gradInput) end if stop <= 1 then @@ -352,6 +358,7 @@ function Recurrent:updateGradInputThroughTime() for i,modula in ipairs(modules) do self.initialGradInputs[i] = modula.gradInput end + table.insert(self.gradInputs, 1, gradInput) end return gradInput diff --git a/Repeater.lua b/Repeater.lua new file mode 100644 index 0000000..f9c9f11 --- /dev/null +++ b/Repeater.lua @@ -0,0 +1,62 @@ +------------------------------------------------------------------------ +--[[ Repeater ]]-- +-- Encapsulates an AbstractRecurrent instance (rnn) which is repeatedly +-- presented with the same input for nStep time steps. +-- The output is a table of nStep outputs of the rnn. +------------------------------------------------------------------------ +local Repeater, parent = torch.class("nn.Repeater", "nn.Container") + +function Repeater:__init(nStep, rnn) + parent.__init(self) + self.nStep = nStep + self.rnn = rnn + assert(rnn.backwardThroughTime, "expecting AbstractRecurrent instance for arg 2") + self.modules[1] = rnn + self.output = {} +end + +function Repeater:updateOutput(input) + self.rnn:forget() + for step=1,self.nStep do + self.output[step] = self.rnn:updateOutput(input) + end + return self.output +end + +function Repeater:updateGradInput(input, gradOutput) + assert(self.rnn.step - 1 == self.nStep, "inconsistent rnn steps") + assert(torch.type(gradOutput) == 'table', "expecting gradOutput table") + assert(#gradOutput == self.nStep, "gradOutput should have nStep elements") + for step=1,self.nStep do + self.rnn.step = step + 1 + self.rnn:updateGradInput(input, gradOutput[step]) + end + -- back-propagate through time (BPTT) + self.rnn:updateGradInputThroughTime() + self.gradInput = self.rnn.gradInputs + return self.gradInput +end + +function Repeater:accGradParameters(input, gradOutput, scale) + assert(self.rnn.step - 1 == self.nStep, "inconsistent rnn steps") + assert(torch.type(gradOutput) == 'table', "expecting gradOutput table") + assert(#gradOutput == self.nStep, "gradOutput should have nStep elements") + for step=1,self.nStep do + self.rnn.step = step + 1 + self.rnn:accGradParameters(input, gradOutput[step], scale) + end + -- back-propagate through time (BPTT) + self.rnn:accGradParametersThroughTime() +end + +function Repeater:accUpdateGradParameters(input, gradOutput, lr) + assert(self.rnn.step - 1 == self.nStep, "inconsistent rnn steps") + assert(torch.type(gradOutput) == 'table', "expecting gradOutput table") + assert(#gradOutput == self.nStep, "gradOutput should have nStep elements") + for step=1,self.nStep do + self.rnn.step = step + 1 + self.rnn:accGradParameters(input, gradOutput[step], 1) + end + -- back-propagate through time (BPTT) + self.rnn:accUpdateGradParametersThroughTime(lr) +end diff --git a/RepeaterCriterion.lua b/RepeaterCriterion.lua new file mode 100644 index 0000000..7a64ad3 --- /dev/null +++ b/RepeaterCriterion.lua @@ -0,0 +1,47 @@ +------------------------------------------------------------------------ +--[[ RepeaterCriterion ]]-- +-- Applies a criterion to each of the inputs in a Table using the +-- same target (the target is repeated). +-- Useful for nn.Repeater and nn.Sequencer. +------------------------------------------------------------------------ +local RepeaterCriterion, parent = torch.class("nn.RepeaterCriterion", "nn.Criterion") + +function RepeaterCriterion:__init(criterion) + parent.__init(self) + self.criterion = criterion + self.gradInput = {} +end + +function RepeaterCriterion:forward(inputTable, target) + for i,input in ipairs(inputTable) do + self.output = self.output + self.criterion:forward(input, target) + end +end + +function RepeaterCriterion:backward(inputTable, target) + for i,input in ipairs(inputTable) do + local gradInput = self.criterion:backward(input, target) + self.gradInput[i] = self.gradInput[i] or gradInput.new() + self.gradInput[i]:resizeAs(gradInput):copy(gradInput) + end + return self.gradInput +end + +local function recursiveType(param, type_str) + if torch.type(param) == 'table' then + for i = 1, #param do + param[i] = recursiveType(param[i], type_str) + end + else + if torch.typename(param) and + torch.typename(param):find('torch%..+Tensor') then + param = param:type(type_str) + end + end + return param +end + +function RepeaterCriterion:type(type) + self.gradInput = recursiveType(self.gradInput) + return self.criterion:type(type) +end diff --git a/Sequencer.lua b/Sequencer.lua new file mode 100644 index 0000000..f8c258f --- /dev/null +++ b/Sequencer.lua @@ -0,0 +1,184 @@ +------------------------------------------------------------------------ +--[[ Sequencer ]]-- +-- Encapsulates a Module. +-- Input is a sequence (a table) of tensors. +-- Output is a sequence (a table) of tensors of the same length. +-- Applies the module to each element in the sequence. +-- Handles both recurrent modules and non-recurrent modules. +-- The sequences in a batch must have the same size. +-- But the sequence length of each batch can vary. +------------------------------------------------------------------------ +local Sequencer, parent = torch.class("nn.Sequencer", "nn.Container") + +function Sequencer:__init(module) + parent.__init(self) + self.module = module + self.isRecurrent = rnn.backwardThroughTime ~= nil + self.modules[1] = module + self.sequenceOutputs = {} + self.output = {} + self.step = 1 +end + +local function recursiveResizeAs(t1,t2) + if torch.type(t2) == 'table' then + t1 = (torch.type(t1) == 'table') and t1 or {t1} + for key,_ in pairs(t2) do + t1[key], t2[key] = recursiveResizeAs(t1[key], t2[key]) + end + elseif torch.isTensor(t2) then + t1 = t1 or t2.new() + t1:resizeAs(t2) + else + error("expecting nested tensors or tables. Got ".. + torch.type(t1).." and "..torch.type(t2).." instead") + end + return t1, t2 +end + + +function Sequencer:updateOutput(inputTable) + assert(torch.type(inputTable) == 'table', "expecting input table") + self.output = {} + if self.isRecurrent then + self.module:forget() + for step, input in ipairs(inputTable) do + self.output[step] = self.module:updateOutput(input) + end + else + for step, input in ipairs(inputTable) do + -- set output states for this step + local modules = self.module:listModules() + local sequenceOutputs = self.sequenceOutputs[step] + if not sequenceOutputs then + sequenceOutputs = {} + self.sequenceOutputs[step] = sequenceOutputs + end + for i,modula in ipairs(modules) do + local output_ = recursiveResizeAs(sequenceOutputs[i], modula.output) + modula.output = output_ + end + + -- forward propagate this step + self.output[step] = self.module:updateOutput(input) + + -- save output state of this step + for i,modula in ipairs(modules) do + sequenceOutputs[i] = modula.output + end + end + end + return self.output +end + +function Sequencer:updateGradInput(inputTable, gradOutputTable) + self.gradInput = {} + if self.isRecurrent then + assert(torch.type(gradOutputTable) == 'table', "expecting gradOutput table") + assert(#gradOutputTable == #inputTable, "gradOutput should have as many elements as input") + for step, input in ipairs(inputTable) do + self.module.step = step + 1 + self.module:updateGradInput(input, gradOutputTable[step]) + end + -- back-propagate through time (BPTT) + self.module:updateGradInputThroughTime() + assert(self.module.gradInputs, "recurrent module did not fill gradInputs") + for step=1,#inputTable do + self.gradInput[step] = self.module.gradInputs[step] + end + assert(#self.gradInput == #inputTable, "missing gradInputs") + else + for step, input in ipairs(inputTable) do + -- set the output/gradOutput states for this step + local modules = self.module:listModules() + local sequenceOutputs = self.sequenceOutputs[step] + local sequenceGradInputs = self.sequenceGradInputs[step] + if not sequenceGradInputs then + sequenceGradInputs = {} + self.sequenceGradInputs[step] = sequenceGradInputs + end + for i,modula in ipairs(modules) do + local output, gradInput = modula.output, modula.gradInput + local output_ = sequenceOutputs[i] + assert(output_, "updateGradInputThroughTime should be preceded by updateOutput") + modula.output = output_ + modula.gradInput = recursiveResizeAs(sequenceGradInputs[i], gradInput) + end + + -- backward propagate this step + self.gradInput[step] = self.module:updateGradInput(input, gradOutputTable[step]) + + -- save the output/gradOutput states of this step + for i,modula in ipairs(modules) do + sequenceGradInputs[i] = modula.gradInput + end + end + end + return self.gradInput +end + +function Sequencer:accGradParameters(inputTable, gradOutputTable, scale) + if self.isRecurrent then + assert(torch.type(gradOutputTable) == 'table', "expecting gradOutput table") + assert(#gradOutputTable == #inputTable, "gradOutput should have as many elements as input") + for step, input in ipairs(inputTable) do + self.module.step = step + 1 + self.module:accGradParameters(input, gradOutputTable[step], scale) + end + -- back-propagate through time (BPTT) + self.module:accGradParametersThroughTime() + else + for step, input in ipairs(inputTable) do + -- set the output/gradOutput states for this step + local modules = self.module:listModules() + local sequenceOutputs = self.sequenceOutputs[step] + local sequenceGradInputs = self.sequenceGradInputs[step] + if not sequenceGradInputs then + sequenceGradInputs = {} + self.sequenceGradInputs[step] = sequenceGradInputs + end + for i,modula in ipairs(modules) do + local output, gradInput = modula.output, modula.gradInput + local output_ = sequenceOutputs[i] + modula.output = output_ + modula.gradInput = recursiveResizeAs(sequenceGradInputs[i], gradInput) + end + + -- accumulate parameters for this step + self.module:accGradParameters(input, gradOutputTable[step], scale) + end + end +end + +function Sequencer:accUpdateGradParameters(input, gradOutput, lr) + if self.isRecurrent then + assert(torch.type(gradOutputTable) == 'table', "expecting gradOutput table") + assert(#gradOutputTable == #inputTable, "gradOutput should have as many elements as input") + for step, input in ipairs(inputTable) do + self.module.step = step + 1 + self.module:accGradParameters(input, gradOutputTable[step], 1) + end + -- back-propagate through time (BPTT) + self.module:accUpdateGradParametersThroughTime(lr) + else + for step, input in ipairs(inputTable) do + -- set the output/gradOutput states for this step + local modules = self.module:listModules() + local sequenceOutputs = self.sequenceOutputs[step] + local sequenceGradInputs = self.sequenceGradInputs[step] + if not sequenceGradInputs then + sequenceGradInputs = {} + self.sequenceGradInputs[step] = sequenceGradInputs + end + for i,modula in ipairs(modules) do + local output, gradInput = modula.output, modula.gradInput + local output_ = sequenceOutputs[i] + modula.output = output_ + modula.gradInput = recursiveResizeAs(sequenceGradInputs[i], gradInput) + end + + -- accumulate parameters for this step + self.module:accUpdateGradParameters(input, gradOutputTable[step], lr) + end + end +end @@ -76,6 +76,8 @@ torch.include('nnx', 'MultiSoftMax.lua') torch.include('nnx', 'Balance.lua') torch.include('nnx', 'NarrowLookupTable.lua') torch.include('nnx', 'Recurrent.lua') +torch.include('nnx', 'Repeater.lua') +torch.include('nnx', 'Sequencer.lua') torch.include('nnx', 'PushTable.lua') torch.include('nnx', 'PullTable.lua') torch.include('nnx', 'Padding.lua') @@ -85,6 +87,7 @@ torch.include('nnx', 'SuperCriterion.lua') torch.include('nnx', 'DistNLLCriterion.lua') torch.include('nnx', 'DistMarginCriterion.lua') torch.include('nnx', 'TreeNLLCriterion.lua') +torch.include('nnx', 'RepeaterCriterion.lua') -- datasets: torch.include('nnx', 'DataSet.lua') |