diff options
author | Soumith Chintala <soumith@gmail.com> | 2015-01-03 07:48:18 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2015-01-03 07:48:18 +0300 |
commit | 4e0a96d801060121521ccc46f7294aeb3b247965 (patch) | |
tree | 412418951bfbf5174bff2fa99eb1f48f0b71ac1e | |
parent | da0c4e81ddf757786a89073dab5d1b1d192216b5 (diff) | |
parent | 1efff4dc0b0bd396be618c352a9d6941a2f7b8b4 (diff) |
Merge pull request #132 from torch/container
refactoring all the common container code into nn.Container
-rw-r--r-- | Concat.lua | 66 | ||||
-rw-r--r-- | Container.lua | 80 | ||||
-rw-r--r-- | Parallel.lua | 41 | ||||
-rw-r--r-- | Sequential.lua | 72 | ||||
-rw-r--r-- | init.lua | 1 | ||||
-rw-r--r-- | test.lua | 2 |
6 files changed, 86 insertions, 176 deletions
@@ -1,21 +1,11 @@ -local Concat, parent = torch.class('nn.Concat', 'nn.Module') +local Concat, parent = torch.class('nn.Concat', 'nn.Container') function Concat:__init(dimension) - parent.__init(self) - self.modules = {} + parent.__init(self, dimension) self.size = torch.LongStorage() self.dimension = dimension end -function Concat:add(module) - table.insert(self.modules, module) - return self -end - -function Concat:get(index) - return self.modules[index] -end - function Concat:updateOutput(input) local outs = {} for i=1,#self.modules do @@ -83,58 +73,6 @@ function Concat:accUpdateGradParameters(input, gradOutput, lr) end end -function Concat:zeroGradParameters() - for _,module in ipairs(self.modules) do - module:zeroGradParameters() - end -end - -function Concat:updateParameters(learningRate) - for _,module in ipairs(self.modules) do - module:updateParameters(learningRate) - end -end - -function Concat:training() - for i=1,#self.modules do - self.modules[i]:training() - end -end - -function Concat:evaluate() - for i=1,#self.modules do - self.modules[i]:evaluate() - end -end - -function Concat:share(mlp,...) - for i=1,#self.modules do - self.modules[i]:share(mlp.modules[i],...); - end -end - -function Concat:parameters() - local function tinsert(to, from) - if type(from) == 'table' then - for i=1,#from do - tinsert(to,from[i]) - end - else - table.insert(to,from) - end - end - local w = {} - local gw = {} - for i=1,#self.modules do - local mw,mgw = self.modules[i]:parameters() - if mw then - tinsert(w,mw) - tinsert(gw,mgw) - end - end - return w,gw -end - function Concat:__tostring__() local tab = ' ' local line = '\n' diff --git a/Container.lua b/Container.lua new file mode 100644 index 0000000..125ab98 --- /dev/null +++ b/Container.lua @@ -0,0 +1,80 @@ +-- This is code common to container modules, which are collections of +-- smaller constituent modules like Parallel, Sequential, etc. +local Container, parent = + torch.class('nn.Container', 'nn.Module') + +function Container:__init(...) + parent.__init(self, ...) + self.modules = {} +end + +function Container:add(module) + table.insert(self.modules, module) + return self +end + +function Container:get(index) + return self.modules[index] +end + +function Container:size() + return #self.modules +end + +function Container:zeroGradParameters() + for i=1,#self.modules do + self.modules[i]:zeroGradParameters() + end +end + +function Container:updateParameters(learningRate) + for _,module in ipairs(self.modules) do + module:updateParameters(learningRate) + end +end + +function Container:training() + for i=1,#self.modules do + self.modules[i]:training() + end +end + +function Container:evaluate() + for i=1,#self.modules do + self.modules[i]:evaluate() + end +end + +function Container:share(mlp, ...) + for i=1,#self.modules do + self.modules[i]:share(mlp.modules[i], ...); + end +end + +function Container:reset(stdv) + for i=1,#self.modules do + self.modules[i]:reset(stdv) + end +end + +function Container:parameters() + local function tinsert(to, from) + if type(from) == 'table' then + for i=1,#from do + tinsert(to,from[i]) + end + else + table.insert(to,from) + end + end + local w = {} + local gw = {} + for i=1,#self.modules do + local mw,mgw = self.modules[i]:parameters() + if mw then + tinsert(w,mw) + tinsert(gw,mgw) + end + end + return w,gw +end diff --git a/Parallel.lua b/Parallel.lua index 3057ba2..ef42723 100644 --- a/Parallel.lua +++ b/Parallel.lua @@ -1,4 +1,4 @@ -local Parallel, parent = torch.class('nn.Parallel', 'nn.Module') +local Parallel, parent = torch.class('nn.Parallel', 'nn.Container') function Parallel:__init(inputDimension,outputDimension) parent.__init(self) @@ -8,15 +8,6 @@ function Parallel:__init(inputDimension,outputDimension) self.outputDimension = outputDimension end -function Parallel:add(module) - table.insert(self.modules, module) - return self -end - -function Parallel:get(index) - return self.modules[index] -end - function Parallel:updateOutput(input) local modules=input:size(self.inputDimension) @@ -99,36 +90,6 @@ function Parallel:accUpdateGradParameters(input, gradOutput, lr) end end -function Parallel:zeroGradParameters() - for _,module in ipairs(self.modules) do - module:zeroGradParameters() - end -end - -function Parallel:updateParameters(learningRate) - for _,module in ipairs(self.modules) do - module:updateParameters(learningRate) - end -end - -function Parallel:training() - for i=1,#self.modules do - self.modules[i]:training() - end -end - -function Parallel:evaluate() - for i=1,#self.modules do - self.modules[i]:evaluate() - end -end - -function Parallel:share(mlp,...) - for i=1,#self.modules do - self.modules[i]:share(mlp.modules[i],...); - end -end - function Parallel:parameters() local function tinsert(to, from) if type(from) == 'table' then diff --git a/Sequential.lua b/Sequential.lua index 97554b3..3288e6d 100644 --- a/Sequential.lua +++ b/Sequential.lua @@ -1,9 +1,4 @@ -local Sequential, parent = torch.class('nn.Sequential', 'nn.Module') - -function Sequential:__init() - parent.__init(self) - self.modules = {} -end +local Sequential, _ = torch.class('nn.Sequential', 'nn.Container') function Sequential:add(module) if #self.modules == 0 then @@ -24,14 +19,6 @@ function Sequential:insert(module, index) self.gradInput = self.modules[1].gradInput end -function Sequential:size() - return #self.modules -end - -function Sequential:get(index) - return self.modules[index] -end - function Sequential:updateOutput(input) local currentOutput = input for i=1,#self.modules do @@ -82,63 +69,6 @@ function Sequential:accUpdateGradParameters(input, gradOutput, lr) currentModule:accUpdateGradParameters(input, currentGradOutput, lr) end -function Sequential:zeroGradParameters() - for i=1,#self.modules do - self.modules[i]:zeroGradParameters() - end -end - -function Sequential:updateParameters(learningRate) - for i=1,#self.modules do - self.modules[i]:updateParameters(learningRate) - end -end - -function Sequential:training() - for i=1,#self.modules do - self.modules[i]:training() - end -end - -function Sequential:evaluate() - for i=1,#self.modules do - self.modules[i]:evaluate() - end -end - -function Sequential:share(mlp,...) - for i=1,#self.modules do - self.modules[i]:share(mlp.modules[i],...); - end -end - -function Sequential:reset(stdv) - for i=1,#self.modules do - self.modules[i]:reset(stdv) - end -end - -function Sequential:parameters() - local function tinsert(to, from) - if type(from) == 'table' then - for i=1,#from do - tinsert(to,from[i]) - end - else - table.insert(to,from) - end - end - local w = {} - local gw = {} - for i=1,#self.modules do - local mw,mgw = self.modules[i]:parameters() - if mw then - tinsert(w,mw) - tinsert(gw,mgw) - end - end - return w,gw -end function Sequential:__tostring__() local tab = ' ' @@ -4,6 +4,7 @@ require('libnn') include('ErrorMessages.lua') include('Module.lua') +include('Container.lua') include('Concat.lua') include('Parallel.lua') include('Sequential.lua') @@ -477,7 +477,7 @@ function nntest.WeightedEuclidean() local inj = math.random(13,5) local input = torch.Tensor(ini):zero() local module = nn.WeightedEuclidean(ini,inj) - + local err = jac.testJacobian(module,input) mytester:assertlt(err,precision, 'error on state ') |