diff options
author | Clement Farabet <cfarabet@twitter.com> | 2015-09-04 23:31:01 +0300 |
---|---|---|
committer | Clement Farabet <cfarabet@twitter.com> | 2015-09-04 23:31:01 +0300 |
commit | ba7f3ec6ffe7e60e5e07b2886178aec54e6305e5 (patch) | |
tree | 23aaf8f288fdce23b7361a1e7252254a690a6fb9 | |
parent | 72f74d39257a344a2c3237c83f8a828b916817e4 (diff) |
Whitespace cleanup.
-rw-r--r-- | gmodule.lua | 624 | ||||
-rw-r--r-- | graphinspecting.lua | 212 | ||||
-rw-r--r-- | init.lua | 44 | ||||
-rw-r--r-- | nesting.lua | 83 | ||||
-rw-r--r-- | node.lua | 235 | ||||
-rw-r--r-- | simple_print.lua | 197 | ||||
-rw-r--r-- | test/speed.lua | 169 | ||||
-rw-r--r-- | test/test_ModuleFromCriterion.lua | 68 | ||||
-rw-r--r-- | test/test_nngraph.lua | 696 | ||||
-rw-r--r-- | test/test_old.lua | 412 | ||||
-rw-r--r-- | utils.lua | 14 |
11 files changed, 1370 insertions, 1384 deletions
diff --git a/gmodule.lua b/gmodule.lua index 6489f61..79f3033 100644 --- a/gmodule.lua +++ b/gmodule.lua @@ -6,21 +6,21 @@ local istable = utils.istable local istorchclass = utils.istorchclass local function getTotalGradOutput(node) - local gradOutput = node.data.gradOutput - assert(istable(gradOutput), "expecting gradients to sum") - if #gradOutput > 1 then - node.data.gradOutputBuffer = node.data.gradOutputBuffer or nesting.cloneNested(gradOutput[1]) - local gobuff = node.data.gradOutputBuffer - nesting.resizeNestedAs(gobuff, gradOutput[1]) - nesting.fillNested(gobuff, 0) - for i=1,#gradOutput do - nesting.addNestedTo(gobuff, gradOutput[i]) - end - gradOutput = gobuff - else - gradOutput = gradOutput[1] - end - return gradOutput + local gradOutput = node.data.gradOutput + assert(istable(gradOutput), "expecting gradients to sum") + if #gradOutput > 1 then + node.data.gradOutputBuffer = node.data.gradOutputBuffer or nesting.cloneNested(gradOutput[1]) + local gobuff = node.data.gradOutputBuffer + nesting.resizeNestedAs(gobuff, gradOutput[1]) + nesting.fillNested(gobuff, 0) + for i=1,#gradOutput do + nesting.addNestedTo(gobuff, gradOutput[i]) + end + gradOutput = gobuff + else + gradOutput = gradOutput[1] + end + return gradOutput end -- The gModule allows to have a general non-cyclic graph of of modules. @@ -43,84 +43,84 @@ end local gModule, parent = torch.class('nn.gModule','nn.Module') function gModule:__init(inputs,outputs) - parent.__init(self) - -- the graph is defined backwards, we have the output modules as input here - -- we will define a dummy output node that connects all output modules - -- into itself. This will be the output for the forward graph and - -- input point for the backward graph - local outnode = nngraph.Node({input={}}) - for i,n in ipairs(outputs) do - if torch.typename(n) ~= 'nngraph.Node' then - error(string.format('what is this in the outputs[%s]? %s', - i, tostring(n))) - end - outnode:add(n,true) - end - for i,n in ipairs(inputs) do - if torch.typename(n) ~= 'nngraph.Node' then - error(string.format('what is this in the inputs[%s]? %s', - i, tostring(n))) - end - end - -- We add also a dummy input node. - -- The input node will be split to feed the passed input nodes. - local innode = nngraph.Node({input={}}) - assert(#inputs > 0, "no inputs are not supported") - if #inputs == 1 then - inputs[1]:add(innode,true) - else - local splits = {innode:split(#inputs)} - for i = 1, #inputs do - assert(#inputs[i].children == 0, "an input should have no inputs") - end - for i = 1, #inputs do - inputs[i]:add(splits[i],true) - end - end - - -- the backward graph (bg) is for gradients - -- the forward graph (fg) is for function evaluation - self.bg = outnode:graph() - self.fg = self.bg:reverse() - - -- the complete graph is constructed - -- now regenerate the graphs with the additional nodes - assert(#self.fg:roots() == 1, "expecting only one start") - self.innode = self.fg:roots()[1] - assert(self.innode.data == innode.data, "expecting the forward innode") - self.outnode = outnode - self.verbose = false - self.nInputs = #inputs - - -- computation on the graph is done through topsort of forward and backward graphs - self.forwardnodes = self.fg:topsort() - self.backwardnodes = self.bg:topsort() - -- Checking for unused inputs or unused split() outputs. - for i,forwardNode in ipairs(self.forwardnodes) do - if forwardNode.data.nSplitOutputs and forwardNode.data.nSplitOutputs ~= #forwardNode.children then - local nUnused = forwardNode.data.nSplitOutputs - #forwardNode.children - error(string.format("%s of split(%s) outputs are unused", nUnused, - forwardNode.data.nSplitOutputs)) - end - end - -- Adding data.forwardNodeId for nicer node:label() output. - for i,forwardNode in ipairs(self.forwardnodes) do - forwardNode.data.forwardNodeId = forwardNode.id - end - - self.output = nil - self.gradInput = nil - if #self.outnode.children > 1 then - self.output = self.outnode.data.input - end + parent.__init(self) + -- the graph is defined backwards, we have the output modules as input here + -- we will define a dummy output node that connects all output modules + -- into itself. This will be the output for the forward graph and + -- input point for the backward graph + local outnode = nngraph.Node({input={}}) + for i,n in ipairs(outputs) do + if torch.typename(n) ~= 'nngraph.Node' then + error(string.format('what is this in the outputs[%s]? %s', + i, tostring(n))) + end + outnode:add(n,true) + end + for i,n in ipairs(inputs) do + if torch.typename(n) ~= 'nngraph.Node' then + error(string.format('what is this in the inputs[%s]? %s', + i, tostring(n))) + end + end + -- We add also a dummy input node. + -- The input node will be split to feed the passed input nodes. + local innode = nngraph.Node({input={}}) + assert(#inputs > 0, "no inputs are not supported") + if #inputs == 1 then + inputs[1]:add(innode,true) + else + local splits = {innode:split(#inputs)} + for i = 1, #inputs do + assert(#inputs[i].children == 0, "an input should have no inputs") + end + for i = 1, #inputs do + inputs[i]:add(splits[i],true) + end + end + + -- the backward graph (bg) is for gradients + -- the forward graph (fg) is for function evaluation + self.bg = outnode:graph() + self.fg = self.bg:reverse() + + -- the complete graph is constructed + -- now regenerate the graphs with the additional nodes + assert(#self.fg:roots() == 1, "expecting only one start") + self.innode = self.fg:roots()[1] + assert(self.innode.data == innode.data, "expecting the forward innode") + self.outnode = outnode + self.verbose = false + self.nInputs = #inputs + + -- computation on the graph is done through topsort of forward and backward graphs + self.forwardnodes = self.fg:topsort() + self.backwardnodes = self.bg:topsort() + -- Checking for unused inputs or unused split() outputs. + for i,forwardNode in ipairs(self.forwardnodes) do + if forwardNode.data.nSplitOutputs and forwardNode.data.nSplitOutputs ~= #forwardNode.children then + local nUnused = forwardNode.data.nSplitOutputs - #forwardNode.children + error(string.format("%s of split(%s) outputs are unused", nUnused, + forwardNode.data.nSplitOutputs)) + end + end + -- Adding data.forwardNodeId for nicer node:label() output. + for i,forwardNode in ipairs(self.forwardnodes) do + forwardNode.data.forwardNodeId = forwardNode.id + end + + self.output = nil + self.gradInput = nil + if #self.outnode.children > 1 then + self.output = self.outnode.data.input + end end function gModule:apply(func) - for i,node in ipairs(self.forwardnodes) do - if node.data.module then - func(node.data.module) - end - end + for i,node in ipairs(self.forwardnodes) do + if node.data.module then + func(node.data.module) + end + end end function gModule:map(gm, func) @@ -149,266 +149,266 @@ end function gModule:share(gm, ...) local args = {...} self:map(gm, - function(subnet1, subnet2) - subnet1:share(subnet2, unpack(args)) + function(subnet1, subnet2) + subnet1:share(subnet2, unpack(args)) end) return self end function gModule:training() - self:apply(function(module) module:training() end) + self:apply(function(module) module:training() end) end function gModule:evaluate() - self:apply(function(module) module:evaluate() end) + self:apply(function(module) module:evaluate() end) end --[[ Recursively applies type(type_str) to any tensors in the argument. If the argument is a tensor, type(type_str) is applied; if the argument is an array, this function recurses into it. ]] local function recursiveType(param, type_str) - if torch.type(param) == 'table' then - for i = 1, #param do - param[i] = recursiveType(param[i], type_str) - end - elseif torch.typename(param) and - torch.typename(param):find('torch%..+Tensor') then - param = param:type(type_str) - end - return param + if torch.type(param) == 'table' then + for i = 1, #param do + param[i] = recursiveType(param[i], type_str) + end + elseif torch.typename(param) and + torch.typename(param):find('torch%..+Tensor') then + param = param:type(type_str) + end + return param end function gModule:type(type) - local function applyTypeToTable(table) - for key, value in pairs(table) do - table[key] = recursiveType(table[key], type) - end - end - - -- Convert any stored data in self, and in the in and out nodes - applyTypeToTable(self) - if self.innode then applyTypeToTable(self.innode.data) end - if self.outnode then applyTypeToTable(self.outnode.data) end - - -- Loop through modules and convert data - self:apply(function(module) module:type(type) end) - - for i,node in ipairs(self.backwardnodes) do - if node.data.gradOutputBuffer ~= nil then - node.data.gradOutputBuffer = node.data.gradOutputBuffer:type(type) - end - end - - return self + local function applyTypeToTable(table) + for key, value in pairs(table) do + table[key] = recursiveType(table[key], type) + end + end + + -- Convert any stored data in self, and in the in and out nodes + applyTypeToTable(self) + if self.innode then applyTypeToTable(self.innode.data) end + if self.outnode then applyTypeToTable(self.outnode.data) end + + -- Loop through modules and convert data + self:apply(function(module) module:type(type) end) + + for i,node in ipairs(self.backwardnodes) do + if node.data.gradOutputBuffer ~= nil then + node.data.gradOutputBuffer = node.data.gradOutputBuffer:type(type) + end + end + + return self end function gModule:zeroGradParameters() - self:apply(function(module) module:zeroGradParameters() end) + self:apply(function(module) module:zeroGradParameters() end) end function gModule:updateOutput(input) - return self:runForwardFunction('updateOutput',input) + return self:runForwardFunction('updateOutput',input) end function gModule:runForwardFunction(func,input) - if type(func) == "string" then - local func_name = func - func = function(module,input) return module[func_name](module,input) end - end - -- For backward compatibility, we allow self.nInputs to be missing. - local nInputs = self.nInputs or #self.innode.children - -- We see the input as a list of inputs. - if nInputs <= 1 then - input={input} - elseif type(input) ~= "table" then - error(string.format("expecting %s inputs", nInputs)) - end - local function neteval(node) - local function propagate(node,x) - for i,child in ipairs(node.children) do - child.data.input = child.data.input or {} - local mapindex = child.data.mapindex[node.data] - assert(not child.data.input[mapindex], "each input should have one source") - child.data.input[mapindex] = x - end - end - if node.data.selectindex then - assert(not node.data.module, "the selectindex-handling nodes should have no module") - local input = node.data.input - assert(#input == 1, "only the splitted node should be the input") - assert(istable(input[1]), "the input for a split should be a table") - input = input[1][node.data.selectindex] - propagate(node,input) - else - local input = node.data.input - if #input == 1 then - input = input[1] - end - -- forward through this node - -- If no module is present, the node behaves like nn.Identity. - local output - if not node.data.module then - output = input - else - output = func(node.data.module,input) - end - if node.data.nSplitOutputs and node.data.nSplitOutputs ~= #output then - error(string.format("split(%s) cannot split %s outputs", - node.data.nSplitOutputs, - #output)) - end - -- propagate the output to children - propagate(node,output) - end - if self.verbose then - print(' V : ' .. node:label()) - end - end - - local innode = self.innode - if #input ~= nInputs then - error(string.format('Got %s inputs instead of %s', #input, nInputs)) - end - -- first clear the input states - for _,node in ipairs(self.forwardnodes) do - local input = node.data.input - while input and #input>0 do - table.remove(input) - end - end - -- Set the starting input. - -- We do copy instead of modifying the passed input. - innode.data.input = innode.data.input or {} - for i, item in ipairs(input) do - innode.data.input[i] = item - end - - -- the run forward - for i,node in ipairs(self.forwardnodes) do - neteval(node) - end - - self.output = self.outnode.data.input - if #self.outnode.children == 1 then - self.output = self.output[1] - end - return self.output + if type(func) == "string" then + local func_name = func + func = function(module,input) return module[func_name](module,input) end + end + -- For backward compatibility, we allow self.nInputs to be missing. + local nInputs = self.nInputs or #self.innode.children + -- We see the input as a list of inputs. + if nInputs <= 1 then + input={input} + elseif type(input) ~= "table" then + error(string.format("expecting %s inputs", nInputs)) + end + local function neteval(node) + local function propagate(node,x) + for i,child in ipairs(node.children) do + child.data.input = child.data.input or {} + local mapindex = child.data.mapindex[node.data] + assert(not child.data.input[mapindex], "each input should have one source") + child.data.input[mapindex] = x + end + end + if node.data.selectindex then + assert(not node.data.module, "the selectindex-handling nodes should have no module") + local input = node.data.input + assert(#input == 1, "only the splitted node should be the input") + assert(istable(input[1]), "the input for a split should be a table") + input = input[1][node.data.selectindex] + propagate(node,input) + else + local input = node.data.input + if #input == 1 then + input = input[1] + end + -- forward through this node + -- If no module is present, the node behaves like nn.Identity. + local output + if not node.data.module then + output = input + else + output = func(node.data.module,input) + end + if node.data.nSplitOutputs and node.data.nSplitOutputs ~= #output then + error(string.format("split(%s) cannot split %s outputs", + node.data.nSplitOutputs, + #output)) + end + -- propagate the output to children + propagate(node,output) + end + if self.verbose then + print(' V : ' .. node:label()) + end + end + + local innode = self.innode + if #input ~= nInputs then + error(string.format('Got %s inputs instead of %s', #input, nInputs)) + end + -- first clear the input states + for _,node in ipairs(self.forwardnodes) do + local input = node.data.input + while input and #input>0 do + table.remove(input) + end + end + -- Set the starting input. + -- We do copy instead of modifying the passed input. + innode.data.input = innode.data.input or {} + for i, item in ipairs(input) do + innode.data.input[i] = item + end + + -- the run forward + for i,node in ipairs(self.forwardnodes) do + neteval(node) + end + + self.output = self.outnode.data.input + if #self.outnode.children == 1 then + self.output = self.output[1] + end + return self.output end function gModule:updateGradInput(input,gradOutput) - local function neteval(node) - if node.data.selectindex then - assert(not node.data.module, "the selectindex-handling nodes should have no module") - assert(#node.children == 1, "only the splitted node should be the input") - local child = node.children[1] - local go = getTotalGradOutput(node) - child.data.gradOutput = child.data.gradOutput or {} - assert(#child.data.gradOutput <= 1, "the splitted node should be used only once") - -- The data.gradOutput holds the to-be-summed gradients. - child.data.gradOutput[1] = child.data.gradOutput[1] or {} - assert(not child.data.gradOutput[1][node.data.selectindex], "no gradOutput should be assigned yet") - child.data.gradOutput[1][node.data.selectindex] = go - else - local gradOutput = getTotalGradOutput(node) - -- updateGradInput through this node - -- If no module is present, the node behaves like nn.Identity. - local gradInput - if not node.data.module then - gradInput = gradOutput - else - local input = node.data.input - if #input == 1 then - input = input[1] - end - local module = node.data.module - gradInput = module:updateGradInput(input,gradOutput) - end - -- propagate the output to children - for i,child in ipairs(node.children) do - child.data.gradOutput = child.data.gradOutput or {} - local mapindex = node.data.mapindex[child.data] - local gi - if #node.children == 1 then - gi = gradInput - else - gi = gradInput[mapindex] - end - table.insert(child.data.gradOutput,gi) - end - end - if self.verbose then - print(' V : ' .. node:label()) - end - end - local outnode = self.outnode - if #outnode.children > 1 and #gradOutput ~= #outnode.children then - error(string.format('Got %s gradOutputs instead of %s', #gradOutput, #outnode.children)) - end - for _,node in ipairs(self.backwardnodes) do - local gradOutput = node.data.gradOutput - while gradOutput and #gradOutput >0 do - table.remove(gradOutput) - end - end - -- Set the starting gradOutput. - outnode.data.gradOutput = outnode.data.gradOutput or {} - outnode.data.gradOutput[1] = gradOutput - - for i,node in ipairs(self.backwardnodes) do - neteval(node) - end - - assert(#self.innode.data.gradOutput == 1, "expecting the innode to be used only once") - self.gradInput = self.innode.data.gradOutput[1] - return self.gradInput + local function neteval(node) + if node.data.selectindex then + assert(not node.data.module, "the selectindex-handling nodes should have no module") + assert(#node.children == 1, "only the splitted node should be the input") + local child = node.children[1] + local go = getTotalGradOutput(node) + child.data.gradOutput = child.data.gradOutput or {} + assert(#child.data.gradOutput <= 1, "the splitted node should be used only once") + -- The data.gradOutput holds the to-be-summed gradients. + child.data.gradOutput[1] = child.data.gradOutput[1] or {} + assert(not child.data.gradOutput[1][node.data.selectindex], "no gradOutput should be assigned yet") + child.data.gradOutput[1][node.data.selectindex] = go + else + local gradOutput = getTotalGradOutput(node) + -- updateGradInput through this node + -- If no module is present, the node behaves like nn.Identity. + local gradInput + if not node.data.module then + gradInput = gradOutput + else + local input = node.data.input + if #input == 1 then + input = input[1] + end + local module = node.data.module + gradInput = module:updateGradInput(input,gradOutput) + end + -- propagate the output to children + for i,child in ipairs(node.children) do + child.data.gradOutput = child.data.gradOutput or {} + local mapindex = node.data.mapindex[child.data] + local gi + if #node.children == 1 then + gi = gradInput + else + gi = gradInput[mapindex] + end + table.insert(child.data.gradOutput,gi) + end + end + if self.verbose then + print(' V : ' .. node:label()) + end + end + local outnode = self.outnode + if #outnode.children > 1 and #gradOutput ~= #outnode.children then + error(string.format('Got %s gradOutputs instead of %s', #gradOutput, #outnode.children)) + end + for _,node in ipairs(self.backwardnodes) do + local gradOutput = node.data.gradOutput + while gradOutput and #gradOutput >0 do + table.remove(gradOutput) + end + end + -- Set the starting gradOutput. + outnode.data.gradOutput = outnode.data.gradOutput or {} + outnode.data.gradOutput[1] = gradOutput + + for i,node in ipairs(self.backwardnodes) do + neteval(node) + end + + assert(#self.innode.data.gradOutput == 1, "expecting the innode to be used only once") + self.gradInput = self.innode.data.gradOutput[1] + return self.gradInput end function gModule:accGradParameters(input,gradOutput,lr) - local function neteval(node) - if node.data.module then - local module = node.data.module - local gradOutput = node.data.gradOutput[1] - if #node.data.gradOutput > 1 then - gradOutput = node.data.gradOutputBuffer - end - local input = node.data.input - if #input == 1 then - input = input[1] - end - -- accGradParameters through this node - module:accGradParameters(input,gradOutput,lr) - end - if self.verbose then - print(' V : ' .. node:label()) - end - end - local outnode = self.outnode - if #outnode.children > 1 and #gradOutput ~= #outnode.children then - error(string.format('Got %s gradOutputs instead of %s', #gradOutput, #outnode.children)) - end - for i,node in ipairs(self.backwardnodes) do - neteval(node) - end + local function neteval(node) + if node.data.module then + local module = node.data.module + local gradOutput = node.data.gradOutput[1] + if #node.data.gradOutput > 1 then + gradOutput = node.data.gradOutputBuffer + end + local input = node.data.input + if #input == 1 then + input = input[1] + end + -- accGradParameters through this node + module:accGradParameters(input,gradOutput,lr) + end + if self.verbose then + print(' V : ' .. node:label()) + end + end + local outnode = self.outnode + if #outnode.children > 1 and #gradOutput ~= #outnode.children then + error(string.format('Got %s gradOutputs instead of %s', #gradOutput, #outnode.children)) + end + for i,node in ipairs(self.backwardnodes) do + neteval(node) + end end function gModule:parameters() - local p,gp = {},{} - for _,node in ipairs(self.forwardnodes) do - if node.data.module then - local mp,mgp = node.data.module:parameters() - if mp and mgp then - for i = 1,#mp do - table.insert(p,mp[i]) - table.insert(gp,mgp[i]) - end - end - end - end - return p,gp + local p,gp = {},{} + for _,node in ipairs(self.forwardnodes) do + if node.data.module then + local mp,mgp = node.data.module:parameters() + if mp and mgp then + for i = 1,#mp do + table.insert(p,mp[i]) + table.insert(gp,mgp[i]) + end + end + end + end + return p,gp end function gModule:__tostring__() - return self.name or torch.type(self) + return self.name or torch.type(self) end diff --git a/graphinspecting.lua b/graphinspecting.lua index 7c8d462..0ccb168 100644 --- a/graphinspecting.lua +++ b/graphinspecting.lua @@ -2,98 +2,98 @@ -- The findCurrentNode() depends on the names of the -- local variables in the nngraph.gModule source code. local function findCurrentNode() - for level = 2, math.huge do - local info = debug.getinfo(level, "n") - if info == nil then - return nil - end - - local funcName = info.name - if funcName == "neteval" then - local varName, node = debug.getlocal(level, 1) - if varName == "node" then - return node - end - end - end + for level = 2, math.huge do + local info = debug.getinfo(level, "n") + if info == nil then + return nil + end + + local funcName = info.name + if funcName == "neteval" then + local varName, node = debug.getlocal(level, 1) + if varName == "node" then + return node + end + end + end end -- Runs the func and calls onError(failedNode, ...) on an error. -- The stack trace is inspected to find the failedNode. local function runChecked(func, onError, ...) - -- The current node needs to be searched-for, before unrolling the stack. - local failedNode - local function errorHandler(message) - -- The stack traceback is added only if not already present. - if not string.find(message, 'stack traceback:\n', 1, true) then - message = debug.traceback(message, 2) - end - failedNode = findCurrentNode() - return message - end - - local ok, result = xpcall(func, errorHandler) - if ok then - return result - end - - onError(failedNode, ...) - -- Passing the level 0 avoids adding an additional error position info - -- to the message. - error(result, 0) + -- The current node needs to be searched-for, before unrolling the stack. + local failedNode + local function errorHandler(message) + -- The stack traceback is added only if not already present. + if not string.find(message, 'stack traceback:\n', 1, true) then + message = debug.traceback(message, 2) + end + failedNode = findCurrentNode() + return message + end + + local ok, result = xpcall(func, errorHandler) + if ok then + return result + end + + onError(failedNode, ...) + -- Passing the level 0 avoids adding an additional error position info + -- to the message. + error(result, 0) end local function customToDot(graph, title, failedNode) - local str = graph:todot(title) - if not failedNode then - return str - end - - local failedNodeId = nil - for i, node in ipairs(graph.nodes) do - if node.data == failedNode.data then - failedNodeId = node.id - break - end - end - - if failedNodeId ~= nil then - -- The closing '}' is removed. - -- And red fillcolor is specified for the failedNode. - str = string.gsub(str, '}%s*$', '') - str = str .. string.format('n%s[style=filled, fillcolor=red];\n}', - failedNodeId) - end - return str + local str = graph:todot(title) + if not failedNode then + return str + end + + local failedNodeId = nil + for i, node in ipairs(graph.nodes) do + if node.data == failedNode.data then + failedNodeId = node.id + break + end + end + + if failedNodeId ~= nil then + -- The closing '}' is removed. + -- And red fillcolor is specified for the failedNode. + str = string.gsub(str, '}%s*$', '') + str = str .. string.format('n%s[style=filled, fillcolor=red];\n}', + failedNodeId) + end + return str end local function saveSvg(svgPathPrefix, dotStr) - io.stderr:write(string.format("saving %s.svg\n", svgPathPrefix)) - local dotPath = svgPathPrefix .. '.dot' - local dotFile = io.open(dotPath, 'w') - dotFile:write(dotStr) - dotFile:close() - - local svgPath = svgPathPrefix .. '.svg' - local cmd = string.format('dot -Tsvg -o %s %s', svgPath, dotPath) - os.execute(cmd) + io.stderr:write(string.format("saving %s.svg\n", svgPathPrefix)) + local dotPath = svgPathPrefix .. '.dot' + local dotFile = io.open(dotPath, 'w') + dotFile:write(dotStr) + dotFile:close() + + local svgPath = svgPathPrefix .. '.svg' + local cmd = string.format('dot -Tsvg -o %s %s', svgPath, dotPath) + os.execute(cmd) end local function onError(failedNode, gmodule) - local nInputs = gmodule.nInputs or #gmodule.innode.children - local svgPathPrefix = gmodule.name or string.format( - 'nngraph_%sin_%sout', nInputs, #gmodule.outnode.children) - if paths.filep(svgPathPrefix .. '.svg') then - svgPathPrefix = svgPathPrefix .. '_' .. paths.basename(os.tmpname()) - end - local dotStr = customToDot(gmodule.fg, svgPathPrefix, failedNode) - saveSvg(svgPathPrefix, dotStr) + local nInputs = gmodule.nInputs or #gmodule.innode.children + local svgPathPrefix = gmodule.name or string.format( + 'nngraph_%sin_%sout', nInputs, #gmodule.outnode.children) + if paths.filep(svgPathPrefix .. '.svg') then + svgPathPrefix = svgPathPrefix .. '_' .. paths.basename(os.tmpname()) + end + local dotStr = customToDot(gmodule.fg, svgPathPrefix, failedNode) + saveSvg(svgPathPrefix, dotStr) end local origFuncs = { - runForwardFunction = nn.gModule.runForwardFunction, - updateGradInput = nn.gModule.updateGradInput, - accGradParameters = nn.gModule.accGradParameters, + runForwardFunction = nn.gModule.runForwardFunction, + updateGradInput = nn.gModule.updateGradInput, + accGradParameters = nn.gModule.accGradParameters, } -- When debug is enabled, @@ -101,42 +101,42 @@ local origFuncs = { -- if an exception occurs in a graph execution. -- The problematic node will be marked by red color. function nngraph.setDebug(enable) - if not enable then - -- When debug is disabled, - -- the origFuncs are restored on nn.gModule. - for funcName, origFunc in pairs(origFuncs) do - nn.gModule[funcName] = origFunc - end - return - end - - for funcName, origFunc in pairs(origFuncs) do - nn.gModule[funcName] = function(...) - local args = {...} - local gmodule = args[1] - return runChecked(function() - return origFunc(unpack(args)) - end, onError, gmodule) - end - end + if not enable then + -- When debug is disabled, + -- the origFuncs are restored on nn.gModule. + for funcName, origFunc in pairs(origFuncs) do + nn.gModule[funcName] = origFunc + end + return + end + + for funcName, origFunc in pairs(origFuncs) do + nn.gModule[funcName] = function(...) + local args = {...} + local gmodule = args[1] + return runChecked(function() + return origFunc(unpack(args)) + end, onError, gmodule) + end + end end -- Sets node.data.annotations.name for the found nodes. -- The local variables at the given stack level are inspected. -- The default stack level is 1 (the function that called annotateNodes()). function nngraph.annotateNodes(stackLevel) - stackLevel = stackLevel or 1 - for index = 1, math.huge do - local varName, varValue = debug.getlocal(stackLevel + 1, index) - if not varName then - break - end - if torch.typename(varValue) == "nngraph.Node" then - -- An explicit name is preserved. - if not varValue.data.annotations.name then - varValue:annotate({name = varName}) - end - end - end + stackLevel = stackLevel or 1 + for index = 1, math.huge do + local varName, varValue = debug.getlocal(stackLevel + 1, index) + if not varName then + break + end + if torch.typename(varValue) == "nngraph.Node" then + -- An explicit name is preserved. + if not varValue.data.annotations.name then + varValue:annotate({name = varName}) + end + end + end end @@ -1,4 +1,3 @@ - require 'nn' require 'graph' @@ -18,33 +17,32 @@ local istorchclass = utils.istorchclass -- simpler todot functions nngraph.simple_print = paths.dofile('simple_print.lua') - -- Modify the __call function to hack into nn.Module local Module = torch.getmetatable('nn.Module') function Module:__call__(...) - local nArgs = select("#", ...) - assert(nArgs <= 1, 'Use {input1, input2} to pass multiple inputs.') - - local input = ... - if nArgs == 1 and input == nil then - error('what is this in the input? nil') - end - if not istable(input) then - input = {input} - end - local mnode = nngraph.Node({module=self}) - - for i,dnode in ipairs(input) do - if torch.typename(dnode) ~= 'nngraph.Node' then - error('what is this in the input? ' .. tostring(dnode)) - end - mnode:add(dnode,true) - end - - return mnode + local nArgs = select("#", ...) + assert(nArgs <= 1, 'Use {input1, input2} to pass multiple inputs.') + + local input = ... + if nArgs == 1 and input == nil then + error('what is this in the input? nil') + end + if not istable(input) then + input = {input} + end + local mnode = nngraph.Node({module=self}) + + for i,dnode in ipairs(input) do + if torch.typename(dnode) ~= 'nngraph.Node' then + error('what is this in the input? ' .. tostring(dnode)) + end + mnode:add(dnode,true) + end + + return mnode end local Criterion = torch.getmetatable('nn.Criterion') function Criterion:__call__(...) - return nn.ModuleFromCriterion(self)(...) + return nn.ModuleFromCriterion(self)(...) end diff --git a/nesting.lua b/nesting.lua index ab63e70..0a61b36 100644 --- a/nesting.lua +++ b/nesting.lua @@ -4,66 +4,65 @@ local nesting = {} local utils = paths.dofile('utils.lua') local istensor = utils.istensor - -- Creates a clone of a tensor or of a table with tensors. function nesting.cloneNested(obj) - if istensor(obj) then - return obj:clone() - end + if istensor(obj) then + return obj:clone() + end - local result = {} - for key, child in pairs(obj) do - result[key] = nesting.cloneNested(child) - end - return result + local result = {} + for key, child in pairs(obj) do + result[key] = nesting.cloneNested(child) + end + return result end -- Fills the obj with the given value. -- The obj can be a tensor or a table with tensors. function nesting.fillNested(obj, value) - if istensor(obj) then - obj:fill(value) - else - for key, child in pairs(obj) do - nesting.fillNested(child, value) - end - end + if istensor(obj) then + obj:fill(value) + else + for key, child in pairs(obj) do + nesting.fillNested(child, value) + end + end end -- Resizes all tensors in the output. function nesting.resizeNestedAs(output, input) - if istensor(output) then - output:resizeAs(input) - else - for key, child in pairs(input) do - -- A new element is added to the output, if needed. - if not output[key] then - output[key] = nesting.cloneNested(child) - else - nesting.resizeNestedAs(output[key], child) - end - end - -- Extra elements are removed from the output. - for key, child in pairs(output) do - if not input[key] then - output[key] = nil - end - end - end + if istensor(output) then + output:resizeAs(input) + else + for key, child in pairs(input) do + -- A new element is added to the output, if needed. + if not output[key] then + output[key] = nesting.cloneNested(child) + else + nesting.resizeNestedAs(output[key], child) + end + end + -- Extra elements are removed from the output. + for key, child in pairs(output) do + if not input[key] then + output[key] = nil + end + end + end end -- Adds the input to the output. -- The input can contain nested tables. -- The output will contain the same nesting of tables. function nesting.addNestedTo(output, input) - if istensor(output) then - output:add(input) - else - for key, child in pairs(input) do - assert(output[key] ~= nil, "missing key") - nesting.addNestedTo(output[key], child) - end - end + if istensor(output) then + output:add(input) + else + for key, child in pairs(input) do + assert(output[key] ~= nil, "missing key") + nesting.addNestedTo(output[key], child) + end + end end return nesting @@ -5,162 +5,155 @@ local istable = utils.istable local istorchclass = utils.istorchclass require 'debug' - local nnNode,parent = torch.class('nngraph.Node','graph.Node') function nnNode:__init(data) - parent.__init(self,data) - self.data.annotations = self.data.annotations or {} - self.data.mapindex = self.data.mapindex or {} - if not self.data.annotations._debugLabel then - self:_makeDebugLabel(debug.getinfo(6, 'Sl')) - end + parent.__init(self,data) + self.data.annotations = self.data.annotations or {} + self.data.mapindex = self.data.mapindex or {} + if not self.data.annotations._debugLabel then + self:_makeDebugLabel(debug.getinfo(6, 'Sl')) + end end - --[[ Build a string label which will be used a tooltip when - making a graph.]] +making a graph.]] function nnNode:_makeDebugLabel(dinfo) - if dinfo then - self.data.annotations._debugLabel = string.format('[%s]:%d', - dinfo.short_src, dinfo.currentline, dinfo.name) - end + if dinfo then + self.data.annotations._debugLabel = string.format('[%s]:%d', + dinfo.short_src, dinfo.currentline, dinfo.name) + end end - -- domap ensures that this node will keep track of the order its children are added. -- mapindex is a forward/backward list -- index = self.data.mapindex[child.data] -- child.data = self.data.mapindex[index] function nnNode:add(child,domap) - parent.add(self,child) - if domap then - local mapindex = self.data.mapindex - local data = child.data - assert(not mapindex[data], "Don't pass the same input twice.") - table.insert(mapindex,data) - mapindex[data] = #mapindex - end + parent.add(self,child) + if domap then + local mapindex = self.data.mapindex + local data = child.data + assert(not mapindex[data], "Don't pass the same input twice.") + table.insert(mapindex,data) + mapindex[data] = #mapindex + end end -- this function returns noutput number of new nodes -- that each take a single component of the output of this -- node in the order they are returned. function nnNode:split(noutput) - assert(noutput >= 2, "splitting to one output is not supported") - local debugLabel = self.data.annotations._debugLabel - local mnode = nngraph.Node({nSplitOutputs=noutput, annotations={_debugLabel=debugLabel .. '-mnode'}}) - mnode:add(self,true) - - local selectnodes = {} - for i=1,noutput do - local node = nngraph.Node({selectindex=i,input={}, annotations={_debugLabel=debugLabel .. '-' .. i}}) - node:add(mnode,true) - table.insert(selectnodes,node) - end - return unpack(selectnodes) + assert(noutput >= 2, "splitting to one output is not supported") + local debugLabel = self.data.annotations._debugLabel + local mnode = nngraph.Node({nSplitOutputs=noutput, annotations={_debugLabel=debugLabel .. '-mnode'}}) + mnode:add(self,true) + + local selectnodes = {} + for i=1,noutput do + local node = nngraph.Node({selectindex=i,input={}, annotations={_debugLabel=debugLabel .. '-' .. i}}) + node:add(mnode,true) + table.insert(selectnodes,node) + end + return unpack(selectnodes) end - function nnNode:annotate(annotations) - for k, v in pairs(annotations) do - self.data.annotations[k] = v - end + for k, v in pairs(annotations) do + self.data.annotations[k] = v + end - return self + return self end - function nnNode:graphNodeName() - if self.data.annotations.name then - return self.data.annotations.name .. ' (' .. self.id .. ')' - else - return 'Node' .. self.id - end + if self.data.annotations.name then + return self.data.annotations.name .. ' (' .. self.id .. ')' + else + return 'Node' .. self.id + end end - function nnNode:graphNodeAttributes() - self.data.annotations.graphAttributes = - self.data.annotations.graphAttributes or {} - if not self.data.annotations.graphAttributes.tooltip then - self.data.annotations.graphAttributes.tooltip = - self.data.annotations._debugLabel - end - - return self.data.annotations.graphAttributes + self.data.annotations.graphAttributes = + self.data.annotations.graphAttributes or {} + if not self.data.annotations.graphAttributes.tooltip then + self.data.annotations.graphAttributes.tooltip = + self.data.annotations._debugLabel + end + + return self.data.annotations.graphAttributes end - local function getNanFlag(data) - if data:nElement() == 0 then - return '' - end - local isNan = (data:ne(data):sum() > 0) - if isNan then - return 'NaN' - end - if data:max() == math.huge then - return 'inf' - end - if data:min() == -math.huge then - return '-inf' - end - return '' + if data:nElement() == 0 then + return '' + end + local isNan = (data:ne(data):sum() > 0) + if isNan then + return 'NaN' + end + if data:max() == math.huge then + return 'inf' + end + if data:min() == -math.huge then + return '-inf' + end + return '' end function nnNode:label() - local lbl = {} - - local function getstr(data) - if not data then return '' end - if istensor(data) then - local nanFlag = getNanFlag(data) - local tensorType = 'Tensor' - if data:type() ~= torch.Tensor():type() then - tensorType = data:type() - end - return tensorType .. '[' .. table.concat(data:size():totable(),'x') .. ']' .. nanFlag - elseif istable(data) then - local tstr = {} - for i,v in ipairs(data) do - table.insert(tstr, getstr(v)) - end - return '{' .. table.concat(tstr,',') .. '}' - else - return tostring(data):gsub('\n','\\l') - end - end - local function getmapindexstr(mapindex) - local tstr = {} - for i,data in ipairs(mapindex) do - local inputId = 'Node' .. (data.forwardNodeId or '') - table.insert(tstr, inputId) - end - return '{' .. table.concat(tstr,',') .. '}' - end - - for k,v in pairs(self.data) do - local vstr = '' - if k== 'mapindex' then - if #v > 1 then - vstr = getmapindexstr(v) - table.insert(lbl, k .. ' = ' .. vstr) - end - elseif k== 'forwardNodeId' or k== 'annotations' then - -- the forwardNodeId is not displayed in the label. - else - vstr = getstr(v) - table.insert(lbl, k .. ' = ' .. vstr) - end - end - - local desc - if self.data.annotations.description then - desc = 'desc = ' .. self.data.annotations.description .. '\\n' - else - desc = '' - end - return desc .. table.concat(lbl,"\\l") + local lbl = {} + + local function getstr(data) + if not data then return '' end + if istensor(data) then + local nanFlag = getNanFlag(data) + local tensorType = 'Tensor' + if data:type() ~= torch.Tensor():type() then + tensorType = data:type() + end + return tensorType .. '[' .. table.concat(data:size():totable(),'x') .. ']' .. nanFlag + elseif istable(data) then + local tstr = {} + for i,v in ipairs(data) do + table.insert(tstr, getstr(v)) + end + return '{' .. table.concat(tstr,',') .. '}' + else + return tostring(data):gsub('\n','\\l') + end + end + local function getmapindexstr(mapindex) + local tstr = {} + for i,data in ipairs(mapindex) do + local inputId = 'Node' .. (data.forwardNodeId or '') + table.insert(tstr, inputId) + end + return '{' .. table.concat(tstr,',') .. '}' + end + + for k,v in pairs(self.data) do + local vstr = '' + if k== 'mapindex' then + if #v > 1 then + vstr = getmapindexstr(v) + table.insert(lbl, k .. ' = ' .. vstr) + end + elseif k== 'forwardNodeId' or k== 'annotations' then + -- the forwardNodeId is not displayed in the label. + else + vstr = getstr(v) + table.insert(lbl, k .. ' = ' .. vstr) + end + end + + local desc + if self.data.annotations.description then + desc = 'desc = ' .. self.data.annotations.description .. '\\n' + else + desc = '' + end + return desc .. table.concat(lbl,"\\l") end diff --git a/simple_print.lua b/simple_print.lua index 878db3e..87bf152 100644 --- a/simple_print.lua +++ b/simple_print.lua @@ -1,125 +1,124 @@ local function removeNodeFromEdges(node_id, edges) - local from_nodes = {} - local to_nodes = {} - -- remove edges - local idx = 1 - while idx <= #edges do - local edge = edges[idx] - if edge.source == node_id then - local to_node = edges[idx].target - table.insert(to_nodes, to_node) - table.remove(edges, idx) - elseif edge.target == node_id then - local from_node = edges[idx].source - table.insert(from_nodes, from_node) - table.remove(edges, idx) - else - idx = idx + 1 - end - end + local from_nodes = {} + local to_nodes = {} + -- remove edges + local idx = 1 + while idx <= #edges do + local edge = edges[idx] + if edge.source == node_id then + local to_node = edges[idx].target + table.insert(to_nodes, to_node) + table.remove(edges, idx) + elseif edge.target == node_id then + local from_node = edges[idx].source + table.insert(from_nodes, from_node) + table.remove(edges, idx) + else + idx = idx + 1 + end + end - -- add new edges - for _, f in pairs(from_nodes) do - for _, t in pairs(to_nodes) do - local edge = {source = f, target= t} - table.insert(edges, edge) - end - end + -- add new edges + for _, f in pairs(from_nodes) do + for _, t in pairs(to_nodes) do + local edge = {source = f, target= t} + table.insert(edges, edge) + end + end - return edges + return edges end local function isNodeGood(node) - return node.data and node.data.module and (torch.typename(node.data.module) ~= 'nn.Identity') + return node.data and node.data.module and (torch.typename(node.data.module) ~= 'nn.Identity') end local function reIndexNodes(nodes, edges) - -- make reverse map - local rev_map = {} - for idx = 1, #nodes do - rev_map[nodes[idx].id] = idx - nodes[idx].id = idx - end - for idx = 1, #edges do - local edge = edges[idx] - edge.source = rev_map[edge.source] - edge.target = rev_map[edge.target] - end - return nodes, edges + -- make reverse map + local rev_map = {} + for idx = 1, #nodes do + rev_map[nodes[idx].id] = idx + nodes[idx].id = idx + end + for idx = 1, #edges do + local edge = edges[idx] + edge.source = rev_map[edge.source] + edge.target = rev_map[edge.target] + end + return nodes, edges end local function cleanGraph(nodes, edges) - local idx = 1 - while idx <= #nodes do - local node = nodes[idx] - if isNodeGood(node.orig_node) then - idx = idx + 1 - else - local id = node.id - table.remove(nodes, idx) - edges = removeNodeFromEdges(id, edges) - end - end - return reIndexNodes(nodes, edges) + local idx = 1 + while idx <= #nodes do + local node = nodes[idx] + if isNodeGood(node.orig_node) then + idx = idx + 1 + else + local id = node.id + table.remove(nodes, idx) + edges = removeNodeFromEdges(id, edges) + end + end + return reIndexNodes(nodes, edges) end local function loadGraph(graph) - local nodes = {} - local edges = {} - for _, node in ipairs(graph.nodes) do - local idx = node.id - table.insert(nodes, {id=idx, orig_node = node} ) - for ich = 1, #node.children do - table.insert( edges, {source = idx, target = node.children[ich].id}) - end - end - nodes, edges = cleanGraph(nodes, edges) - return nodes , edges + local nodes = {} + local edges = {} + for _, node in ipairs(graph.nodes) do + local idx = node.id + table.insert(nodes, {id=idx, orig_node = node} ) + for ich = 1, #node.children do + table.insert( edges, {source = idx, target = node.children[ich].id}) + end + end + nodes, edges = cleanGraph(nodes, edges) + return nodes , edges end local M = {} function M.todot( graph, title ) - local nodes, edges = loadGraph(graph) - local str = {} - table.insert(str,'digraph G {\n') - if title then - table.insert(str,'labelloc="t";\nlabel="' .. title .. '";\n') - end - table.insert(str,'node [shape = oval]; ') - local nodelabels = {} - for i,node in ipairs(nodes) do - local true_node = node.orig_node - local l = '"' .. ( 'Node' .. true_node.id .. '\\n' .. true_node:label() ) .. '"' - nodelabels[i] = 'n' .. true_node.id - table.insert(str, '\n' .. nodelabels[i] .. '[label=' .. l .. '];') - end - table.insert(str,'\n') - for i,edge in ipairs(edges) do - table.insert(str,nodelabels[edge.source] .. ' -> ' .. nodelabels[edge.target] .. ';\n') - end - table.insert(str,'}') - return table.concat(str,'') + local nodes, edges = loadGraph(graph) + local str = {} + table.insert(str,'digraph G {\n') + if title then + table.insert(str,'labelloc="t";\nlabel="' .. title .. '";\n') + end + table.insert(str,'node [shape = oval]; ') + local nodelabels = {} + for i,node in ipairs(nodes) do + local true_node = node.orig_node + local l = '"' .. ( 'Node' .. true_node.id .. '\\n' .. true_node:label() ) .. '"' + nodelabels[i] = 'n' .. true_node.id + table.insert(str, '\n' .. nodelabels[i] .. '[label=' .. l .. '];') + end + table.insert(str,'\n') + for i,edge in ipairs(edges) do + table.insert(str,nodelabels[edge.source] .. ' -> ' .. nodelabels[edge.target] .. ';\n') + end + table.insert(str,'}') + return table.concat(str,'') end function M.dot(g,title,fname) - local gv = M.todot(g, title) - local fngv = (fname or os.tmpname()) .. '.dot' - local fgv = io.open(fngv,'w') - fgv:write(gv) - fgv:close() - local fnsvg = (fname or os.tmpname()) .. '.svg' - os.execute('dot -Tsvg -o ' .. fnsvg .. ' ' .. fngv) - if not fname then - require 'qtsvg' - local qs = qt.QSvgWidget(fnsvg) - qs:show() - os.remove(fngv) - os.remove(fnsvg) - -- print(fngv,fnpng) - return qs - end + local gv = M.todot(g, title) + local fngv = (fname or os.tmpname()) .. '.dot' + local fgv = io.open(fngv,'w') + fgv:write(gv) + fgv:close() + local fnsvg = (fname or os.tmpname()) .. '.svg' + os.execute('dot -Tsvg -o ' .. fnsvg .. ' ' .. fngv) + if not fname then + require 'qtsvg' + local qs = qt.QSvgWidget(fnsvg) + qs:show() + os.remove(fngv) + os.remove(fnsvg) + -- print(fngv,fnpng) + return qs + end end return M - diff --git a/test/speed.lua b/test/speed.lua index 355645c..7218cbe 100644 --- a/test/speed.lua +++ b/test/speed.lua @@ -1,105 +1,104 @@ require 'nngraph' - function time_benchmark(model, input, n) - local forward_timer = torch.Timer():stop():reset() - local backward_timer = torch.Timer():stop():reset() - local total_timer = torch.Timer():stop():reset() - local gradOut - total_timer:resume() - for i = 1, n do - forward_timer:resume() - local out = model:forward(input) - forward_timer:stop() - if not gradOut then - gradOut = torch.rand(out:size()) - end - backward_timer:resume() - model:backward(input, gradOut) - backward_timer:stop() - end - total_timer:stop() - - return {forward = forward_timer:time().real, - backward = backward_timer:time().real, - total = total_timer:time().real} + local forward_timer = torch.Timer():stop():reset() + local backward_timer = torch.Timer():stop():reset() + local total_timer = torch.Timer():stop():reset() + local gradOut + total_timer:resume() + for i = 1, n do + forward_timer:resume() + local out = model:forward(input) + forward_timer:stop() + if not gradOut then + gradOut = torch.rand(out:size()) + end + backward_timer:resume() + model:backward(input, gradOut) + backward_timer:stop() + end + total_timer:stop() + + return {forward = forward_timer:time().real, + backward = backward_timer:time().real, + total = total_timer:time().real} end function report_benchmark(result, title) - local nspace = (80-string.len(title))/2 - report = {string.rep('#', 80), - string.format('%s%s%s', string.rep(' ', math.floor(nspace)), title, string.rep(' ', math.ceil(nspace))), - string.format('Total Time Spent = %.2f s', result.total), - string.format(' Forward Time = %.2f s', result.forward), - string.format(' Backward Time = %.2f s', result.backward), - string.rep('#', 80) - } - print(table.concat(report, '\n')) - return result + local nspace = (80-string.len(title))/2 + report = {string.rep('#', 80), + string.format('%s%s%s', string.rep(' ', math.floor(nspace)), title, string.rep(' ', math.ceil(nspace))), + string.format('Total Time Spent = %.2f s', result.total), + string.format(' Forward Time = %.2f s', result.forward), + string.format(' Backward Time = %.2f s', result.backward), + string.rep('#', 80) +} +print(table.concat(report, '\n')) +return result end function compare_benchmarks(result, base, title) - local nspace = (80-string.len(title))/2 - report = {string.rep('#', 80), - string.format('%s%s%s', string.rep(' ', math.floor(nspace)), title, string.rep(' ', math.ceil(nspace))), - string.format('Total Time Spent = %.2f %%', result.total/base.total*100), - string.format(' Forward Time = %.2f %%', result.forward/base.forward*100), - string.format(' Backward Time = %.2f %%', result.backward/base.backward*100), - string.rep('#', 80) - } - print(table.concat(report, '\n')) - return result + local nspace = (80-string.len(title))/2 + report = {string.rep('#', 80), + string.format('%s%s%s', string.rep(' ', math.floor(nspace)), title, string.rep(' ', math.ceil(nspace))), + string.format('Total Time Spent = %.2f %%', result.total/base.total*100), + string.format(' Forward Time = %.2f %%', result.forward/base.forward*100), + string.format(' Backward Time = %.2f %%', result.backward/base.backward*100), + string.rep('#', 80) +} +print(table.concat(report, '\n')) +return result end function get_models(nhidden_layers, ninput, noutput, nhidden) - local function get_concat_layer(nfrom, nto) - local concat_module = nn.Sequential() - local concat_layer = nn.ConcatTable() - concat_layer:add(nn.Linear(nfrom, nto)) - concat_layer:add(nn.Linear(nfrom, nto)) - concat_module:add(concat_layer) - concat_module:add(nn.CAddTable()) - concat_module:add(nn.ReLU()) - return concat_module - end - - -- NN - local nn_model = nn.Sequential() - nn_model:add(get_concat_layer(ninput, nhidden))--nn.Linear(ninput, nhidden)):add(nn.ReLU()) - for i = 1, nhidden_layers do - nn_model:add(get_concat_layer(nhidden, nhidden))--nn.Linear(nhidden, nhidden)):add(nn.ReLU()) - end - nn_model:add(get_concat_layer(nhidden, noutput))--nn.Linear(nhidden, noutput)) - - -- NN graph - local input = nn.Identity()() - local nng_model = nn.ReLU()(nn.CAddTable()({nn.Linear(ninput, nhidden)(input), - nn.Linear(ninput, nhidden)(input)})) - for i = 1, nhidden_layers do - nng_model = nn.ReLU()(nn.CAddTable()({nn.Linear(nhidden, nhidden)(nng_model), - nn.Linear(nhidden, nhidden)(nng_model)})) - end - nng_model = nn.ReLU()(nn.CAddTable()({nn.Linear(nhidden, noutput)(nng_model), - nn.Linear(nhidden, noutput)(nng_model)})) - - nng_model = nn.gModule({input},{nng_model}) - - return {nn_model = nn_model, nng_model = nng_model} + local function get_concat_layer(nfrom, nto) + local concat_module = nn.Sequential() + local concat_layer = nn.ConcatTable() + concat_layer:add(nn.Linear(nfrom, nto)) + concat_layer:add(nn.Linear(nfrom, nto)) + concat_module:add(concat_layer) + concat_module:add(nn.CAddTable()) + concat_module:add(nn.ReLU()) + return concat_module + end + + -- NN + local nn_model = nn.Sequential() + nn_model:add(get_concat_layer(ninput, nhidden))--nn.Linear(ninput, nhidden)):add(nn.ReLU()) + for i = 1, nhidden_layers do + nn_model:add(get_concat_layer(nhidden, nhidden))--nn.Linear(nhidden, nhidden)):add(nn.ReLU()) + end + nn_model:add(get_concat_layer(nhidden, noutput))--nn.Linear(nhidden, noutput)) + + -- NN graph + local input = nn.Identity()() + local nng_model = nn.ReLU()(nn.CAddTable()({nn.Linear(ninput, nhidden)(input), + nn.Linear(ninput, nhidden)(input)})) + for i = 1, nhidden_layers do + nng_model = nn.ReLU()(nn.CAddTable()({nn.Linear(nhidden, nhidden)(nng_model), + nn.Linear(nhidden, nhidden)(nng_model)})) + end + nng_model = nn.ReLU()(nn.CAddTable()({nn.Linear(nhidden, noutput)(nng_model), + nn.Linear(nhidden, noutput)(nng_model)})) + + nng_model = nn.gModule({input},{nng_model}) + + return {nn_model = nn_model, nng_model = nng_model} end function get_options(arg) - local cmd = torch.CmdLine() - cmd:text('nngraph benchmarking') - cmd:option('-niter', 10, 'number of iterations of forward/backward for each model') - cmd:option('-nhidden_layers', 10, 'number of hidden layers') - cmd:option('-input_size', 512, 'size of input') - cmd:option('-batch_size', 16, 'size of batch') - cmd:option('-hidden_size', 1024, 'size of hidden layer') - cmd:option('-output_size', 128, 'size of output layer') - local opt = cmd:parse(arg) - return opt + local cmd = torch.CmdLine() + cmd:text('nngraph benchmarking') + cmd:option('-niter', 10, 'number of iterations of forward/backward for each model') + cmd:option('-nhidden_layers', 10, 'number of hidden layers') + cmd:option('-input_size', 512, 'size of input') + cmd:option('-batch_size', 16, 'size of batch') + cmd:option('-hidden_size', 1024, 'size of hidden layer') + cmd:option('-output_size', 128, 'size of output layer') + local opt = cmd:parse(arg) + return opt end local opt = get_options(arg) diff --git a/test/test_ModuleFromCriterion.lua b/test/test_ModuleFromCriterion.lua index 78d3cd2..9206905 100644 --- a/test/test_ModuleFromCriterion.lua +++ b/test/test_ModuleFromCriterion.lua @@ -5,52 +5,52 @@ local test = {} local tester = totem.Tester() function test.test_call() - local prediction = nn.Identity()() - local target = nn.Identity()() - local mse = nn.MSECriterion()({prediction, target}) - local costBits = nn.MulConstant(1/math.log(2))(mse) - local net = nn.gModule({prediction, target}, {costBits}) - - local input = {torch.randn(3, 5), torch.rand(3, 5)} - local criterion = nn.MSECriterion() - local output = net:forward(input) - criterion:forward(input[1], input[2]) - tester:eq(output[1], criterion.output/math.log(2), "output", 1e-14) - - local gradOutput = torch.randn(1) - local gradInput = net:backward(input, gradOutput) - criterion:backward(input[1], input[2]) - tester:eq(gradInput[1], criterion.gradInput:clone():mul(gradOutput[1]/math.log(2)), "gradPrediction", 1e-14) - tester:eq(gradInput[2], torch.zeros(input[2]:size()), "gradTarget") + local prediction = nn.Identity()() + local target = nn.Identity()() + local mse = nn.MSECriterion()({prediction, target}) + local costBits = nn.MulConstant(1/math.log(2))(mse) + local net = nn.gModule({prediction, target}, {costBits}) + + local input = {torch.randn(3, 5), torch.rand(3, 5)} + local criterion = nn.MSECriterion() + local output = net:forward(input) + criterion:forward(input[1], input[2]) + tester:eq(output[1], criterion.output/math.log(2), "output", 1e-14) + + local gradOutput = torch.randn(1) + local gradInput = net:backward(input, gradOutput) + criterion:backward(input[1], input[2]) + tester:eq(gradInput[1], criterion.gradInput:clone():mul(gradOutput[1]/math.log(2)), "gradPrediction", 1e-14) + tester:eq(gradInput[2], torch.zeros(input[2]:size()), "gradTarget") end function test.test_grad() - local prediction = nn.Identity()() - local zero = nn.MulConstant(0)(prediction) - -- The target is created inside of the nngraph - -- to ignore the zero gradTarget. - local target = nn.AddConstant(1.23)(zero) - local mse = nn.MSECriterion()({prediction, target}) - local net = nn.gModule({prediction}, {mse}) - - local input = torch.randn(4, 7) - totem.nn.checkGradients(tester, net, input) + local prediction = nn.Identity()() + local zero = nn.MulConstant(0)(prediction) + -- The target is created inside of the nngraph + -- to ignore the zero gradTarget. + local target = nn.AddConstant(1.23)(zero) + local mse = nn.MSECriterion()({prediction, target}) + local net = nn.gModule({prediction}, {mse}) + + local input = torch.randn(4, 7) + totem.nn.checkGradients(tester, net, input) end local function module() - local module = nn.ModuleFromCriterion(nn.MSECriterion()) - local input = {torch.randn(3, 5), torch.randn(3, 5)} - return module, input + local module = nn.ModuleFromCriterion(nn.MSECriterion()) + local input = {torch.randn(3, 5), torch.randn(3, 5)} + return module, input end function test.test_serializable() - local module, input = module() - totem.nn.checkSerializable(tester, module, input) + local module, input = module() + totem.nn.checkSerializable(tester, module, input) end function test.test_typeCastable() - local module, input = module() - totem.nn.checkTypeCastable(tester, module, input) + local module, input = module() + totem.nn.checkTypeCastable(tester, module, input) end diff --git a/test/test_nngraph.lua b/test/test_nngraph.lua index 86bf730..95a1658 100644 --- a/test/test_nngraph.lua +++ b/test/test_nngraph.lua @@ -5,379 +5,379 @@ local test = {} local tester = totem.Tester() local function checkGradients(...) - totem.nn.checkGradients(tester, ...) + totem.nn.checkGradients(tester, ...) end function test.test_oneOutput() - local in1 = nn.Identity()() - local out1 = nn.Identity()(in1) - local module = nn.gModule({in1}, {out1}) - - local input = torch.Tensor({1}) - module:forward(input) - tester:eq(module.output, torch.Tensor{1}, "output") - local gradInput = module:backward(input, torch.Tensor({-123})) - tester:eq(gradInput, torch.Tensor{-123}, "gradInput") - - local input2 = torch.Tensor({2}) - module:forward(input2) - tester:eq(module.output, torch.Tensor{2}, "output for input2") - gradInput = module:backward(input2, torch.Tensor({-2})) - tester:eq(gradInput, torch.Tensor{-2}, "expecting a recomputed gradInput") + local in1 = nn.Identity()() + local out1 = nn.Identity()(in1) + local module = nn.gModule({in1}, {out1}) + + local input = torch.Tensor({1}) + module:forward(input) + tester:eq(module.output, torch.Tensor{1}, "output") + local gradInput = module:backward(input, torch.Tensor({-123})) + tester:eq(gradInput, torch.Tensor{-123}, "gradInput") + + local input2 = torch.Tensor({2}) + module:forward(input2) + tester:eq(module.output, torch.Tensor{2}, "output for input2") + gradInput = module:backward(input2, torch.Tensor({-2})) + tester:eq(gradInput, torch.Tensor{-2}, "expecting a recomputed gradInput") end function test.test_twoOutputs() - local in1 = nn.Identity()() - local out1 = nn.Identity()(in1) - local out2 = nn.Identity()(in1) - local module = nn.gModule({in1}, {out1, out2}) - - local input = torch.Tensor({1}) - module:forward(input) - local gradInput = module:backward(input, {torch.Tensor({-2}), torch.Tensor({-3})}) - tester:eq(gradInput, torch.Tensor{-5}, "gradInput of a fork") - checkGradients(module, input) + local in1 = nn.Identity()() + local out1 = nn.Identity()(in1) + local out2 = nn.Identity()(in1) + local module = nn.gModule({in1}, {out1, out2}) + + local input = torch.Tensor({1}) + module:forward(input) + local gradInput = module:backward(input, {torch.Tensor({-2}), torch.Tensor({-3})}) + tester:eq(gradInput, torch.Tensor{-5}, "gradInput of a fork") + checkGradients(module, input) end function test.test_twoGradOutputs() - local in1 = nn.Sigmoid()() - local splitTable = nn.SplitTable(1)({in1}) - local out1, out2 = splitTable:split(2) - local module = nn.gModule({in1}, {out1, out2}) - - local input = torch.randn(2, 3) - local output = module:forward(input) - assert(#output == 2, "wrong number of outputs") - module:backward(input, {torch.randn(3), torch.randn(3)}) - checkGradients(module, input) + local in1 = nn.Sigmoid()() + local splitTable = nn.SplitTable(1)({in1}) + local out1, out2 = splitTable:split(2) + local module = nn.gModule({in1}, {out1, out2}) + + local input = torch.randn(2, 3) + local output = module:forward(input) + assert(#output == 2, "wrong number of outputs") + module:backward(input, {torch.randn(3), torch.randn(3)}) + checkGradients(module, input) end function test.test_twoInputs() - local in1 = nn.Identity()() - local in2 = nn.Identity()() - local prevH, prevCell = in2:split(2) - - local out1 = nn.CMulTable()({in1, prevH, prevCell}) - local module = nn.gModule({in1, in2}, {out1}) - - local input = {torch.randn(3), {torch.randn(3), torch.randn(3)}} - module:forward(input) - local gradInput = module:backward(input, torch.randn(3)) - assert(#gradInput == 2, "wrong number of gradInputs") - assert(type(gradInput[2]) == "table", "wrong gradInput[2] type") - checkGradients(module, input) + local in1 = nn.Identity()() + local in2 = nn.Identity()() + local prevH, prevCell = in2:split(2) + + local out1 = nn.CMulTable()({in1, prevH, prevCell}) + local module = nn.gModule({in1, in2}, {out1}) + + local input = {torch.randn(3), {torch.randn(3), torch.randn(3)}} + module:forward(input) + local gradInput = module:backward(input, torch.randn(3)) + assert(#gradInput == 2, "wrong number of gradInputs") + assert(type(gradInput[2]) == "table", "wrong gradInput[2] type") + checkGradients(module, input) end function test.test_twoInputs2() - local in1 = nn.Sigmoid()() - local in2 = nn.Sigmoid()() - local module = nn.gModule({in1, in2}, {in1, in2, nn.Sigmoid()(in1)}) - - local input = {torch.randn(3), torch.randn(3)} - module:forward(input) - local gradInput = module:backward(input, {torch.randn(3), torch.randn(3), torch.randn(3)}) - checkGradients(module, input) + local in1 = nn.Sigmoid()() + local in2 = nn.Sigmoid()() + local module = nn.gModule({in1, in2}, {in1, in2, nn.Sigmoid()(in1)}) + + local input = {torch.randn(3), torch.randn(3)} + module:forward(input) + local gradInput = module:backward(input, {torch.randn(3), torch.randn(3), torch.randn(3)}) + checkGradients(module, input) end function test.test_splitDebugLabels() - local node = nn.Identity()() - node.data.annotations._debugLabel = "node" - local node1, node2 = node:split(2) - assert(node1.data.annotations._debugLabel == "node-1") - assert(node2.data.annotations._debugLabel == "node-2") + local node = nn.Identity()() + node.data.annotations._debugLabel = "node" + local node1, node2 = node:split(2) + assert(node1.data.annotations._debugLabel == "node-1") + assert(node2.data.annotations._debugLabel == "node-2") end function test.test_identity() - local in1 = nn.Identity()() - local in2 = nn.Identity()() - local module = nn.gModule({in1, in2}, {in1, in2, nn.Identity()(in1)}) - - local input = {torch.randn(3), torch.randn(3)} - module:forward(input) - module:backward(input, {torch.randn(3), torch.randn(3), torch.randn(3)}) - checkGradients(module, input) + local in1 = nn.Identity()() + local in2 = nn.Identity()() + local module = nn.gModule({in1, in2}, {in1, in2, nn.Identity()(in1)}) + + local input = {torch.randn(3), torch.randn(3)} + module:forward(input) + module:backward(input, {torch.randn(3), torch.randn(3), torch.randn(3)}) + checkGradients(module, input) end function test.test_gradInputType() - local xInput = torch.randn(3) - local h = torch.randn(3) - - local x = nn.Identity()() - local prevRnnState = nn.Identity()() - local prevH1, prevCell = prevRnnState:split(2) - local prevH = prevH1 - - local cellOut = nn.CAddTable()({ - nn.CMulTable()({x, prevH}), - nn.CMulTable()({prevH, prevCell})}) - local module = nn.gModule({x, prevRnnState}, {cellOut}) - - local c = torch.randn(h:size()) - local prevRnnState = {h, c} - local input = {xInput, prevRnnState} - local output = module:forward(input) - - local gradOutput = torch.randn(h:size()) - local gradInput = module:backward(input, gradOutput) - - local gradX, gradPrevState = unpack(gradInput) - local gradPrevH, gradPrevCell = unpack(gradPrevState) - assert(type(gradPrevH) == type(h), "wrong gradPrevH type") - - tester:eq(type(gradPrevH), type(h), "wrong gradPrevH type") - tester:eq(gradPrevH:size(), h:size(), "wrong gradPrevH size") - checkGradients(module, input) -end - -function test.test_tabularInput() - local in1 = nn.SplitTable(1)() - local out1 = nn.CAddTable()(in1) - local module = nn.gModule({in1}, {out1}) - - local input = torch.randn(2, 3) - checkGradients(module, input) -end - -function test.test_extraTable() - local in1 = nn.Identity()() - local out1 = nn.Identity()(in1) - local module = nn.gModule({in1}, {out1}) - - local input = torch.Tensor({123}) - tester:eq(module:forward(input), input, "simple output") - tester:eq(module:forward({input}), {input}, "tabular output") -end - -function test.test_accGradParameters() - local input = torch.randn(10) - - local in1 = nn.CMul(input:nElement())() - local out1 = nn.Identity()(in1) - local out2 = nn.Identity()(in1) - local module = nn.gModule({in1}, {out1, out2}) - checkGradients(module, input) -end - -function test.test_example1() - local x1 = nn.Linear(20,10)() - local mout = nn.Linear(10,1)(nn.Tanh()(nn.Linear(10,10)(nn.Tanh()(x1)))) - local mlp = nn.gModule({x1},{mout}) - - local x = torch.rand(20) - checkGradients(mlp, x) -end - -function test.test_example2() - local x1=nn.Linear(20,20)() - local x2=nn.Linear(10,10)() - local m0=nn.Linear(20,1)(nn.Tanh()(x1)) - local m1=nn.Linear(10,1)(nn.Tanh()(x2)) - local madd=nn.CAddTable()({m0,m1}) - local m2=nn.Sigmoid()(madd) - local m3=nn.Tanh()(madd) - local gmod = nn.gModule({x1,x2},{m2,m3}) - - local x = torch.rand(20) - local y = torch.rand(10) - checkGradients(gmod, {x, y}) -end - -function test.test_example3() - local m = nn.Sequential() - m:add(nn.SplitTable(1)) - m:add(nn.ParallelTable():add(nn.Linear(10,20)):add(nn.Linear(10,30))) - local input = nn.Identity()() - local input1,input2 = m(input):split(2) - local m3 = nn.JoinTable(1)({input1,input2}) - local g = nn.gModule({input},{m3}) - - local indata = torch.rand(2,10) - checkGradients(g, indata) -end - -function test.test_example4() - local input = nn.Identity()() - local L1 = nn.Tanh()(nn.Linear(1,2)(input)) - local L2 = nn.Tanh()(nn.Linear(3,6)(nn.JoinTable(1)({input,L1}))) - local L3 = nn.Tanh()(nn.Linear(8,16)(nn.JoinTable(1)({L1,L2}))) - local g = nn.gModule({input},{L3}) - - local indata = torch.rand(1) - checkGradients(g, indata) -end - -function test.test_type() - local in1 = nn.Linear(20,10)() - local out1 = nn.Linear(10,1)(nn.Tanh()(nn.Linear(10,10)(nn.Tanh()(in1)))) - local module = nn.gModule({in1}, {out1}) - local input = torch.rand(20) - local output = module:forward(input) - module:backward(input, output) - tester:eq(torch.typename(output), "torch.DoubleTensor") - tester:eq(torch.typename(module.output), "torch.DoubleTensor") - tester:eq(torch.typename(module.gradInput), "torch.DoubleTensor") - tester:eq(torch.typename(module.innode.data.input[1]), "torch.DoubleTensor") - tester:eq(torch.typename(module.outnode.data.input[1]), "torch.DoubleTensor") - tester:eq(torch.typename(module.forwardnodes[1].data.input[1]), "torch.DoubleTensor") - - module:float() - local output = module:forward(input:float()) - tester:eq(torch.typename(output), "torch.FloatTensor") - tester:eq(torch.typename(module.output), "torch.FloatTensor") - tester:eq(torch.typename(module.gradInput), "torch.FloatTensor") - tester:eq(torch.typename(module.innode.data.input[1]), "torch.FloatTensor") - tester:eq(torch.typename(module.outnode.data.input[1]), "torch.FloatTensor") - tester:eq(torch.typename(module.forwardnodes[1].data.input[1]), "torch.FloatTensor") -end - -function test.test_nestedGradInput() - local x = nn.Identity()() - local h1 = nn.Sequential():add(nn.JoinTable(2)):add(nn.Tanh()) - local h2 = nn.Sequential():add(nn.JoinTable(2)):add(nn.Identity()) - local out = nn.CAddTable()({h1(x), h2(x)}) - - local model = nn.gModule({x}, {out}) - - local input = {} - input[1] = torch.randn(3, 3) - input[2] = torch.randn(3, 3) - input[3] = torch.randn(3, 3) - - checkGradients(model, input) - - local input = {} - input[1] = torch.randn(2, 3) - input[2] = torch.randn(2, 3) - input[3] = torch.randn(2, 3) - - checkGradients(model, input) -end - -function test.test_unusedInput() - local x = nn.Identity()() - local h = nn.Identity()() - local h2 = nn.Identity()() - - local ok, result = pcall(nn.gModule, {x, h}, {x}) - assert(not ok, "the unused input should be detected") -end - -function test.test_unusedChild() - local prevState = nn.Identity()() - local h, cell = prevState:split(2) - - local ok, result = pcall(nn.gModule, {prevState}, {h}) - assert(not ok, "the unused cell should be detected") -end - -function test.test_nilInput() - local ok, result = pcall(function() nn.Sigmoid()(nil) end) - assert(not ok, "the nil input should be detected") -end - -function test.test_unusedNode() - local in1 = nn.Identity()() - local in2 = nn.Identity()() - local middleResult = nn.Sigmoid()(in2) - local out1 = nn.Sigmoid()(in1) - - local ok, result = pcall(nn.gModule, {in1, in2}, {out1}) - assert(not ok, "the unused middleResult should be detected") -end - -function test.test_usageAfterSplit() - local prevState = nn.Identity()() - local h, cell = prevState:split(2) - local nextState = nn.Identity()(prevState) - local transformed = nn.Sigmoid()(cell) - - local model = nn.gModule({prevState}, {h, nextState, transformed}) - local nHidden = 10 - local input = {torch.randn(nHidden), torch.randn(nHidden)} - checkGradients(model, input) -end - -function test.test_resizeNestedAs() - local in1 = nn.Identity()() - local out1 = nn.Identity()(in1) - local out2 = nn.Identity()(in1) - - local net = nn.gModule({in1}, {out1, out2}) - local input = {torch.randn(10), {torch.randn(3), torch.randn(4)}} - net:forward(input) - net:backward(input, net.output) - checkGradients(net, input) - - input = {torch.randn(10), {torch.randn(3), torch.randn(4), torch.randn(5)}} - net:forward(input) - net:backward(input, net.output) - checkGradients(net, input) - - input = {torch.randn(10), {torch.randn(3), torch.randn(4)}} - net:forward(input) - local gradInput = net:backward(input, net.output) - tester:eq(#(gradInput[2]), 2, "gradInput[2] size") - checkGradients(net, input) -end - - -function test.test_annotateGraph() - local input = nn.Identity()():annotate( + local xInput = torch.randn(3) + local h = torch.randn(3) + + local x = nn.Identity()() + local prevRnnState = nn.Identity()() + local prevH1, prevCell = prevRnnState:split(2) + local prevH = prevH1 + + local cellOut = nn.CAddTable()({ + nn.CMulTable()({x, prevH}), + nn.CMulTable()({prevH, prevCell})}) + local module = nn.gModule({x, prevRnnState}, {cellOut}) + + local c = torch.randn(h:size()) + local prevRnnState = {h, c} + local input = {xInput, prevRnnState} + local output = module:forward(input) + + local gradOutput = torch.randn(h:size()) + local gradInput = module:backward(input, gradOutput) + + local gradX, gradPrevState = unpack(gradInput) + local gradPrevH, gradPrevCell = unpack(gradPrevState) + assert(type(gradPrevH) == type(h), "wrong gradPrevH type") + + tester:eq(type(gradPrevH), type(h), "wrong gradPrevH type") + tester:eq(gradPrevH:size(), h:size(), "wrong gradPrevH size") + checkGradients(module, input) + end + + function test.test_tabularInput() + local in1 = nn.SplitTable(1)() + local out1 = nn.CAddTable()(in1) + local module = nn.gModule({in1}, {out1}) + + local input = torch.randn(2, 3) + checkGradients(module, input) + end + + function test.test_extraTable() + local in1 = nn.Identity()() + local out1 = nn.Identity()(in1) + local module = nn.gModule({in1}, {out1}) + + local input = torch.Tensor({123}) + tester:eq(module:forward(input), input, "simple output") + tester:eq(module:forward({input}), {input}, "tabular output") + end + + function test.test_accGradParameters() + local input = torch.randn(10) + + local in1 = nn.CMul(input:nElement())() + local out1 = nn.Identity()(in1) + local out2 = nn.Identity()(in1) + local module = nn.gModule({in1}, {out1, out2}) + checkGradients(module, input) + end + + function test.test_example1() + local x1 = nn.Linear(20,10)() + local mout = nn.Linear(10,1)(nn.Tanh()(nn.Linear(10,10)(nn.Tanh()(x1)))) + local mlp = nn.gModule({x1},{mout}) + + local x = torch.rand(20) + checkGradients(mlp, x) + end + + function test.test_example2() + local x1=nn.Linear(20,20)() + local x2=nn.Linear(10,10)() + local m0=nn.Linear(20,1)(nn.Tanh()(x1)) + local m1=nn.Linear(10,1)(nn.Tanh()(x2)) + local madd=nn.CAddTable()({m0,m1}) + local m2=nn.Sigmoid()(madd) + local m3=nn.Tanh()(madd) + local gmod = nn.gModule({x1,x2},{m2,m3}) + + local x = torch.rand(20) + local y = torch.rand(10) + checkGradients(gmod, {x, y}) + end + + function test.test_example3() + local m = nn.Sequential() + m:add(nn.SplitTable(1)) + m:add(nn.ParallelTable():add(nn.Linear(10,20)):add(nn.Linear(10,30))) + local input = nn.Identity()() + local input1,input2 = m(input):split(2) + local m3 = nn.JoinTable(1)({input1,input2}) + local g = nn.gModule({input},{m3}) + + local indata = torch.rand(2,10) + checkGradients(g, indata) + end + + function test.test_example4() + local input = nn.Identity()() + local L1 = nn.Tanh()(nn.Linear(1,2)(input)) + local L2 = nn.Tanh()(nn.Linear(3,6)(nn.JoinTable(1)({input,L1}))) + local L3 = nn.Tanh()(nn.Linear(8,16)(nn.JoinTable(1)({L1,L2}))) + local g = nn.gModule({input},{L3}) + + local indata = torch.rand(1) + checkGradients(g, indata) + end + + function test.test_type() + local in1 = nn.Linear(20,10)() + local out1 = nn.Linear(10,1)(nn.Tanh()(nn.Linear(10,10)(nn.Tanh()(in1)))) + local module = nn.gModule({in1}, {out1}) + local input = torch.rand(20) + local output = module:forward(input) + module:backward(input, output) + tester:eq(torch.typename(output), "torch.DoubleTensor") + tester:eq(torch.typename(module.output), "torch.DoubleTensor") + tester:eq(torch.typename(module.gradInput), "torch.DoubleTensor") + tester:eq(torch.typename(module.innode.data.input[1]), "torch.DoubleTensor") + tester:eq(torch.typename(module.outnode.data.input[1]), "torch.DoubleTensor") + tester:eq(torch.typename(module.forwardnodes[1].data.input[1]), "torch.DoubleTensor") + + module:float() + local output = module:forward(input:float()) + tester:eq(torch.typename(output), "torch.FloatTensor") + tester:eq(torch.typename(module.output), "torch.FloatTensor") + tester:eq(torch.typename(module.gradInput), "torch.FloatTensor") + tester:eq(torch.typename(module.innode.data.input[1]), "torch.FloatTensor") + tester:eq(torch.typename(module.outnode.data.input[1]), "torch.FloatTensor") + tester:eq(torch.typename(module.forwardnodes[1].data.input[1]), "torch.FloatTensor") + end + + function test.test_nestedGradInput() + local x = nn.Identity()() + local h1 = nn.Sequential():add(nn.JoinTable(2)):add(nn.Tanh()) + local h2 = nn.Sequential():add(nn.JoinTable(2)):add(nn.Identity()) + local out = nn.CAddTable()({h1(x), h2(x)}) + + local model = nn.gModule({x}, {out}) + + local input = {} + input[1] = torch.randn(3, 3) + input[2] = torch.randn(3, 3) + input[3] = torch.randn(3, 3) + + checkGradients(model, input) + + local input = {} + input[1] = torch.randn(2, 3) + input[2] = torch.randn(2, 3) + input[3] = torch.randn(2, 3) + + checkGradients(model, input) + end + + function test.test_unusedInput() + local x = nn.Identity()() + local h = nn.Identity()() + local h2 = nn.Identity()() + + local ok, result = pcall(nn.gModule, {x, h}, {x}) + assert(not ok, "the unused input should be detected") + end + + function test.test_unusedChild() + local prevState = nn.Identity()() + local h, cell = prevState:split(2) + + local ok, result = pcall(nn.gModule, {prevState}, {h}) + assert(not ok, "the unused cell should be detected") + end + + function test.test_nilInput() + local ok, result = pcall(function() nn.Sigmoid()(nil) end) + assert(not ok, "the nil input should be detected") + end + + function test.test_unusedNode() + local in1 = nn.Identity()() + local in2 = nn.Identity()() + local middleResult = nn.Sigmoid()(in2) + local out1 = nn.Sigmoid()(in1) + + local ok, result = pcall(nn.gModule, {in1, in2}, {out1}) + assert(not ok, "the unused middleResult should be detected") + end + + function test.test_usageAfterSplit() + local prevState = nn.Identity()() + local h, cell = prevState:split(2) + local nextState = nn.Identity()(prevState) + local transformed = nn.Sigmoid()(cell) + + local model = nn.gModule({prevState}, {h, nextState, transformed}) + local nHidden = 10 + local input = {torch.randn(nHidden), torch.randn(nHidden)} + checkGradients(model, input) + end + + function test.test_resizeNestedAs() + local in1 = nn.Identity()() + local out1 = nn.Identity()(in1) + local out2 = nn.Identity()(in1) + + local net = nn.gModule({in1}, {out1, out2}) + local input = {torch.randn(10), {torch.randn(3), torch.randn(4)}} + net:forward(input) + net:backward(input, net.output) + checkGradients(net, input) + + input = {torch.randn(10), {torch.randn(3), torch.randn(4), torch.randn(5)}} + net:forward(input) + net:backward(input, net.output) + checkGradients(net, input) + + input = {torch.randn(10), {torch.randn(3), torch.randn(4)}} + net:forward(input) + local gradInput = net:backward(input, net.output) + tester:eq(#(gradInput[2]), 2, "gradInput[2] size") + checkGradients(net, input) + end + + + function test.test_annotateGraph() + local input = nn.Identity()():annotate( {name = 'Input', description = 'DescA', - graphAttributes = {color = 'red'}}) + graphAttributes = {color = 'red'}}) - local hidden_a = nn.Linear(10, 10)(input):annotate( + local hidden_a = nn.Linear(10, 10)(input):annotate( {name = 'Hidden A', description = 'DescB', - graphAttributes = {color = 'blue', fontcolor='green', tooltip = 'I am green'}}) - local hidden_b = nn.Sigmoid()(hidden_a) - local output = nn.Linear(10, 10)(hidden_b) - local net = nn.gModule({input}, {output}) - - tester:assert(hidden_a:label():match('DescB')) - local fg_tmpfile = os.tmpname() - local bg_tmpfile = os.tmpname() - graph.dot(net.fg, 'Test', fg_tmpfile) - graph.dot(net.fg, 'Test BG', bg_tmpfile) - - local function checkDotFile(tmpfile) - local dotcontent = io.open(tmpfile .. '.dot', 'r'):read("*all") - tester:assert( - dotcontent:match('%[color=red.*label=%"Input.*DescA.*%".*%]')) - tester:assert( - dotcontent:match( - '%[.*fontcolor=green.*label=%"Hidden A.*DescB.*%".*%]')) - tester:assert( - dotcontent:match('%[color=blue.*label=%".*DescB.*%".*%]')) - tester:assert( - dotcontent:match( - '%[.*label=%".*DescB.*%".*tooltip=%"I am green%".*%]')) - end - - checkDotFile(fg_tmpfile) - checkDotFile(bg_tmpfile) -end - -function test.test_splitMore() - local nSplits = 2 - local in1 = nn.Identity()() - local out1, out2 = nn.SplitTable(2)(in1):split(nSplits) - - local model = nn.gModule({in1}, {out1, out2}) - local input = torch.randn(10, nSplits + 1) - local ok, result = pcall(model.forward, model, input) - assert(not ok, "the extra input to split should be detected") -end - -function test.test_splitLess() - local nSplits = 3 - local in1 = nn.Identity()() - local out1, out2, out3 = nn.SplitTable(2)(in1):split(nSplits) - - local model = nn.gModule({in1}, {out1, out2, out3}) - local input = torch.randn(10, nSplits - 1) - local ok, result = pcall(model.forward, model, input) - assert(not ok, "the missing input to split should be detected") -end - -tester:add(test):run() + graphAttributes = {color = 'blue', fontcolor='green', tooltip = 'I am green'}}) + local hidden_b = nn.Sigmoid()(hidden_a) + local output = nn.Linear(10, 10)(hidden_b) + local net = nn.gModule({input}, {output}) + + tester:assert(hidden_a:label():match('DescB')) + local fg_tmpfile = os.tmpname() + local bg_tmpfile = os.tmpname() + graph.dot(net.fg, 'Test', fg_tmpfile) + graph.dot(net.fg, 'Test BG', bg_tmpfile) + + local function checkDotFile(tmpfile) + local dotcontent = io.open(tmpfile .. '.dot', 'r'):read("*all") + tester:assert( + dotcontent:match('%[color=red.*label=%"Input.*DescA.*%".*%]')) + tester:assert( + dotcontent:match( + '%[.*fontcolor=green.*label=%"Hidden A.*DescB.*%".*%]')) + tester:assert( + dotcontent:match('%[color=blue.*label=%".*DescB.*%".*%]')) + tester:assert( + dotcontent:match( + '%[.*label=%".*DescB.*%".*tooltip=%"I am green%".*%]')) + end + + checkDotFile(fg_tmpfile) + checkDotFile(bg_tmpfile) + end + + function test.test_splitMore() + local nSplits = 2 + local in1 = nn.Identity()() + local out1, out2 = nn.SplitTable(2)(in1):split(nSplits) + + local model = nn.gModule({in1}, {out1, out2}) + local input = torch.randn(10, nSplits + 1) + local ok, result = pcall(model.forward, model, input) + assert(not ok, "the extra input to split should be detected") + end + + function test.test_splitLess() + local nSplits = 3 + local in1 = nn.Identity()() + local out1, out2, out3 = nn.SplitTable(2)(in1):split(nSplits) + + local model = nn.gModule({in1}, {out1, out2, out3}) + local input = torch.randn(10, nSplits - 1) + local ok, result = pcall(model.forward, model, input) + assert(not ok, "the missing input to split should be detected") + end + + tester:add(test):run() diff --git a/test/test_old.lua b/test/test_old.lua index b9e2f6c..1a1e862 100644 --- a/test/test_old.lua +++ b/test/test_old.lua @@ -1,227 +1,227 @@ require 'nngraph' function t1() - local x1 = nn.Linear(20,20)() - local x2 = nn.Linear(10,10)() - local m0=nn.Linear(20,1)(nn.Tanh()(x1)) - local m1=nn.Linear(10,1)(nn.Tanh()(x2)) - local madd=nn.CAddTable()({m0,m1}) - local m2=nn.Sigmoid()(madd) - local m3=nn.Tanh()(madd) - local x = torch.rand(20) - local y = torch.rand(10) - gmod = nn.gModule({x1,x2},{m2,m3}) - gmod.verbose = true - print('forward') - gmod:updateOutput({x,y}) - print('updateGradInput') - gmod:updateGradInput({x,y},{torch.rand(1),torch.rand(1)}) - graph.dot(gmod.fg,'forward') - graph.dot(gmod.bg,'backward') + local x1 = nn.Linear(20,20)() + local x2 = nn.Linear(10,10)() + local m0=nn.Linear(20,1)(nn.Tanh()(x1)) + local m1=nn.Linear(10,1)(nn.Tanh()(x2)) + local madd=nn.CAddTable()({m0,m1}) + local m2=nn.Sigmoid()(madd) + local m3=nn.Tanh()(madd) + local x = torch.rand(20) + local y = torch.rand(10) + gmod = nn.gModule({x1,x2},{m2,m3}) + gmod.verbose = true + print('forward') + gmod:updateOutput({x,y}) + print('updateGradInput') + gmod:updateGradInput({x,y},{torch.rand(1),torch.rand(1)}) + graph.dot(gmod.fg,'forward') + graph.dot(gmod.bg,'backward') end function t2() - print('compare') - local m0 = nn.Linear(5,10)() - local m1 = nn.Linear(10,20)() - local m2 = nn.Linear(30,50)(nn.JoinTable(1){m0,m1}) - gmod = nn.gModule({m0,m1},{m2}) - - local nn0 = nn.Linear(5,10) - local nn1 = nn.Linear(10,20) - local nn2 = nn.Linear(30,50) - local nnmod = nn.Sequential():add(nn.ParallelTable():add(nn0):add(nn1)):add(nn.JoinTable(1)):add(nn2) - - nn0.weight:copy(m0.data.module.weight) - nn0.bias:copy(m0.data.module.bias) - nn1.weight:copy(m1.data.module.weight) - nn1.bias:copy(m1.data.module.bias) - nn2.weight:copy(m2.data.module.weight) - nn2.bias:copy(m2.data.module.bias) - - - for i=1,5 do - local x,y = torch.rand(5),torch.rand(10) - local xx,yy = x:clone(),y:clone() - - gmod:updateOutput({x,y}) - nnmod:updateOutput({xx,yy}) - print('fdiff = ', torch.dist(gmod.output,nnmod.output)) - - local odx = torch.rand(50) - local odxx = odx:clone() - - gmod:updateGradInput({x,y},odx) - nnmod:updateGradInput({xx,yy},odxx) - graph.dot(gmod.fg,tostring(i)) - for i,v in ipairs(gmod.gradInput) do - print('bdiff [' ..i.. '] = ', torch.dist(gmod.gradInput[i],nnmod.gradInput[i])) - end - end - - local gms = {m0,m1,m2} - local nms = {nn0,nn1,nn2} - - for i=1,5 do - local x,y = torch.rand(5),torch.rand(10) - local xx,yy = x:clone(),y:clone() - - gmod:updateOutput({x,y}) - nnmod:updateOutput({xx,yy}) - print('fdiff = ', torch.dist(gmod.output,nnmod.output)) - - local odx = torch.rand(50) - local odxx = odx:clone() - - gmod:zeroGradParameters() - nnmod:zeroGradParameters() - - gmod:updateGradInput({x,y},odx) - nnmod:updateGradInput({xx,yy},odxx) - - gmod:accGradParameters({x,y},odx) - nnmod:accGradParameters({xx,yy},odxx) - graph.dot(gmod.fg) - for i,v in ipairs(gms) do - print('accdiff [' ..i.. '] = ', torch.dist(gms[i].data.module.gradWeight,nms[i].gradWeight)) - print('accdiff [' ..i.. '] = ', torch.dist(gms[i].data.module.gradBias,nms[i].gradBias)) - end - end + print('compare') + local m0 = nn.Linear(5,10)() + local m1 = nn.Linear(10,20)() + local m2 = nn.Linear(30,50)(nn.JoinTable(1){m0,m1}) + gmod = nn.gModule({m0,m1},{m2}) + + local nn0 = nn.Linear(5,10) + local nn1 = nn.Linear(10,20) + local nn2 = nn.Linear(30,50) + local nnmod = nn.Sequential():add(nn.ParallelTable():add(nn0):add(nn1)):add(nn.JoinTable(1)):add(nn2) + + nn0.weight:copy(m0.data.module.weight) + nn0.bias:copy(m0.data.module.bias) + nn1.weight:copy(m1.data.module.weight) + nn1.bias:copy(m1.data.module.bias) + nn2.weight:copy(m2.data.module.weight) + nn2.bias:copy(m2.data.module.bias) + + + for i=1,5 do + local x,y = torch.rand(5),torch.rand(10) + local xx,yy = x:clone(),y:clone() + + gmod:updateOutput({x,y}) + nnmod:updateOutput({xx,yy}) + print('fdiff = ', torch.dist(gmod.output,nnmod.output)) + + local odx = torch.rand(50) + local odxx = odx:clone() + + gmod:updateGradInput({x,y},odx) + nnmod:updateGradInput({xx,yy},odxx) + graph.dot(gmod.fg,tostring(i)) + for i,v in ipairs(gmod.gradInput) do + print('bdiff [' ..i.. '] = ', torch.dist(gmod.gradInput[i],nnmod.gradInput[i])) + end + end + + local gms = {m0,m1,m2} + local nms = {nn0,nn1,nn2} + + for i=1,5 do + local x,y = torch.rand(5),torch.rand(10) + local xx,yy = x:clone(),y:clone() + + gmod:updateOutput({x,y}) + nnmod:updateOutput({xx,yy}) + print('fdiff = ', torch.dist(gmod.output,nnmod.output)) + + local odx = torch.rand(50) + local odxx = odx:clone() + + gmod:zeroGradParameters() + nnmod:zeroGradParameters() + + gmod:updateGradInput({x,y},odx) + nnmod:updateGradInput({xx,yy},odxx) + + gmod:accGradParameters({x,y},odx) + nnmod:accGradParameters({xx,yy},odxx) + graph.dot(gmod.fg) + for i,v in ipairs(gms) do + print('accdiff [' ..i.. '] = ', torch.dist(gms[i].data.module.gradWeight,nms[i].gradWeight)) + print('accdiff [' ..i.. '] = ', torch.dist(gms[i].data.module.gradBias,nms[i].gradBias)) + end + end end function t3() - mlp=nn.Sequential(); --Create a network that takes a Tensor as input - mlp:add(nn.SplitTable(2)) - c=nn.ParallelTable() --The two Tensors go through two different Linear - c:add(nn.Linear(10,3)) --Layers in Parallel - c:add(nn.Linear(10,7)) - mlp:add(c) --Outputing a table with 2 elements - p=nn.ParallelTable() --These tables go through two more linear layers - p:add(nn.Linear(3,2)) -- separately. - p:add(nn.Linear(7,1)) - mlp:add(p) - mlp:add(nn.JoinTable(1)) --Finally, the tables are joined together and output. - - pred=mlp:forward(torch.randn(10,2)) - print(pred) - - for i=1,25 do -- A few steps of training such a network.. - x=torch.ones(10,2); - y=torch.Tensor(3); y:copy(x:select(2,1,1):narrow(1,1,3)) - pred=mlp:forward(x) - - criterion= nn.MSECriterion() - local err=criterion:forward(pred,y) - local gradCriterion = criterion:backward(pred,y); - print(x,y) - mlp:zeroGradParameters(); - mlp:backward(x, gradCriterion); - mlp:updateParameters(0.05); - - print(err) - end + mlp=nn.Sequential(); --Create a network that takes a Tensor as input + mlp:add(nn.SplitTable(2)) + c=nn.ParallelTable() --The two Tensors go through two different Linear + c:add(nn.Linear(10,3)) --Layers in Parallel + c:add(nn.Linear(10,7)) + mlp:add(c) --Outputing a table with 2 elements + p=nn.ParallelTable() --These tables go through two more linear layers + p:add(nn.Linear(3,2)) -- separately. + p:add(nn.Linear(7,1)) + mlp:add(p) + mlp:add(nn.JoinTable(1)) --Finally, the tables are joined together and output. + + pred=mlp:forward(torch.randn(10,2)) + print(pred) + + for i=1,25 do -- A few steps of training such a network.. + x=torch.ones(10,2); + y=torch.Tensor(3); y:copy(x:select(2,1,1):narrow(1,1,3)) + pred=mlp:forward(x) + + criterion= nn.MSECriterion() + local err=criterion:forward(pred,y) + local gradCriterion = criterion:backward(pred,y); + print(x,y) + mlp:zeroGradParameters(); + mlp:backward(x, gradCriterion); + mlp:updateParameters(0.05); + + print(err) + end end function t4() - local getInput1 = nn.Identity()() - local getInput2 = nn.Identity()() - local mlp = nn.Tanh()(getInput1) - net = nn.gModule({getInput1, getInput2}, {mlp, getInput2}) + local getInput1 = nn.Identity()() + local getInput2 = nn.Identity()() + local mlp = nn.Tanh()(getInput1) + net = nn.gModule({getInput1, getInput2}, {mlp, getInput2}) - local input1 = torch.randn(2) - local input2 = torch.randn(5) + local input1 = torch.randn(2) + local input2 = torch.randn(5) - net:forward({input1, input2}) - local gradInput = net:backward({input1, input2}, - {torch.randn(input1:size()), torch.randn(input2:size())}) - print("gradInput[1]:", gradInput[1]) - print("gradInput[2]:", gradInput[2]) - graph.dot(net.fg) - assert(gradInput[1]:nElement() == input1:nElement(), "size mismatch") + net:forward({input1, input2}) + local gradInput = net:backward({input1, input2}, + {torch.randn(input1:size()), torch.randn(input2:size())}) + print("gradInput[1]:", gradInput[1]) + print("gradInput[2]:", gradInput[2]) + graph.dot(net.fg) + assert(gradInput[1]:nElement() == input1:nElement(), "size mismatch") end function t5() - local m = nn.Sequential() - m:add(nn.SplitTable(1)) - m:add(nn.ParallelTable():add(nn.Linear(10,20)):add(nn.Linear(10,30))) - local input = nn.Identity()() - local input1,input2 = m(input):split(2) - local m3 = nn.JoinTable(1)({input1,input2}) - - g = nn.gModule({input},{m3}) - graph.dot(g.fg,'init forward') - - local indata = torch.rand(2,10) - local gdata = torch.rand(50) - g:forward(indata) - g:backward(indata,gdata) - - graph.dot(g.fg,'forward') - graph.dot(g.bg,'backward') + local m = nn.Sequential() + m:add(nn.SplitTable(1)) + m:add(nn.ParallelTable():add(nn.Linear(10,20)):add(nn.Linear(10,30))) + local input = nn.Identity()() + local input1,input2 = m(input):split(2) + local m3 = nn.JoinTable(1)({input1,input2}) + + g = nn.gModule({input},{m3}) + graph.dot(g.fg,'init forward') + + local indata = torch.rand(2,10) + local gdata = torch.rand(50) + g:forward(indata) + g:backward(indata,gdata) + + graph.dot(g.fg,'forward') + graph.dot(g.bg,'backward') end function topsort(a) - -- first clone the graph - -- local g = self:clone() - -- local nodes = g.nodes - -- local edges = g.edges - -- for i,node in ipairs(nodes) do - -- node.children = {} - -- end - - -- reverse the graph - rg,map = a:reverse() - local rmap = {} - for k,v in pairs(map) do - rmap[v] = k - end - - -- work on the sorted graph - sortednodes = {} - rootnodes = rg:roots() - - if #rootnodes == 0 then - print('Graph has cycles') - end - - -- run - for i,root in ipairs(rootnodes) do - root:dfs(function(node) - print(node.id,rmap[node].id) - -- print(rmap[node]) - table.insert(sortednodes,rmap[node]) end) - end - - if #sortednodes ~= #a.nodes then - print('Graph has cycles') - end - return sortednodes,rg,rootnodes -end - -local my={eq = - function(a,b,s) - if a:dist(b) == 0 then - print('ok') - else - print('error : ' .. s) - print('a : ');print(a) - print('b : ');print(b) - end - end} - -function t8() - local in1 = nn.Identity()() - local m = nn.Linear(10,10)(in1) - local out1 = nn.Tanh()(m) - local out2 = nn.Tanh()(m) - local out = nn.CAddTable(){out1, out2} - local mod = nn.gModule({in1}, {out}) - - local dot = nngraph.simple_print.todot(mod.fg, 'bogus') - print (dot) - nngraph.simple_print.dot(mod.fg, 'bogus', 'new') - graph.dot(mod.fg, 'bogus', 'old') -end --- t2() -t8() + -- first clone the graph + -- local g = self:clone() + -- local nodes = g.nodes + -- local edges = g.edges + -- for i,node in ipairs(nodes) do + -- node.children = {} + -- end + + -- reverse the graph + rg,map = a:reverse() + local rmap = {} + for k,v in pairs(map) do + rmap[v] = k + end + + -- work on the sorted graph + sortednodes = {} + rootnodes = rg:roots() + + if #rootnodes == 0 then + print('Graph has cycles') + end + + -- run + for i,root in ipairs(rootnodes) do + root:dfs(function(node) + print(node.id,rmap[node].id) + -- print(rmap[node]) + table.insert(sortednodes,rmap[node]) end) + end + + if #sortednodes ~= #a.nodes then + print('Graph has cycles') + end + return sortednodes,rg,rootnodes + end + + local my={eq = + function(a,b,s) + if a:dist(b) == 0 then + print('ok') + else + print('error : ' .. s) + print('a : ');print(a) + print('b : ');print(b) + end + end} + + function t8() + local in1 = nn.Identity()() + local m = nn.Linear(10,10)(in1) + local out1 = nn.Tanh()(m) + local out2 = nn.Tanh()(m) + local out = nn.CAddTable(){out1, out2} + local mod = nn.gModule({in1}, {out}) + + local dot = nngraph.simple_print.todot(mod.fg, 'bogus') + print (dot) + nngraph.simple_print.dot(mod.fg, 'bogus', 'new') + graph.dot(mod.fg, 'bogus', 'old') + end + -- t2() + t8() @@ -1,20 +1,18 @@ local utils = {} function utils.istensor(x) - if torch.typename(x) and torch.typename(x):find('Tensor') then - return true - end - return false + if torch.typename(x) and torch.typename(x):find('Tensor') then + return true + end + return false end function utils.istorchclass(x) - return type(x) == 'table' and torch.typename(x) + return type(x) == 'table' and torch.typename(x) end function utils.istable(x) - return type(x) == 'table' and not torch.typename(x) + return type(x) == 'table' and not torch.typename(x) end return utils - - |