diff options
author | Adam Paszke <adam.paszke@gmail.com> | 2016-02-25 01:33:37 +0300 |
---|---|---|
committer | Adam Paszke <adam.paszke@gmail.com> | 2016-03-03 21:33:29 +0300 |
commit | 94199eb2fd540ffa4a51964ae006ae8090baf6c6 (patch) | |
tree | 3818df1db9ea8054dd258dcff945fb90c375d0bc | |
parent | ccc9627a95972eca32915100ceddddcfe6e87f43 (diff) |
Clear tensors in a whole graph on :clearState()
-rw-r--r-- | gmodule.lua | 14 |
1 files changed, 13 insertions, 1 deletions
diff --git a/gmodule.lua b/gmodule.lua index 2556b15..98bc196 100644 --- a/gmodule.lua +++ b/gmodule.lua @@ -117,7 +117,7 @@ function gModule:__init(inputs,outputs) -- computation on the graph is done through topsort of forward and backward graphs self.forwardnodes = self.fg:topsort() self.backwardnodes = self.bg:topsort() - + -- iteratare over all nodes: check, tag and add to container for i,node in ipairs(self.forwardnodes) do -- check for unused inputs or unused split() outputs @@ -227,6 +227,18 @@ function gModule:updateOutput(input) return self:runForwardFunction('updateOutput',input) end +function gModule:clearState() + local ret = parent.clearState(self) + for _,node in ipairs(self.backwardnodes) do + node.data.gradOutput = nil + node.data.gradOutputBuffer = nil + end + for _,node in ipairs(self.forwardnodes) do + node.data.input = nil + end + return ret +end + function gModule:runForwardFunction(func,input) if type(func) == "string" then local func_name = func |