Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nngraph.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2016-03-04 00:01:17 +0300
committerSoumith Chintala <soumith@gmail.com>2016-03-04 00:01:17 +0300
commitdb81121991c7b41d927f103425c500070fb9c708 (patch)
tree1011355ccbd797b82ebb9f6a08a80059d6f64567
parentcf016fe034125e67fb2a53efad2f8e85a6a4a3a0 (diff)
parent94199eb2fd540ffa4a51964ae006ae8090baf6c6 (diff)
Merge pull request #104 from apaszke/master
Clear tensors in a whole graph on :clearState()
-rw-r--r--gmodule.lua14
1 files changed, 13 insertions, 1 deletions
diff --git a/gmodule.lua b/gmodule.lua
index 2556b15..98bc196 100644
--- a/gmodule.lua
+++ b/gmodule.lua
@@ -117,7 +117,7 @@ function gModule:__init(inputs,outputs)
-- computation on the graph is done through topsort of forward and backward graphs
self.forwardnodes = self.fg:topsort()
self.backwardnodes = self.bg:topsort()
-
+
-- iteratare over all nodes: check, tag and add to container
for i,node in ipairs(self.forwardnodes) do
-- check for unused inputs or unused split() outputs
@@ -227,6 +227,18 @@ function gModule:updateOutput(input)
return self:runForwardFunction('updateOutput',input)
end
+function gModule:clearState()
+ local ret = parent.clearState(self)
+ for _,node in ipairs(self.backwardnodes) do
+ node.data.gradOutput = nil
+ node.data.gradOutputBuffer = nil
+ end
+ for _,node in ipairs(self.forwardnodes) do
+ node.data.input = nil
+ end
+ return ret
+end
+
function gModule:runForwardFunction(func,input)
if type(func) == "string" then
local func_name = func