Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2015-01-03 07:48:18 +0300
committerSoumith Chintala <soumith@gmail.com>2015-01-03 07:48:18 +0300
commit4e0a96d801060121521ccc46f7294aeb3b247965 (patch)
tree412418951bfbf5174bff2fa99eb1f48f0b71ac1e
parentda0c4e81ddf757786a89073dab5d1b1d192216b5 (diff)
parent1efff4dc0b0bd396be618c352a9d6941a2f7b8b4 (diff)
Merge pull request #132 from torch/container
refactoring all the common container code into nn.Container
-rw-r--r--Concat.lua66
-rw-r--r--Container.lua80
-rw-r--r--Parallel.lua41
-rw-r--r--Sequential.lua72
-rw-r--r--init.lua1
-rw-r--r--test.lua2
6 files changed, 86 insertions, 176 deletions
diff --git a/Concat.lua b/Concat.lua
index c94808d..b0436a5 100644
--- a/Concat.lua
+++ b/Concat.lua
@@ -1,21 +1,11 @@
-local Concat, parent = torch.class('nn.Concat', 'nn.Module')
+local Concat, parent = torch.class('nn.Concat', 'nn.Container')
function Concat:__init(dimension)
- parent.__init(self)
- self.modules = {}
+ parent.__init(self, dimension)
self.size = torch.LongStorage()
self.dimension = dimension
end
-function Concat:add(module)
- table.insert(self.modules, module)
- return self
-end
-
-function Concat:get(index)
- return self.modules[index]
-end
-
function Concat:updateOutput(input)
local outs = {}
for i=1,#self.modules do
@@ -83,58 +73,6 @@ function Concat:accUpdateGradParameters(input, gradOutput, lr)
end
end
-function Concat:zeroGradParameters()
- for _,module in ipairs(self.modules) do
- module:zeroGradParameters()
- end
-end
-
-function Concat:updateParameters(learningRate)
- for _,module in ipairs(self.modules) do
- module:updateParameters(learningRate)
- end
-end
-
-function Concat:training()
- for i=1,#self.modules do
- self.modules[i]:training()
- end
-end
-
-function Concat:evaluate()
- for i=1,#self.modules do
- self.modules[i]:evaluate()
- end
-end
-
-function Concat:share(mlp,...)
- for i=1,#self.modules do
- self.modules[i]:share(mlp.modules[i],...);
- end
-end
-
-function Concat:parameters()
- local function tinsert(to, from)
- if type(from) == 'table' then
- for i=1,#from do
- tinsert(to,from[i])
- end
- else
- table.insert(to,from)
- end
- end
- local w = {}
- local gw = {}
- for i=1,#self.modules do
- local mw,mgw = self.modules[i]:parameters()
- if mw then
- tinsert(w,mw)
- tinsert(gw,mgw)
- end
- end
- return w,gw
-end
-
function Concat:__tostring__()
local tab = ' '
local line = '\n'
diff --git a/Container.lua b/Container.lua
new file mode 100644
index 0000000..125ab98
--- /dev/null
+++ b/Container.lua
@@ -0,0 +1,80 @@
+-- This is code common to container modules, which are collections of
+-- smaller constituent modules like Parallel, Sequential, etc.
+local Container, parent =
+ torch.class('nn.Container', 'nn.Module')
+
+function Container:__init(...)
+ parent.__init(self, ...)
+ self.modules = {}
+end
+
+function Container:add(module)
+ table.insert(self.modules, module)
+ return self
+end
+
+function Container:get(index)
+ return self.modules[index]
+end
+
+function Container:size()
+ return #self.modules
+end
+
+function Container:zeroGradParameters()
+ for i=1,#self.modules do
+ self.modules[i]:zeroGradParameters()
+ end
+end
+
+function Container:updateParameters(learningRate)
+ for _,module in ipairs(self.modules) do
+ module:updateParameters(learningRate)
+ end
+end
+
+function Container:training()
+ for i=1,#self.modules do
+ self.modules[i]:training()
+ end
+end
+
+function Container:evaluate()
+ for i=1,#self.modules do
+ self.modules[i]:evaluate()
+ end
+end
+
+function Container:share(mlp, ...)
+ for i=1,#self.modules do
+ self.modules[i]:share(mlp.modules[i], ...);
+ end
+end
+
+function Container:reset(stdv)
+ for i=1,#self.modules do
+ self.modules[i]:reset(stdv)
+ end
+end
+
+function Container:parameters()
+ local function tinsert(to, from)
+ if type(from) == 'table' then
+ for i=1,#from do
+ tinsert(to,from[i])
+ end
+ else
+ table.insert(to,from)
+ end
+ end
+ local w = {}
+ local gw = {}
+ for i=1,#self.modules do
+ local mw,mgw = self.modules[i]:parameters()
+ if mw then
+ tinsert(w,mw)
+ tinsert(gw,mgw)
+ end
+ end
+ return w,gw
+end
diff --git a/Parallel.lua b/Parallel.lua
index 3057ba2..ef42723 100644
--- a/Parallel.lua
+++ b/Parallel.lua
@@ -1,4 +1,4 @@
-local Parallel, parent = torch.class('nn.Parallel', 'nn.Module')
+local Parallel, parent = torch.class('nn.Parallel', 'nn.Container')
function Parallel:__init(inputDimension,outputDimension)
parent.__init(self)
@@ -8,15 +8,6 @@ function Parallel:__init(inputDimension,outputDimension)
self.outputDimension = outputDimension
end
-function Parallel:add(module)
- table.insert(self.modules, module)
- return self
-end
-
-function Parallel:get(index)
- return self.modules[index]
-end
-
function Parallel:updateOutput(input)
local modules=input:size(self.inputDimension)
@@ -99,36 +90,6 @@ function Parallel:accUpdateGradParameters(input, gradOutput, lr)
end
end
-function Parallel:zeroGradParameters()
- for _,module in ipairs(self.modules) do
- module:zeroGradParameters()
- end
-end
-
-function Parallel:updateParameters(learningRate)
- for _,module in ipairs(self.modules) do
- module:updateParameters(learningRate)
- end
-end
-
-function Parallel:training()
- for i=1,#self.modules do
- self.modules[i]:training()
- end
-end
-
-function Parallel:evaluate()
- for i=1,#self.modules do
- self.modules[i]:evaluate()
- end
-end
-
-function Parallel:share(mlp,...)
- for i=1,#self.modules do
- self.modules[i]:share(mlp.modules[i],...);
- end
-end
-
function Parallel:parameters()
local function tinsert(to, from)
if type(from) == 'table' then
diff --git a/Sequential.lua b/Sequential.lua
index 97554b3..3288e6d 100644
--- a/Sequential.lua
+++ b/Sequential.lua
@@ -1,9 +1,4 @@
-local Sequential, parent = torch.class('nn.Sequential', 'nn.Module')
-
-function Sequential:__init()
- parent.__init(self)
- self.modules = {}
-end
+local Sequential, _ = torch.class('nn.Sequential', 'nn.Container')
function Sequential:add(module)
if #self.modules == 0 then
@@ -24,14 +19,6 @@ function Sequential:insert(module, index)
self.gradInput = self.modules[1].gradInput
end
-function Sequential:size()
- return #self.modules
-end
-
-function Sequential:get(index)
- return self.modules[index]
-end
-
function Sequential:updateOutput(input)
local currentOutput = input
for i=1,#self.modules do
@@ -82,63 +69,6 @@ function Sequential:accUpdateGradParameters(input, gradOutput, lr)
currentModule:accUpdateGradParameters(input, currentGradOutput, lr)
end
-function Sequential:zeroGradParameters()
- for i=1,#self.modules do
- self.modules[i]:zeroGradParameters()
- end
-end
-
-function Sequential:updateParameters(learningRate)
- for i=1,#self.modules do
- self.modules[i]:updateParameters(learningRate)
- end
-end
-
-function Sequential:training()
- for i=1,#self.modules do
- self.modules[i]:training()
- end
-end
-
-function Sequential:evaluate()
- for i=1,#self.modules do
- self.modules[i]:evaluate()
- end
-end
-
-function Sequential:share(mlp,...)
- for i=1,#self.modules do
- self.modules[i]:share(mlp.modules[i],...);
- end
-end
-
-function Sequential:reset(stdv)
- for i=1,#self.modules do
- self.modules[i]:reset(stdv)
- end
-end
-
-function Sequential:parameters()
- local function tinsert(to, from)
- if type(from) == 'table' then
- for i=1,#from do
- tinsert(to,from[i])
- end
- else
- table.insert(to,from)
- end
- end
- local w = {}
- local gw = {}
- for i=1,#self.modules do
- local mw,mgw = self.modules[i]:parameters()
- if mw then
- tinsert(w,mw)
- tinsert(gw,mgw)
- end
- end
- return w,gw
-end
function Sequential:__tostring__()
local tab = ' '
diff --git a/init.lua b/init.lua
index 0ab72a8..c556321 100644
--- a/init.lua
+++ b/init.lua
@@ -4,6 +4,7 @@ require('libnn')
include('ErrorMessages.lua')
include('Module.lua')
+include('Container.lua')
include('Concat.lua')
include('Parallel.lua')
include('Sequential.lua')
diff --git a/test.lua b/test.lua
index 56298e0..3cf6a58 100644
--- a/test.lua
+++ b/test.lua
@@ -477,7 +477,7 @@ function nntest.WeightedEuclidean()
local inj = math.random(13,5)
local input = torch.Tensor(ini):zero()
local module = nn.WeightedEuclidean(ini,inj)
-
+
local err = jac.testJacobian(module,input)
mytester:assertlt(err,precision, 'error on state ')