diff options
author | soumith <soumith@fb.com> | 2015-01-03 07:45:04 +0300 |
---|---|---|
committer | soumith <soumith@fb.com> | 2015-01-03 07:45:04 +0300 |
commit | 1efff4dc0b0bd396be618c352a9d6941a2f7b8b4 (patch) | |
tree | 412418951bfbf5174bff2fa99eb1f48f0b71ac1e /Concat.lua | |
parent | a38407a57def785acc819066db70f1649da47f03 (diff) |
refactoring all the common container code into nn.Container
Diffstat (limited to 'Concat.lua')
-rw-r--r-- | Concat.lua | 66 |
1 files changed, 2 insertions, 64 deletions
@@ -1,21 +1,11 @@ -local Concat, parent = torch.class('nn.Concat', 'nn.Module') +local Concat, parent = torch.class('nn.Concat', 'nn.Container') function Concat:__init(dimension) - parent.__init(self) - self.modules = {} + parent.__init(self, dimension) self.size = torch.LongStorage() self.dimension = dimension end -function Concat:add(module) - table.insert(self.modules, module) - return self -end - -function Concat:get(index) - return self.modules[index] -end - function Concat:updateOutput(input) local outs = {} for i=1,#self.modules do @@ -83,58 +73,6 @@ function Concat:accUpdateGradParameters(input, gradOutput, lr) end end -function Concat:zeroGradParameters() - for _,module in ipairs(self.modules) do - module:zeroGradParameters() - end -end - -function Concat:updateParameters(learningRate) - for _,module in ipairs(self.modules) do - module:updateParameters(learningRate) - end -end - -function Concat:training() - for i=1,#self.modules do - self.modules[i]:training() - end -end - -function Concat:evaluate() - for i=1,#self.modules do - self.modules[i]:evaluate() - end -end - -function Concat:share(mlp,...) - for i=1,#self.modules do - self.modules[i]:share(mlp.modules[i],...); - end -end - -function Concat:parameters() - local function tinsert(to, from) - if type(from) == 'table' then - for i=1,#from do - tinsert(to,from[i]) - end - else - table.insert(to,from) - end - end - local w = {} - local gw = {} - for i=1,#self.modules do - local mw,mgw = self.modules[i]:parameters() - if mw then - tinsert(w,mw) - tinsert(gw,mgw) - end - end - return w,gw -end - function Concat:__tostring__() local tab = ' ' local line = '\n' |