diff options
author | Soumith Chintala <soumith@gmail.com> | 2015-01-07 09:28:51 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2015-01-07 09:28:51 +0300 |
commit | 675507d9a1ca9c8b854a45e388499bbffc0cda61 (patch) | |
tree | ade128400a5a753cc7086afc4fd9ad2e35888f87 | |
parent | 81d2c4215451b350404364dfc19ef5250fe6155b (diff) | |
parent | 5b198168ebaa330e0530fe67f4e08f0b8c1114ba (diff) |
Merge pull request #135 from nicholas-leonard/parallel
Parallel, Container & cie
-rw-r--r-- | Concat.lua | 2 | ||||
-rw-r--r-- | ConcatTable.lua | 63 | ||||
-rw-r--r-- | Container.lua | 3 | ||||
-rw-r--r-- | Parallel.lua | 76 | ||||
-rw-r--r-- | ParallelTable.lua | 70 | ||||
-rw-r--r-- | doc/containers.md | 29 | ||||
-rw-r--r-- | test.lua | 34 |
7 files changed, 94 insertions, 183 deletions
@@ -1,7 +1,7 @@ local Concat, parent = torch.class('nn.Concat', 'nn.Container') function Concat:__init(dimension) - parent.__init(self, dimension) + parent.__init(self) self.size = torch.LongStorage() self.dimension = dimension end diff --git a/ConcatTable.lua b/ConcatTable.lua index 62a4636..706ee6a 100644 --- a/ConcatTable.lua +++ b/ConcatTable.lua @@ -1,4 +1,4 @@ -local ConcatTable, parent = torch.class('nn.ConcatTable', 'nn.Module') +local ConcatTable, parent = torch.class('nn.ConcatTable', 'nn.Container') function ConcatTable:__init() parent.__init(self) @@ -6,19 +6,6 @@ function ConcatTable:__init() self.output = {} end -function ConcatTable:add(module) - table.insert(self.modules, module) - return self -end - -function ConcatTable:get(index) - return self.modules[index] -end - -function ConcatTable:size() - return #self.modules -end - function ConcatTable:updateOutput(input) for i=1,#self.modules do self.output[i] = self.modules[i]:updateOutput(input) @@ -99,52 +86,6 @@ function ConcatTable:zeroGradParameters() end end -function ConcatTable:updateParameters(learningRate) - for _,module in ipairs(self.modules) do - module:updateParameters(learningRate) - end -end - -function ConcatTable:training() - for i=1,#self.modules do - self.modules[i]:training() - end -end - -function ConcatTable:evaluate() - for i=1,#self.modules do - self.modules[i]:evaluate() - end -end - -function ConcatTable:share(mlp,...) - for i=1,#self.modules do - self.modules[i]:share(mlp.modules[i],...); - end -end - -function ConcatTable:parameters() - local function tinsert(to, from) - if type(from) == 'table' then - for i=1,#from do - tinsert(to,from[i]) - end - else - table.insert(to,from) - end - end - local w = {} - local gw = {} - for i=1,#self.modules do - local mw,mgw = self.modules[i]:parameters() - if mw then - tinsert(w,mw) - tinsert(gw,mgw) - end - end - return w,gw -end - function ConcatTable:type(type) parent.type(self, type) if torch.type(self.gradInput) == 'table' then @@ -161,7 +102,7 @@ function ConcatTable:__tostring__() local ext = ' | ' local extlast = ' ' local last = ' ... -> ' - local str = 'nn.ConcatTable' + local str = torch.type(self) str = str .. ' {' .. line .. tab .. 'input' for i=1,#self.modules do if i == self.modules then diff --git a/Container.lua b/Container.lua index 125ab98..484a3be 100644 --- a/Container.lua +++ b/Container.lua @@ -1,7 +1,6 @@ -- This is code common to container modules, which are collections of -- smaller constituent modules like Parallel, Sequential, etc. -local Container, parent = - torch.class('nn.Container', 'nn.Module') +local Container, parent = torch.class('nn.Container', 'nn.Module') function Container:__init(...) parent.__init(self, ...) diff --git a/Parallel.lua b/Parallel.lua index ef42723..e40c16c 100644 --- a/Parallel.lua +++ b/Parallel.lua @@ -9,28 +9,29 @@ function Parallel:__init(inputDimension,outputDimension) end function Parallel:updateOutput(input) - - local modules=input:size(self.inputDimension) + local nModule=input:size(self.inputDimension) + local outputs = {} - for i=1,modules do - local currentOutput = - self.modules[i]:updateOutput(input:select(self.inputDimension,i)) + for i=1,nModule do + local currentInput = input:select(self.inputDimension,i) + local currentOutput = self.modules[i]:updateOutput(currentInput) + table.insert(outputs, currentOutput) + local outputSize = currentOutput:size(self.outputDimension) if i == 1 then self.size:resize(currentOutput:dim()):copy(currentOutput:size()) else - self.size[self.outputDimension] = self.size[self.outputDimension] - + currentOutput:size(self.outputDimension) + self.size[self.outputDimension] = self.size[self.outputDimension] + outputSize end + end self.output:resize(self.size) local offset = 1 - for i=1,modules do - local currentOutput = self.modules[i]:updateOutput(input:select(self.inputDimension,i)) - - self.output:narrow(self.outputDimension, offset, - currentOutput:size(self.outputDimension)):copy(currentOutput) + for i=1,nModule do + local currentOutput = outputs[i] + local outputSize = currentOutput:size(self.outputDimension) + self.output:narrow(self.outputDimension, offset, outputSize):copy(currentOutput) offset = offset + currentOutput:size(self.outputDimension) end return self.output @@ -42,15 +43,16 @@ function Parallel:updateGradInput(input, gradOutput) local offset = 1 for i=1,nModule do - local module=self.modules[i]; + local module=self.modules[i] + local currentInput = input:select(self.inputDimension,i) local currentOutput = module.output - local currentGradInput = - module:updateGradInput(input:select(self.inputDimension,i), - gradOutput:narrow(self.outputDimension, - offset, currentOutput:size(self.outputDimension))) + local outputSize = currentOutput:size(self.outputDimension) + local currentGradOutput = gradOutput:narrow(self.outputDimension, offset, outputSize) + + local currentGradInput = module:updateGradInput(currentInput, currentGradOutput) self.gradInput:select(self.inputDimension,i):copy(currentGradInput) - offset = offset + currentOutput:size(self.outputDimension) + offset = offset + outputSize end return self.gradInput end @@ -60,16 +62,17 @@ function Parallel:accGradParameters(input, gradOutput, scale) local offset = 1 for i=1,nModule do - local module = self.modules[i]; + local module = self.modules[i] local currentOutput = module.output + local outputSize = currentOutput:size(self.outputDimension) + module:accGradParameters( input:select(self.inputDimension,i), - gradOutput:narrow( - self.outputDimension, offset, - currentOutput:size(self.outputDimension)), - scale) + gradOutput:narrow(self.outputDimension, offset,outputSize), + scale + ) - offset = offset + currentOutput:size(self.outputDimension) + offset = offset + outputSize end end @@ -81,6 +84,7 @@ function Parallel:accUpdateGradParameters(input, gradOutput, lr) local module = self.modules[i]; local currentOutput = module.output module:accUpdateGradParameters( + input:select(self.inputDimension,i), gradOutput:narrow(self.outputDimension, offset, currentOutput:size(self.outputDimension)), @@ -89,28 +93,6 @@ function Parallel:accUpdateGradParameters(input, gradOutput, lr) offset = offset + currentOutput:size(self.outputDimension) end end - -function Parallel:parameters() - local function tinsert(to, from) - if type(from) == 'table' then - for i=1,#from do - tinsert(to,from[i]) - end - else - table.insert(to,from) - end - end - local w = {} - local gw = {} - for i=1,#self.modules do - local mw,mgw = self.modules[i]:parameters() - if mw then - tinsert(w,mw) - tinsert(gw,mgw) - end - end - return w,gw -end function Parallel:__tostring__() local tab = ' ' @@ -119,7 +101,7 @@ function Parallel:__tostring__() local ext = ' | ' local extlast = ' ' local last = ' ... -> ' - local str = 'nn.Parallel' + local str = torch.type(self) str = str .. ' {' .. line .. tab .. 'input' for i=1,#self.modules do if i == self.modules then diff --git a/ParallelTable.lua b/ParallelTable.lua index 255a7bd..89bfc83 100644 --- a/ParallelTable.lua +++ b/ParallelTable.lua @@ -1,4 +1,4 @@ -local ParallelTable, parent = torch.class('nn.ParallelTable', 'nn.Module') +local ParallelTable, parent = torch.class('nn.ParallelTable', 'nn.Container') function ParallelTable:__init() parent.__init(self) @@ -7,19 +7,6 @@ function ParallelTable:__init() self.gradInput = {} end -function ParallelTable:add(module) - table.insert(self.modules, module) - return self -end - -function ParallelTable:get(index) - return self.modules[index] -end - -function ParallelTable:size() - return #self.modules -end - function ParallelTable:updateOutput(input) for i=1,#self.modules do self.output[i] = self.modules[i]:updateOutput(input[i]) @@ -27,7 +14,6 @@ function ParallelTable:updateOutput(input) return self.output end - function ParallelTable:updateGradInput(input, gradOutput) for i,module in ipairs(self.modules) do self.gradInput[i]= module:updateGradInput(input[i], gradOutput[i]) @@ -49,58 +35,6 @@ function ParallelTable:accUpdateGradParameters(input, gradOutput, lr) end end -function ParallelTable:zeroGradParameters() - for _,module in ipairs(self.modules) do - module:zeroGradParameters() - end -end - -function ParallelTable:updateParameters(learningRate) - for _,module in ipairs(self.modules) do - module:updateParameters(learningRate) - end -end - -function ParallelTable:training() - for i=1,#self.modules do - self.modules[i]:training() - end -end - -function ParallelTable:evaluate() - for i=1,#self.modules do - self.modules[i]:evaluate() - end -end - -function ParallelTable:share(mlp,...) - for i=1,#self.modules do - self.modules[i]:share(mlp.modules[i],...); - end -end - -function ParallelTable:parameters() - local function tinsert(to, from) - if type(from) == 'table' then - for i=1,#from do - tinsert(to,from[i]) - end - else - table.insert(to,from) - end - end - local w = {} - local gw = {} - for i=1,#self.modules do - local mw,mgw = self.modules[i]:parameters() - if mw then - tinsert(w,mw) - tinsert(gw,mgw) - end - end - return w,gw -end - function ParallelTable:__tostring__() local tab = ' ' local line = '\n' @@ -108,7 +42,7 @@ function ParallelTable:__tostring__() local ext = ' | ' local extlast = ' ' local last = ' ... -> ' - local str = 'nn.ParallelTable' + local str = torch.type(self) str = str .. ' {' .. line .. tab .. 'input' for i=1,#self.modules do if i == self.modules then diff --git a/doc/containers.md b/doc/containers.md index f529267..81d9e46 100644 --- a/doc/containers.md +++ b/doc/containers.md @@ -1,13 +1,34 @@ <a name="nn.Containers"/> # Containers # Complex neural networks are easily built using container classes: - * [Sequential](#nn.Sequential) : plugs layers in a feed-forward fully connected manner ; - * [Parallel](#nn.Parallel) : applies its `ith` child module to the `ith` slice of the input Tensor ; - * [Concat](#nn.Concat) : concatenates in one layer several modules along dimension `dim` ; - * [DepthConcat](#nn.DepthConcat) : like Concat, but adds zero-padding when non-`dim` sizes don't match; + * [Container](#nn.Container) : abstract class inherited by containers ; + * [Sequential](#nn.Sequential) : plugs layers in a feed-forward fully connected manner ; + * [Parallel](#nn.Parallel) : applies its `ith` child module to the `ith` slice of the input Tensor ; + * [Concat](#nn.Concat) : concatenates in one layer several modules along dimension `dim` ; + * [DepthConcat](#nn.DepthConcat) : like Concat, but adds zero-padding when non-`dim` sizes don't match; See also the [Table Containers](#nn.TableContainers) for manipulating tables of [Tensors](https://github.com/torch/torch7/blob/master/doc/tensor.md). +<a name="nn.Container"/> +## Container ## + +This is an abstract [Module](module.md#nn.Module) class which declares methods defined in all containers. +It reimplements many of the Module methods such that calls are propagated to the +contained modules. For example, a call to [zeroGradParameters](module.md#nn.Module.zeroGradParameters) +will be propagated to all contained modules. + +<a name="nn.Container.add"/> +### add(module) ### +Adds the given `module` to the container. The order is important + +<a name="nn.Container.get"/> +### get(index) ### +Returns the contained modules at index `index`. + +<a name="nn.Container.size"/> +### size() ### +Returns the number of contained modules. + <a name="nn.Sequential"/> ## Sequential ## @@ -2462,6 +2462,40 @@ function nntest.SpatialUpSamplingNearest() end end +function nntest.Parallel() + local input = torch.randn(3, 4, 5) + local m = nn.Parallel(1,3) + m:add(nn.View(4,5,1)) + m:add(nn.View(4,5,1)) + m:add(nn.View(4,5,1)) + + local output = m:forward(input) + local output2 = input:transpose(1,3):transpose(1,2) + mytester:assertTensorEq(output2, output, 0.000001, 'Parallel forward err') + + local gradInput = m:backward(input, output2) + mytester:assertTensorEq(gradInput, input, 0.000001, 'Parallel backward err') +end + +function nntest.ParallelTable() + local input = torch.randn(3, 4, 5) + local p = nn.ParallelTable() + p:add(nn.View(4,5,1)) + p:add(nn.View(4,5,1)) + p:add(nn.View(4,5,1)) + m = nn.Sequential() + m:add(nn.SplitTable(1)) + m:add(p) + m:add(nn.JoinTable(3)) + + local output = m:forward(input) + local output2 = input:transpose(1,3):transpose(1,2) + mytester:assertTensorEq(output2, output, 0.000001, 'ParallelTable forward err') + + local gradInput = m:backward(input, output2) + mytester:assertTensorEq(gradInput, input, 0.000001, 'ParallelTable backward err') +end + function nntest.ConcatTable() -- Test tensor input local input = torch.rand(5, 5, 5) |