diff options
author | Soumith Chintala <soumith@gmail.com> | 2016-03-04 00:01:17 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2016-03-04 00:01:17 +0300 |
commit | db81121991c7b41d927f103425c500070fb9c708 (patch) | |
tree | 1011355ccbd797b82ebb9f6a08a80059d6f64567 | |
parent | cf016fe034125e67fb2a53efad2f8e85a6a4a3a0 (diff) | |
parent | 94199eb2fd540ffa4a51964ae006ae8090baf6c6 (diff) |
Merge pull request #104 from apaszke/master
Clear tensors in a whole graph on :clearState()
-rw-r--r-- | gmodule.lua | 14 |
1 files changed, 13 insertions, 1 deletions
diff --git a/gmodule.lua b/gmodule.lua index 2556b15..98bc196 100644 --- a/gmodule.lua +++ b/gmodule.lua @@ -117,7 +117,7 @@ function gModule:__init(inputs,outputs) -- computation on the graph is done through topsort of forward and backward graphs self.forwardnodes = self.fg:topsort() self.backwardnodes = self.bg:topsort() - + -- iteratare over all nodes: check, tag and add to container for i,node in ipairs(self.forwardnodes) do -- check for unused inputs or unused split() outputs @@ -227,6 +227,18 @@ function gModule:updateOutput(input) return self:runForwardFunction('updateOutput',input) end +function gModule:clearState() + local ret = parent.clearState(self) + for _,node in ipairs(self.backwardnodes) do + node.data.gradOutput = nil + node.data.gradOutputBuffer = nil + end + for _,node in ipairs(self.forwardnodes) do + node.data.input = nil + end + return ret +end + function gModule:runForwardFunction(func,input) if type(func) == "string" then local func_name = func |