Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nngraph.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAdam Lerer <alerer@fb.com>2015-10-01 19:30:25 +0300
committerAdam Lerer <alerer@fb.com>2015-10-01 19:30:25 +0300
commit119d9100a0390bc6d2fb33dd449a5d8630682aa8 (patch)
tree7ba7da67e4f5bf88e9c092f88496493538033e7d /gmodule.lua
parentc654b19a11004ffa1061ca689f1ca89fe527cbe7 (diff)
Integrate apply() and type() improvements from https://github.com/torch/nn/pull/303
Diffstat (limited to 'gmodule.lua')
-rw-r--r--gmodule.lua34
1 files changed, 21 insertions, 13 deletions
diff --git a/gmodule.lua b/gmodule.lua
index 1d569b2..83e2c7f 100644
--- a/gmodule.lua
+++ b/gmodule.lua
@@ -118,6 +118,12 @@ function gModule:__init(inputs,outputs)
-- computation on the graph is done through topsort of forward and backward graphs
self.forwardnodes = self.fg:topsort()
self.backwardnodes = self.bg:topsort()
+ self.modules = {}
+ for _, node in ipairs(self.forwardnodes) do
+ if node.data.module then
+ table.insert(self.modules, node.data.module)
+ end
+ end
-- Checking for unused inputs or unused split() outputs.
for i,forwardNode in ipairs(self.forwardnodes) do
if forwardNode.data.nSplitOutputs and forwardNode.data.nSplitOutputs ~= #forwardNode.children then
@@ -138,14 +144,6 @@ function gModule:__init(inputs,outputs)
end
end
-function gModule:apply(func)
- for i,node in ipairs(self.forwardnodes) do
- if node.data.module then
- func(node.data.module)
- end
- end
-end
-
function gModule:map(gm, func)
for i,node in ipairs(self.forwardnodes) do
local gmnode = gm.forwardnodes[i]
@@ -179,11 +177,15 @@ function gModule:share(gm, ...)
end
function gModule:training()
- self:apply(function(module) module:training() end)
+ for _, m in ipairs(self.modules) do
+ m:training()
+ end
end
function gModule:evaluate()
- self:apply(function(module) module:evaluate() end)
+ for _, m in ipairs(self.modules) do
+ m:evaluate()
+ end
end
--[[ Recursively applies type(type_str) to any tensors in the argument. If the
@@ -201,7 +203,9 @@ local function recursiveType(param, type_str)
return param
end
-function gModule:type(type)
+function gModule:type(type, tensorCache)
+ tensorCache = tensorCache or {}
+
local function applyTypeToTable(table)
for key, value in pairs(table) do
table[key] = recursiveType(table[key], type)
@@ -214,7 +218,9 @@ function gModule:type(type)
if self.outnode then applyTypeToTable(self.outnode.data) end
-- Loop through modules and convert data
- self:apply(function(module) module:type(type) end)
+ for _, m in ipairs(self.modules) do
+ m:type(type, tensorCache)
+ end
for i,node in ipairs(self.backwardnodes) do
if node.data.gradOutputBuffer ~= nil then
@@ -226,7 +232,9 @@ function gModule:type(type)
end
function gModule:zeroGradParameters()
- self:apply(function(module) module:zeroGradParameters() end)
+ for _, m in ipairs(self.modules) do
+ m:zeroGradParameters()
+ end
end
function gModule:updateOutput(input)