Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNicholas Léonard <nick@nikopia.org>2015-03-14 05:41:48 +0300
committerNicholas Léonard <nick@nikopia.org>2015-03-14 05:41:48 +0300
commite0c591c87f574e3ce4148048e65d82678859dc93 (patch)
tree96aef713d7e2ae3360edd2d3680503a3f3130cee
parent5360b32a9a24625838ec0788de6790611cd2231e (diff)
parentc638f3aa83172ae757ecacffae09c08cb4a0a55f (diff)
Merge pull request #31 from nicholas-leonard/Repeater
Repeater, RepeaterCriterion and Sequencer
-rw-r--r--Recurrent.lua13
-rw-r--r--Repeater.lua62
-rw-r--r--RepeaterCriterion.lua47
-rw-r--r--Sequencer.lua184
-rw-r--r--init.lua3
5 files changed, 306 insertions, 3 deletions
diff --git a/Recurrent.lua b/Recurrent.lua
index fc961bd..859ce19 100644
--- a/Recurrent.lua
+++ b/Recurrent.lua
@@ -47,6 +47,7 @@ function Recurrent:__init(start, input, feedback, transfer, rho, merge)
self.inputs = {}
self.outputs = {}
self.gradOutputs = {}
+ self.gradInputs = {}
self.scales = {}
self.gradParametersAccumulated = false
@@ -217,6 +218,7 @@ function Recurrent:backwardThroughTime()
local rho = math.min(self.rho, self.step-1)
local stop = self.step - rho
if self.fastBackward then
+ self.gradInputs = {}
local gradInput
for step=self.step-1,math.max(stop, 2),-1 do
-- set the output/gradOutput states of current Module
@@ -233,7 +235,7 @@ function Recurrent:backwardThroughTime()
local output_ = recurrentOutputs[i]
assert(output_, "backwardThroughTime should be preceded by updateOutput")
modula.output = output_
- modula.gradInput = recursiveCopy(recurrentGradInputs[i], gradInput)
+ modula.gradInput = recursiveResizeAs(recurrentGradInputs[i], gradInput)
end
-- backward propagate through this step
@@ -246,6 +248,7 @@ function Recurrent:backwardThroughTime()
local scale = self.scales[step]
gradInput = self.recurrentModule:backward({input, output}, gradOutput, scale/rho)[2]
+ table.insert(self.gradInputs, 1, gradInput)
for i,modula in ipairs(modules) do
recurrentGradInputs[i] = modula.gradInput
@@ -257,7 +260,7 @@ function Recurrent:backwardThroughTime()
local modules = self.initialModule:listModules()
for i,modula in ipairs(modules) do
modula.output = self.initialOutputs[i]
- modula.gradInput = recursiveCopy(self.initialGradInputs[i], modula.gradInput)
+ modula.gradInput = recursiveResizeAs(self.initialGradInputs[i], modula.gradInput)
end
-- backward propagate through first step
@@ -268,6 +271,7 @@ function Recurrent:backwardThroughTime()
end
local scale = self.scales[1]
gradInput = self.initialModule:backward(input, gradOutput, scale/rho)
+ table.insert(self.gradInputs, 1, gradInput)
for i,modula in ipairs(modules) do
self.initialGradInputs[i] = modula.gradInput
@@ -300,6 +304,7 @@ end
function Recurrent:updateGradInputThroughTime()
assert(self.step > 1, "expecting at least one updateOutput")
+ self.gradInputs = {}
local gradInput
local rho = math.min(self.rho, self.step-1)
local stop = self.step - rho
@@ -317,7 +322,7 @@ function Recurrent:updateGradInputThroughTime()
local output_ = recurrentOutputs[i]
assert(output_, "updateGradInputThroughTime should be preceded by updateOutput")
modula.output = output_
- modula.gradInput = recursiveCopy(recurrentGradInputs[i], gradInput)
+ modula.gradInput = recursiveResizeAs(recurrentGradInputs[i], gradInput)
end
-- backward propagate through this step
@@ -331,6 +336,7 @@ function Recurrent:updateGradInputThroughTime()
for i,modula in ipairs(modules) do
recurrentGradInputs[i] = modula.gradInput
end
+ table.insert(self.gradInputs, 1, gradInput)
end
if stop <= 1 then
@@ -352,6 +358,7 @@ function Recurrent:updateGradInputThroughTime()
for i,modula in ipairs(modules) do
self.initialGradInputs[i] = modula.gradInput
end
+ table.insert(self.gradInputs, 1, gradInput)
end
return gradInput
diff --git a/Repeater.lua b/Repeater.lua
new file mode 100644
index 0000000..f9c9f11
--- /dev/null
+++ b/Repeater.lua
@@ -0,0 +1,62 @@
+------------------------------------------------------------------------
+--[[ Repeater ]]--
+-- Encapsulates an AbstractRecurrent instance (rnn) which is repeatedly
+-- presented with the same input for nStep time steps.
+-- The output is a table of nStep outputs of the rnn.
+------------------------------------------------------------------------
+local Repeater, parent = torch.class("nn.Repeater", "nn.Container")
+
+function Repeater:__init(nStep, rnn)
+ parent.__init(self)
+ self.nStep = nStep
+ self.rnn = rnn
+ assert(rnn.backwardThroughTime, "expecting AbstractRecurrent instance for arg 2")
+ self.modules[1] = rnn
+ self.output = {}
+end
+
+function Repeater:updateOutput(input)
+ self.rnn:forget()
+ for step=1,self.nStep do
+ self.output[step] = self.rnn:updateOutput(input)
+ end
+ return self.output
+end
+
+function Repeater:updateGradInput(input, gradOutput)
+ assert(self.rnn.step - 1 == self.nStep, "inconsistent rnn steps")
+ assert(torch.type(gradOutput) == 'table', "expecting gradOutput table")
+ assert(#gradOutput == self.nStep, "gradOutput should have nStep elements")
+ for step=1,self.nStep do
+ self.rnn.step = step + 1
+ self.rnn:updateGradInput(input, gradOutput[step])
+ end
+ -- back-propagate through time (BPTT)
+ self.rnn:updateGradInputThroughTime()
+ self.gradInput = self.rnn.gradInputs
+ return self.gradInput
+end
+
+function Repeater:accGradParameters(input, gradOutput, scale)
+ assert(self.rnn.step - 1 == self.nStep, "inconsistent rnn steps")
+ assert(torch.type(gradOutput) == 'table', "expecting gradOutput table")
+ assert(#gradOutput == self.nStep, "gradOutput should have nStep elements")
+ for step=1,self.nStep do
+ self.rnn.step = step + 1
+ self.rnn:accGradParameters(input, gradOutput[step], scale)
+ end
+ -- back-propagate through time (BPTT)
+ self.rnn:accGradParametersThroughTime()
+end
+
+function Repeater:accUpdateGradParameters(input, gradOutput, lr)
+ assert(self.rnn.step - 1 == self.nStep, "inconsistent rnn steps")
+ assert(torch.type(gradOutput) == 'table', "expecting gradOutput table")
+ assert(#gradOutput == self.nStep, "gradOutput should have nStep elements")
+ for step=1,self.nStep do
+ self.rnn.step = step + 1
+ self.rnn:accGradParameters(input, gradOutput[step], 1)
+ end
+ -- back-propagate through time (BPTT)
+ self.rnn:accUpdateGradParametersThroughTime(lr)
+end
diff --git a/RepeaterCriterion.lua b/RepeaterCriterion.lua
new file mode 100644
index 0000000..7a64ad3
--- /dev/null
+++ b/RepeaterCriterion.lua
@@ -0,0 +1,47 @@
+------------------------------------------------------------------------
+--[[ RepeaterCriterion ]]--
+-- Applies a criterion to each of the inputs in a Table using the
+-- same target (the target is repeated).
+-- Useful for nn.Repeater and nn.Sequencer.
+------------------------------------------------------------------------
+local RepeaterCriterion, parent = torch.class("nn.RepeaterCriterion", "nn.Criterion")
+
+function RepeaterCriterion:__init(criterion)
+ parent.__init(self)
+ self.criterion = criterion
+ self.gradInput = {}
+end
+
+function RepeaterCriterion:forward(inputTable, target)
+ for i,input in ipairs(inputTable) do
+ self.output = self.output + self.criterion:forward(input, target)
+ end
+end
+
+function RepeaterCriterion:backward(inputTable, target)
+ for i,input in ipairs(inputTable) do
+ local gradInput = self.criterion:backward(input, target)
+ self.gradInput[i] = self.gradInput[i] or gradInput.new()
+ self.gradInput[i]:resizeAs(gradInput):copy(gradInput)
+ end
+ return self.gradInput
+end
+
+local function recursiveType(param, type_str)
+ if torch.type(param) == 'table' then
+ for i = 1, #param do
+ param[i] = recursiveType(param[i], type_str)
+ end
+ else
+ if torch.typename(param) and
+ torch.typename(param):find('torch%..+Tensor') then
+ param = param:type(type_str)
+ end
+ end
+ return param
+end
+
+function RepeaterCriterion:type(type)
+ self.gradInput = recursiveType(self.gradInput)
+ return self.criterion:type(type)
+end
diff --git a/Sequencer.lua b/Sequencer.lua
new file mode 100644
index 0000000..f8c258f
--- /dev/null
+++ b/Sequencer.lua
@@ -0,0 +1,184 @@
+------------------------------------------------------------------------
+--[[ Sequencer ]]--
+-- Encapsulates a Module.
+-- Input is a sequence (a table) of tensors.
+-- Output is a sequence (a table) of tensors of the same length.
+-- Applies the module to each element in the sequence.
+-- Handles both recurrent modules and non-recurrent modules.
+-- The sequences in a batch must have the same size.
+-- But the sequence length of each batch can vary.
+------------------------------------------------------------------------
+local Sequencer, parent = torch.class("nn.Sequencer", "nn.Container")
+
+function Sequencer:__init(module)
+ parent.__init(self)
+ self.module = module
+ self.isRecurrent = rnn.backwardThroughTime ~= nil
+ self.modules[1] = module
+ self.sequenceOutputs = {}
+ self.output = {}
+ self.step = 1
+end
+
+local function recursiveResizeAs(t1,t2)
+ if torch.type(t2) == 'table' then
+ t1 = (torch.type(t1) == 'table') and t1 or {t1}
+ for key,_ in pairs(t2) do
+ t1[key], t2[key] = recursiveResizeAs(t1[key], t2[key])
+ end
+ elseif torch.isTensor(t2) then
+ t1 = t1 or t2.new()
+ t1:resizeAs(t2)
+ else
+ error("expecting nested tensors or tables. Got "..
+ torch.type(t1).." and "..torch.type(t2).." instead")
+ end
+ return t1, t2
+end
+
+
+function Sequencer:updateOutput(inputTable)
+ assert(torch.type(inputTable) == 'table', "expecting input table")
+ self.output = {}
+ if self.isRecurrent then
+ self.module:forget()
+ for step, input in ipairs(inputTable) do
+ self.output[step] = self.module:updateOutput(input)
+ end
+ else
+ for step, input in ipairs(inputTable) do
+ -- set output states for this step
+ local modules = self.module:listModules()
+ local sequenceOutputs = self.sequenceOutputs[step]
+ if not sequenceOutputs then
+ sequenceOutputs = {}
+ self.sequenceOutputs[step] = sequenceOutputs
+ end
+ for i,modula in ipairs(modules) do
+ local output_ = recursiveResizeAs(sequenceOutputs[i], modula.output)
+ modula.output = output_
+ end
+
+ -- forward propagate this step
+ self.output[step] = self.module:updateOutput(input)
+
+ -- save output state of this step
+ for i,modula in ipairs(modules) do
+ sequenceOutputs[i] = modula.output
+ end
+ end
+ end
+ return self.output
+end
+
+function Sequencer:updateGradInput(inputTable, gradOutputTable)
+ self.gradInput = {}
+ if self.isRecurrent then
+ assert(torch.type(gradOutputTable) == 'table', "expecting gradOutput table")
+ assert(#gradOutputTable == #inputTable, "gradOutput should have as many elements as input")
+ for step, input in ipairs(inputTable) do
+ self.module.step = step + 1
+ self.module:updateGradInput(input, gradOutputTable[step])
+ end
+ -- back-propagate through time (BPTT)
+ self.module:updateGradInputThroughTime()
+ assert(self.module.gradInputs, "recurrent module did not fill gradInputs")
+ for step=1,#inputTable do
+ self.gradInput[step] = self.module.gradInputs[step]
+ end
+ assert(#self.gradInput == #inputTable, "missing gradInputs")
+ else
+ for step, input in ipairs(inputTable) do
+ -- set the output/gradOutput states for this step
+ local modules = self.module:listModules()
+ local sequenceOutputs = self.sequenceOutputs[step]
+ local sequenceGradInputs = self.sequenceGradInputs[step]
+ if not sequenceGradInputs then
+ sequenceGradInputs = {}
+ self.sequenceGradInputs[step] = sequenceGradInputs
+ end
+ for i,modula in ipairs(modules) do
+ local output, gradInput = modula.output, modula.gradInput
+ local output_ = sequenceOutputs[i]
+ assert(output_, "updateGradInputThroughTime should be preceded by updateOutput")
+ modula.output = output_
+ modula.gradInput = recursiveResizeAs(sequenceGradInputs[i], gradInput)
+ end
+
+ -- backward propagate this step
+ self.gradInput[step] = self.module:updateGradInput(input, gradOutputTable[step])
+
+ -- save the output/gradOutput states of this step
+ for i,modula in ipairs(modules) do
+ sequenceGradInputs[i] = modula.gradInput
+ end
+ end
+ end
+ return self.gradInput
+end
+
+function Sequencer:accGradParameters(inputTable, gradOutputTable, scale)
+ if self.isRecurrent then
+ assert(torch.type(gradOutputTable) == 'table', "expecting gradOutput table")
+ assert(#gradOutputTable == #inputTable, "gradOutput should have as many elements as input")
+ for step, input in ipairs(inputTable) do
+ self.module.step = step + 1
+ self.module:accGradParameters(input, gradOutputTable[step], scale)
+ end
+ -- back-propagate through time (BPTT)
+ self.module:accGradParametersThroughTime()
+ else
+ for step, input in ipairs(inputTable) do
+ -- set the output/gradOutput states for this step
+ local modules = self.module:listModules()
+ local sequenceOutputs = self.sequenceOutputs[step]
+ local sequenceGradInputs = self.sequenceGradInputs[step]
+ if not sequenceGradInputs then
+ sequenceGradInputs = {}
+ self.sequenceGradInputs[step] = sequenceGradInputs
+ end
+ for i,modula in ipairs(modules) do
+ local output, gradInput = modula.output, modula.gradInput
+ local output_ = sequenceOutputs[i]
+ modula.output = output_
+ modula.gradInput = recursiveResizeAs(sequenceGradInputs[i], gradInput)
+ end
+
+ -- accumulate parameters for this step
+ self.module:accGradParameters(input, gradOutputTable[step], scale)
+ end
+ end
+end
+
+function Sequencer:accUpdateGradParameters(input, gradOutput, lr)
+ if self.isRecurrent then
+ assert(torch.type(gradOutputTable) == 'table', "expecting gradOutput table")
+ assert(#gradOutputTable == #inputTable, "gradOutput should have as many elements as input")
+ for step, input in ipairs(inputTable) do
+ self.module.step = step + 1
+ self.module:accGradParameters(input, gradOutputTable[step], 1)
+ end
+ -- back-propagate through time (BPTT)
+ self.module:accUpdateGradParametersThroughTime(lr)
+ else
+ for step, input in ipairs(inputTable) do
+ -- set the output/gradOutput states for this step
+ local modules = self.module:listModules()
+ local sequenceOutputs = self.sequenceOutputs[step]
+ local sequenceGradInputs = self.sequenceGradInputs[step]
+ if not sequenceGradInputs then
+ sequenceGradInputs = {}
+ self.sequenceGradInputs[step] = sequenceGradInputs
+ end
+ for i,modula in ipairs(modules) do
+ local output, gradInput = modula.output, modula.gradInput
+ local output_ = sequenceOutputs[i]
+ modula.output = output_
+ modula.gradInput = recursiveResizeAs(sequenceGradInputs[i], gradInput)
+ end
+
+ -- accumulate parameters for this step
+ self.module:accUpdateGradParameters(input, gradOutputTable[step], lr)
+ end
+ end
+end
diff --git a/init.lua b/init.lua
index d474525..3d6c5df 100644
--- a/init.lua
+++ b/init.lua
@@ -76,6 +76,8 @@ torch.include('nnx', 'MultiSoftMax.lua')
torch.include('nnx', 'Balance.lua')
torch.include('nnx', 'NarrowLookupTable.lua')
torch.include('nnx', 'Recurrent.lua')
+torch.include('nnx', 'Repeater.lua')
+torch.include('nnx', 'Sequencer.lua')
torch.include('nnx', 'PushTable.lua')
torch.include('nnx', 'PullTable.lua')
torch.include('nnx', 'Padding.lua')
@@ -85,6 +87,7 @@ torch.include('nnx', 'SuperCriterion.lua')
torch.include('nnx', 'DistNLLCriterion.lua')
torch.include('nnx', 'DistMarginCriterion.lua')
torch.include('nnx', 'TreeNLLCriterion.lua')
+torch.include('nnx', 'RepeaterCriterion.lua')
-- datasets:
torch.include('nnx', 'DataSet.lua')