diff options
author | Adam Lerer <alerer@fb.com> | 2015-10-01 19:30:25 +0300 |
---|---|---|
committer | Adam Lerer <alerer@fb.com> | 2015-10-01 19:30:25 +0300 |
commit | 119d9100a0390bc6d2fb33dd449a5d8630682aa8 (patch) | |
tree | 7ba7da67e4f5bf88e9c092f88496493538033e7d /gmodule.lua | |
parent | c654b19a11004ffa1061ca689f1ca89fe527cbe7 (diff) |
Integrate apply() and type() improvements from https://github.com/torch/nn/pull/303
Diffstat (limited to 'gmodule.lua')
-rw-r--r-- | gmodule.lua | 34 |
1 files changed, 21 insertions, 13 deletions
diff --git a/gmodule.lua b/gmodule.lua index 1d569b2..83e2c7f 100644 --- a/gmodule.lua +++ b/gmodule.lua @@ -118,6 +118,12 @@ function gModule:__init(inputs,outputs) -- computation on the graph is done through topsort of forward and backward graphs self.forwardnodes = self.fg:topsort() self.backwardnodes = self.bg:topsort() + self.modules = {} + for _, node in ipairs(self.forwardnodes) do + if node.data.module then + table.insert(self.modules, node.data.module) + end + end -- Checking for unused inputs or unused split() outputs. for i,forwardNode in ipairs(self.forwardnodes) do if forwardNode.data.nSplitOutputs and forwardNode.data.nSplitOutputs ~= #forwardNode.children then @@ -138,14 +144,6 @@ function gModule:__init(inputs,outputs) end end -function gModule:apply(func) - for i,node in ipairs(self.forwardnodes) do - if node.data.module then - func(node.data.module) - end - end -end - function gModule:map(gm, func) for i,node in ipairs(self.forwardnodes) do local gmnode = gm.forwardnodes[i] @@ -179,11 +177,15 @@ function gModule:share(gm, ...) end function gModule:training() - self:apply(function(module) module:training() end) + for _, m in ipairs(self.modules) do + m:training() + end end function gModule:evaluate() - self:apply(function(module) module:evaluate() end) + for _, m in ipairs(self.modules) do + m:evaluate() + end end --[[ Recursively applies type(type_str) to any tensors in the argument. If the @@ -201,7 +203,9 @@ local function recursiveType(param, type_str) return param end -function gModule:type(type) +function gModule:type(type, tensorCache) + tensorCache = tensorCache or {} + local function applyTypeToTable(table) for key, value in pairs(table) do table[key] = recursiveType(table[key], type) @@ -214,7 +218,9 @@ function gModule:type(type) if self.outnode then applyTypeToTable(self.outnode.data) end -- Loop through modules and convert data - self:apply(function(module) module:type(type) end) + for _, m in ipairs(self.modules) do + m:type(type, tensorCache) + end for i,node in ipairs(self.backwardnodes) do if node.data.gradOutputBuffer ~= nil then @@ -226,7 +232,9 @@ function gModule:type(type) end function gModule:zeroGradParameters() - self:apply(function(module) module:zeroGradParameters() end) + for _, m in ipairs(self.modules) do + m:zeroGradParameters() + end end function gModule:updateOutput(input) |