Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorsoumith <soumith@fb.com>2015-01-03 07:45:04 +0300
committersoumith <soumith@fb.com>2015-01-03 07:45:04 +0300
commit1efff4dc0b0bd396be618c352a9d6941a2f7b8b4 (patch)
tree412418951bfbf5174bff2fa99eb1f48f0b71ac1e /Sequential.lua
parenta38407a57def785acc819066db70f1649da47f03 (diff)
refactoring all the common container code into nn.Container
Diffstat (limited to 'Sequential.lua')
-rw-r--r--Sequential.lua72
1 files changed, 1 insertions, 71 deletions
diff --git a/Sequential.lua b/Sequential.lua
index 97554b3..3288e6d 100644
--- a/Sequential.lua
+++ b/Sequential.lua
@@ -1,9 +1,4 @@
-local Sequential, parent = torch.class('nn.Sequential', 'nn.Module')
-
-function Sequential:__init()
- parent.__init(self)
- self.modules = {}
-end
+local Sequential, _ = torch.class('nn.Sequential', 'nn.Container')
function Sequential:add(module)
if #self.modules == 0 then
@@ -24,14 +19,6 @@ function Sequential:insert(module, index)
self.gradInput = self.modules[1].gradInput
end
-function Sequential:size()
- return #self.modules
-end
-
-function Sequential:get(index)
- return self.modules[index]
-end
-
function Sequential:updateOutput(input)
local currentOutput = input
for i=1,#self.modules do
@@ -82,63 +69,6 @@ function Sequential:accUpdateGradParameters(input, gradOutput, lr)
currentModule:accUpdateGradParameters(input, currentGradOutput, lr)
end
-function Sequential:zeroGradParameters()
- for i=1,#self.modules do
- self.modules[i]:zeroGradParameters()
- end
-end
-
-function Sequential:updateParameters(learningRate)
- for i=1,#self.modules do
- self.modules[i]:updateParameters(learningRate)
- end
-end
-
-function Sequential:training()
- for i=1,#self.modules do
- self.modules[i]:training()
- end
-end
-
-function Sequential:evaluate()
- for i=1,#self.modules do
- self.modules[i]:evaluate()
- end
-end
-
-function Sequential:share(mlp,...)
- for i=1,#self.modules do
- self.modules[i]:share(mlp.modules[i],...);
- end
-end
-
-function Sequential:reset(stdv)
- for i=1,#self.modules do
- self.modules[i]:reset(stdv)
- end
-end
-
-function Sequential:parameters()
- local function tinsert(to, from)
- if type(from) == 'table' then
- for i=1,#from do
- tinsert(to,from[i])
- end
- else
- table.insert(to,from)
- end
- end
- local w = {}
- local gw = {}
- for i=1,#self.modules do
- local mw,mgw = self.modules[i]:parameters()
- if mw then
- tinsert(w,mw)
- tinsert(gw,mgw)
- end
- end
- return w,gw
-end
function Sequential:__tostring__()
local tab = ' '