From 94199eb2fd540ffa4a51964ae006ae8090baf6c6 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Wed, 24 Feb 2016 22:33:37 +0000 Subject: Clear tensors in a whole graph on :clearState() --- gmodule.lua | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/gmodule.lua b/gmodule.lua index 2556b15..98bc196 100644 --- a/gmodule.lua +++ b/gmodule.lua @@ -117,7 +117,7 @@ function gModule:__init(inputs,outputs) -- computation on the graph is done through topsort of forward and backward graphs self.forwardnodes = self.fg:topsort() self.backwardnodes = self.bg:topsort() - + -- iteratare over all nodes: check, tag and add to container for i,node in ipairs(self.forwardnodes) do -- check for unused inputs or unused split() outputs @@ -227,6 +227,18 @@ function gModule:updateOutput(input) return self:runForwardFunction('updateOutput',input) end +function gModule:clearState() + local ret = parent.clearState(self) + for _,node in ipairs(self.backwardnodes) do + node.data.gradOutput = nil + node.data.gradOutputBuffer = nil + end + for _,node in ipairs(self.forwardnodes) do + node.data.input = nil + end + return ret +end + function gModule:runForwardFunction(func,input) if type(func) == "string" then local func_name = func -- cgit v1.2.3