Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorsoumith <soumith@fb.com>2015-01-03 07:45:04 +0300
committersoumith <soumith@fb.com>2015-01-03 07:45:04 +0300
commit1efff4dc0b0bd396be618c352a9d6941a2f7b8b4 (patch)
tree412418951bfbf5174bff2fa99eb1f48f0b71ac1e /Concat.lua
parenta38407a57def785acc819066db70f1649da47f03 (diff)
refactoring all the common container code into nn.Container
Diffstat (limited to 'Concat.lua')
-rw-r--r--Concat.lua66
1 files changed, 2 insertions, 64 deletions
diff --git a/Concat.lua b/Concat.lua
index c94808d..b0436a5 100644
--- a/Concat.lua
+++ b/Concat.lua
@@ -1,21 +1,11 @@
-local Concat, parent = torch.class('nn.Concat', 'nn.Module')
+local Concat, parent = torch.class('nn.Concat', 'nn.Container')
function Concat:__init(dimension)
- parent.__init(self)
- self.modules = {}
+ parent.__init(self, dimension)
self.size = torch.LongStorage()
self.dimension = dimension
end
-function Concat:add(module)
- table.insert(self.modules, module)
- return self
-end
-
-function Concat:get(index)
- return self.modules[index]
-end
-
function Concat:updateOutput(input)
local outs = {}
for i=1,#self.modules do
@@ -83,58 +73,6 @@ function Concat:accUpdateGradParameters(input, gradOutput, lr)
end
end
-function Concat:zeroGradParameters()
- for _,module in ipairs(self.modules) do
- module:zeroGradParameters()
- end
-end
-
-function Concat:updateParameters(learningRate)
- for _,module in ipairs(self.modules) do
- module:updateParameters(learningRate)
- end
-end
-
-function Concat:training()
- for i=1,#self.modules do
- self.modules[i]:training()
- end
-end
-
-function Concat:evaluate()
- for i=1,#self.modules do
- self.modules[i]:evaluate()
- end
-end
-
-function Concat:share(mlp,...)
- for i=1,#self.modules do
- self.modules[i]:share(mlp.modules[i],...);
- end
-end
-
-function Concat:parameters()
- local function tinsert(to, from)
- if type(from) == 'table' then
- for i=1,#from do
- tinsert(to,from[i])
- end
- else
- table.insert(to,from)
- end
- end
- local w = {}
- local gw = {}
- for i=1,#self.modules do
- local mw,mgw = self.modules[i]:parameters()
- if mw then
- tinsert(w,mw)
- tinsert(gw,mgw)
- end
- end
- return w,gw
-end
-
function Concat:__tostring__()
local tab = ' '
local line = '\n'