diff options
author | soumith <soumith@fb.com> | 2015-01-03 07:45:04 +0300 |
---|---|---|
committer | soumith <soumith@fb.com> | 2015-01-03 07:45:04 +0300 |
commit | 1efff4dc0b0bd396be618c352a9d6941a2f7b8b4 (patch) | |
tree | 412418951bfbf5174bff2fa99eb1f48f0b71ac1e /Sequential.lua | |
parent | a38407a57def785acc819066db70f1649da47f03 (diff) |
refactoring all the common container code into nn.Container
Diffstat (limited to 'Sequential.lua')
-rw-r--r-- | Sequential.lua | 72 |
1 files changed, 1 insertions, 71 deletions
diff --git a/Sequential.lua b/Sequential.lua index 97554b3..3288e6d 100644 --- a/Sequential.lua +++ b/Sequential.lua @@ -1,9 +1,4 @@ -local Sequential, parent = torch.class('nn.Sequential', 'nn.Module') - -function Sequential:__init() - parent.__init(self) - self.modules = {} -end +local Sequential, _ = torch.class('nn.Sequential', 'nn.Container') function Sequential:add(module) if #self.modules == 0 then @@ -24,14 +19,6 @@ function Sequential:insert(module, index) self.gradInput = self.modules[1].gradInput end -function Sequential:size() - return #self.modules -end - -function Sequential:get(index) - return self.modules[index] -end - function Sequential:updateOutput(input) local currentOutput = input for i=1,#self.modules do @@ -82,63 +69,6 @@ function Sequential:accUpdateGradParameters(input, gradOutput, lr) currentModule:accUpdateGradParameters(input, currentGradOutput, lr) end -function Sequential:zeroGradParameters() - for i=1,#self.modules do - self.modules[i]:zeroGradParameters() - end -end - -function Sequential:updateParameters(learningRate) - for i=1,#self.modules do - self.modules[i]:updateParameters(learningRate) - end -end - -function Sequential:training() - for i=1,#self.modules do - self.modules[i]:training() - end -end - -function Sequential:evaluate() - for i=1,#self.modules do - self.modules[i]:evaluate() - end -end - -function Sequential:share(mlp,...) - for i=1,#self.modules do - self.modules[i]:share(mlp.modules[i],...); - end -end - -function Sequential:reset(stdv) - for i=1,#self.modules do - self.modules[i]:reset(stdv) - end -end - -function Sequential:parameters() - local function tinsert(to, from) - if type(from) == 'table' then - for i=1,#from do - tinsert(to,from[i]) - end - else - table.insert(to,from) - end - end - local w = {} - local gw = {} - for i=1,#self.modules do - local mw,mgw = self.modules[i]:parameters() - if mw then - tinsert(w,mw) - tinsert(gw,mgw) - end - end - return w,gw -end function Sequential:__tostring__() local tab = ' ' |