Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2015-01-07 09:28:51 +0300
committerSoumith Chintala <soumith@gmail.com>2015-01-07 09:28:51 +0300
commit675507d9a1ca9c8b854a45e388499bbffc0cda61 (patch)
treeade128400a5a753cc7086afc4fd9ad2e35888f87
parent81d2c4215451b350404364dfc19ef5250fe6155b (diff)
parent5b198168ebaa330e0530fe67f4e08f0b8c1114ba (diff)
Merge pull request #135 from nicholas-leonard/parallel
Parallel, Container & cie
-rw-r--r--Concat.lua2
-rw-r--r--ConcatTable.lua63
-rw-r--r--Container.lua3
-rw-r--r--Parallel.lua76
-rw-r--r--ParallelTable.lua70
-rw-r--r--doc/containers.md29
-rw-r--r--test.lua34
7 files changed, 94 insertions, 183 deletions
diff --git a/Concat.lua b/Concat.lua
index b0436a5..9943d56 100644
--- a/Concat.lua
+++ b/Concat.lua
@@ -1,7 +1,7 @@
local Concat, parent = torch.class('nn.Concat', 'nn.Container')
function Concat:__init(dimension)
- parent.__init(self, dimension)
+ parent.__init(self)
self.size = torch.LongStorage()
self.dimension = dimension
end
diff --git a/ConcatTable.lua b/ConcatTable.lua
index 62a4636..706ee6a 100644
--- a/ConcatTable.lua
+++ b/ConcatTable.lua
@@ -1,4 +1,4 @@
-local ConcatTable, parent = torch.class('nn.ConcatTable', 'nn.Module')
+local ConcatTable, parent = torch.class('nn.ConcatTable', 'nn.Container')
function ConcatTable:__init()
parent.__init(self)
@@ -6,19 +6,6 @@ function ConcatTable:__init()
self.output = {}
end
-function ConcatTable:add(module)
- table.insert(self.modules, module)
- return self
-end
-
-function ConcatTable:get(index)
- return self.modules[index]
-end
-
-function ConcatTable:size()
- return #self.modules
-end
-
function ConcatTable:updateOutput(input)
for i=1,#self.modules do
self.output[i] = self.modules[i]:updateOutput(input)
@@ -99,52 +86,6 @@ function ConcatTable:zeroGradParameters()
end
end
-function ConcatTable:updateParameters(learningRate)
- for _,module in ipairs(self.modules) do
- module:updateParameters(learningRate)
- end
-end
-
-function ConcatTable:training()
- for i=1,#self.modules do
- self.modules[i]:training()
- end
-end
-
-function ConcatTable:evaluate()
- for i=1,#self.modules do
- self.modules[i]:evaluate()
- end
-end
-
-function ConcatTable:share(mlp,...)
- for i=1,#self.modules do
- self.modules[i]:share(mlp.modules[i],...);
- end
-end
-
-function ConcatTable:parameters()
- local function tinsert(to, from)
- if type(from) == 'table' then
- for i=1,#from do
- tinsert(to,from[i])
- end
- else
- table.insert(to,from)
- end
- end
- local w = {}
- local gw = {}
- for i=1,#self.modules do
- local mw,mgw = self.modules[i]:parameters()
- if mw then
- tinsert(w,mw)
- tinsert(gw,mgw)
- end
- end
- return w,gw
-end
-
function ConcatTable:type(type)
parent.type(self, type)
if torch.type(self.gradInput) == 'table' then
@@ -161,7 +102,7 @@ function ConcatTable:__tostring__()
local ext = ' | '
local extlast = ' '
local last = ' ... -> '
- local str = 'nn.ConcatTable'
+ local str = torch.type(self)
str = str .. ' {' .. line .. tab .. 'input'
for i=1,#self.modules do
if i == self.modules then
diff --git a/Container.lua b/Container.lua
index 125ab98..484a3be 100644
--- a/Container.lua
+++ b/Container.lua
@@ -1,7 +1,6 @@
-- This is code common to container modules, which are collections of
-- smaller constituent modules like Parallel, Sequential, etc.
-local Container, parent =
- torch.class('nn.Container', 'nn.Module')
+local Container, parent = torch.class('nn.Container', 'nn.Module')
function Container:__init(...)
parent.__init(self, ...)
diff --git a/Parallel.lua b/Parallel.lua
index ef42723..e40c16c 100644
--- a/Parallel.lua
+++ b/Parallel.lua
@@ -9,28 +9,29 @@ function Parallel:__init(inputDimension,outputDimension)
end
function Parallel:updateOutput(input)
-
- local modules=input:size(self.inputDimension)
+ local nModule=input:size(self.inputDimension)
+ local outputs = {}
- for i=1,modules do
- local currentOutput =
- self.modules[i]:updateOutput(input:select(self.inputDimension,i))
+ for i=1,nModule do
+ local currentInput = input:select(self.inputDimension,i)
+ local currentOutput = self.modules[i]:updateOutput(currentInput)
+ table.insert(outputs, currentOutput)
+ local outputSize = currentOutput:size(self.outputDimension)
if i == 1 then
self.size:resize(currentOutput:dim()):copy(currentOutput:size())
else
- self.size[self.outputDimension] = self.size[self.outputDimension]
- + currentOutput:size(self.outputDimension)
+ self.size[self.outputDimension] = self.size[self.outputDimension] + outputSize
end
+
end
self.output:resize(self.size)
local offset = 1
- for i=1,modules do
- local currentOutput = self.modules[i]:updateOutput(input:select(self.inputDimension,i))
-
- self.output:narrow(self.outputDimension, offset,
- currentOutput:size(self.outputDimension)):copy(currentOutput)
+ for i=1,nModule do
+ local currentOutput = outputs[i]
+ local outputSize = currentOutput:size(self.outputDimension)
+ self.output:narrow(self.outputDimension, offset, outputSize):copy(currentOutput)
offset = offset + currentOutput:size(self.outputDimension)
end
return self.output
@@ -42,15 +43,16 @@ function Parallel:updateGradInput(input, gradOutput)
local offset = 1
for i=1,nModule do
- local module=self.modules[i];
+ local module=self.modules[i]
+ local currentInput = input:select(self.inputDimension,i)
local currentOutput = module.output
- local currentGradInput =
- module:updateGradInput(input:select(self.inputDimension,i),
- gradOutput:narrow(self.outputDimension,
- offset, currentOutput:size(self.outputDimension)))
+ local outputSize = currentOutput:size(self.outputDimension)
+ local currentGradOutput = gradOutput:narrow(self.outputDimension, offset, outputSize)
+
+ local currentGradInput = module:updateGradInput(currentInput, currentGradOutput)
self.gradInput:select(self.inputDimension,i):copy(currentGradInput)
- offset = offset + currentOutput:size(self.outputDimension)
+ offset = offset + outputSize
end
return self.gradInput
end
@@ -60,16 +62,17 @@ function Parallel:accGradParameters(input, gradOutput, scale)
local offset = 1
for i=1,nModule do
- local module = self.modules[i];
+ local module = self.modules[i]
local currentOutput = module.output
+ local outputSize = currentOutput:size(self.outputDimension)
+
module:accGradParameters(
input:select(self.inputDimension,i),
- gradOutput:narrow(
- self.outputDimension, offset,
- currentOutput:size(self.outputDimension)),
- scale)
+ gradOutput:narrow(self.outputDimension, offset,outputSize),
+ scale
+ )
- offset = offset + currentOutput:size(self.outputDimension)
+ offset = offset + outputSize
end
end
@@ -81,6 +84,7 @@ function Parallel:accUpdateGradParameters(input, gradOutput, lr)
local module = self.modules[i];
local currentOutput = module.output
module:accUpdateGradParameters(
+
input:select(self.inputDimension,i),
gradOutput:narrow(self.outputDimension, offset,
currentOutput:size(self.outputDimension)),
@@ -89,28 +93,6 @@ function Parallel:accUpdateGradParameters(input, gradOutput, lr)
offset = offset + currentOutput:size(self.outputDimension)
end
end
-
-function Parallel:parameters()
- local function tinsert(to, from)
- if type(from) == 'table' then
- for i=1,#from do
- tinsert(to,from[i])
- end
- else
- table.insert(to,from)
- end
- end
- local w = {}
- local gw = {}
- for i=1,#self.modules do
- local mw,mgw = self.modules[i]:parameters()
- if mw then
- tinsert(w,mw)
- tinsert(gw,mgw)
- end
- end
- return w,gw
-end
function Parallel:__tostring__()
local tab = ' '
@@ -119,7 +101,7 @@ function Parallel:__tostring__()
local ext = ' | '
local extlast = ' '
local last = ' ... -> '
- local str = 'nn.Parallel'
+ local str = torch.type(self)
str = str .. ' {' .. line .. tab .. 'input'
for i=1,#self.modules do
if i == self.modules then
diff --git a/ParallelTable.lua b/ParallelTable.lua
index 255a7bd..89bfc83 100644
--- a/ParallelTable.lua
+++ b/ParallelTable.lua
@@ -1,4 +1,4 @@
-local ParallelTable, parent = torch.class('nn.ParallelTable', 'nn.Module')
+local ParallelTable, parent = torch.class('nn.ParallelTable', 'nn.Container')
function ParallelTable:__init()
parent.__init(self)
@@ -7,19 +7,6 @@ function ParallelTable:__init()
self.gradInput = {}
end
-function ParallelTable:add(module)
- table.insert(self.modules, module)
- return self
-end
-
-function ParallelTable:get(index)
- return self.modules[index]
-end
-
-function ParallelTable:size()
- return #self.modules
-end
-
function ParallelTable:updateOutput(input)
for i=1,#self.modules do
self.output[i] = self.modules[i]:updateOutput(input[i])
@@ -27,7 +14,6 @@ function ParallelTable:updateOutput(input)
return self.output
end
-
function ParallelTable:updateGradInput(input, gradOutput)
for i,module in ipairs(self.modules) do
self.gradInput[i]= module:updateGradInput(input[i], gradOutput[i])
@@ -49,58 +35,6 @@ function ParallelTable:accUpdateGradParameters(input, gradOutput, lr)
end
end
-function ParallelTable:zeroGradParameters()
- for _,module in ipairs(self.modules) do
- module:zeroGradParameters()
- end
-end
-
-function ParallelTable:updateParameters(learningRate)
- for _,module in ipairs(self.modules) do
- module:updateParameters(learningRate)
- end
-end
-
-function ParallelTable:training()
- for i=1,#self.modules do
- self.modules[i]:training()
- end
-end
-
-function ParallelTable:evaluate()
- for i=1,#self.modules do
- self.modules[i]:evaluate()
- end
-end
-
-function ParallelTable:share(mlp,...)
- for i=1,#self.modules do
- self.modules[i]:share(mlp.modules[i],...);
- end
-end
-
-function ParallelTable:parameters()
- local function tinsert(to, from)
- if type(from) == 'table' then
- for i=1,#from do
- tinsert(to,from[i])
- end
- else
- table.insert(to,from)
- end
- end
- local w = {}
- local gw = {}
- for i=1,#self.modules do
- local mw,mgw = self.modules[i]:parameters()
- if mw then
- tinsert(w,mw)
- tinsert(gw,mgw)
- end
- end
- return w,gw
-end
-
function ParallelTable:__tostring__()
local tab = ' '
local line = '\n'
@@ -108,7 +42,7 @@ function ParallelTable:__tostring__()
local ext = ' | '
local extlast = ' '
local last = ' ... -> '
- local str = 'nn.ParallelTable'
+ local str = torch.type(self)
str = str .. ' {' .. line .. tab .. 'input'
for i=1,#self.modules do
if i == self.modules then
diff --git a/doc/containers.md b/doc/containers.md
index f529267..81d9e46 100644
--- a/doc/containers.md
+++ b/doc/containers.md
@@ -1,13 +1,34 @@
<a name="nn.Containers"/>
# Containers #
Complex neural networks are easily built using container classes:
- * [Sequential](#nn.Sequential) : plugs layers in a feed-forward fully connected manner ;
- * [Parallel](#nn.Parallel) : applies its `ith` child module to the `ith` slice of the input Tensor ;
- * [Concat](#nn.Concat) : concatenates in one layer several modules along dimension `dim` ;
- * [DepthConcat](#nn.DepthConcat) : like Concat, but adds zero-padding when non-`dim` sizes don't match;
+ * [Container](#nn.Container) : abstract class inherited by containers ;
+ * [Sequential](#nn.Sequential) : plugs layers in a feed-forward fully connected manner ;
+ * [Parallel](#nn.Parallel) : applies its `ith` child module to the `ith` slice of the input Tensor ;
+ * [Concat](#nn.Concat) : concatenates in one layer several modules along dimension `dim` ;
+ * [DepthConcat](#nn.DepthConcat) : like Concat, but adds zero-padding when non-`dim` sizes don't match;
See also the [Table Containers](#nn.TableContainers) for manipulating tables of [Tensors](https://github.com/torch/torch7/blob/master/doc/tensor.md).
+<a name="nn.Container"/>
+## Container ##
+
+This is an abstract [Module](module.md#nn.Module) class which declares methods defined in all containers.
+It reimplements many of the Module methods such that calls are propagated to the
+contained modules. For example, a call to [zeroGradParameters](module.md#nn.Module.zeroGradParameters)
+will be propagated to all contained modules.
+
+<a name="nn.Container.add"/>
+### add(module) ###
+Adds the given `module` to the container. The order is important
+
+<a name="nn.Container.get"/>
+### get(index) ###
+Returns the contained modules at index `index`.
+
+<a name="nn.Container.size"/>
+### size() ###
+Returns the number of contained modules.
+
<a name="nn.Sequential"/>
## Sequential ##
diff --git a/test.lua b/test.lua
index a541a1b..7c4229a 100644
--- a/test.lua
+++ b/test.lua
@@ -2462,6 +2462,40 @@ function nntest.SpatialUpSamplingNearest()
end
end
+function nntest.Parallel()
+ local input = torch.randn(3, 4, 5)
+ local m = nn.Parallel(1,3)
+ m:add(nn.View(4,5,1))
+ m:add(nn.View(4,5,1))
+ m:add(nn.View(4,5,1))
+
+ local output = m:forward(input)
+ local output2 = input:transpose(1,3):transpose(1,2)
+ mytester:assertTensorEq(output2, output, 0.000001, 'Parallel forward err')
+
+ local gradInput = m:backward(input, output2)
+ mytester:assertTensorEq(gradInput, input, 0.000001, 'Parallel backward err')
+end
+
+function nntest.ParallelTable()
+ local input = torch.randn(3, 4, 5)
+ local p = nn.ParallelTable()
+ p:add(nn.View(4,5,1))
+ p:add(nn.View(4,5,1))
+ p:add(nn.View(4,5,1))
+ m = nn.Sequential()
+ m:add(nn.SplitTable(1))
+ m:add(p)
+ m:add(nn.JoinTable(3))
+
+ local output = m:forward(input)
+ local output2 = input:transpose(1,3):transpose(1,2)
+ mytester:assertTensorEq(output2, output, 0.000001, 'ParallelTable forward err')
+
+ local gradInput = m:backward(input, output2)
+ mytester:assertTensorEq(gradInput, input, 0.000001, 'ParallelTable backward err')
+end
+
function nntest.ConcatTable()
-- Test tensor input
local input = torch.rand(5, 5, 5)