Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nngraph.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorClement Farabet <cfarabet@twitter.com>2015-09-04 23:31:01 +0300
committerClement Farabet <cfarabet@twitter.com>2015-09-04 23:31:01 +0300
commitba7f3ec6ffe7e60e5e07b2886178aec54e6305e5 (patch)
tree23aaf8f288fdce23b7361a1e7252254a690a6fb9
parent72f74d39257a344a2c3237c83f8a828b916817e4 (diff)
Whitespace cleanup.
-rw-r--r--gmodule.lua624
-rw-r--r--graphinspecting.lua212
-rw-r--r--init.lua44
-rw-r--r--nesting.lua83
-rw-r--r--node.lua235
-rw-r--r--simple_print.lua197
-rw-r--r--test/speed.lua169
-rw-r--r--test/test_ModuleFromCriterion.lua68
-rw-r--r--test/test_nngraph.lua696
-rw-r--r--test/test_old.lua412
-rw-r--r--utils.lua14
11 files changed, 1370 insertions, 1384 deletions
diff --git a/gmodule.lua b/gmodule.lua
index 6489f61..79f3033 100644
--- a/gmodule.lua
+++ b/gmodule.lua
@@ -6,21 +6,21 @@ local istable = utils.istable
local istorchclass = utils.istorchclass
local function getTotalGradOutput(node)
- local gradOutput = node.data.gradOutput
- assert(istable(gradOutput), "expecting gradients to sum")
- if #gradOutput > 1 then
- node.data.gradOutputBuffer = node.data.gradOutputBuffer or nesting.cloneNested(gradOutput[1])
- local gobuff = node.data.gradOutputBuffer
- nesting.resizeNestedAs(gobuff, gradOutput[1])
- nesting.fillNested(gobuff, 0)
- for i=1,#gradOutput do
- nesting.addNestedTo(gobuff, gradOutput[i])
- end
- gradOutput = gobuff
- else
- gradOutput = gradOutput[1]
- end
- return gradOutput
+ local gradOutput = node.data.gradOutput
+ assert(istable(gradOutput), "expecting gradients to sum")
+ if #gradOutput > 1 then
+ node.data.gradOutputBuffer = node.data.gradOutputBuffer or nesting.cloneNested(gradOutput[1])
+ local gobuff = node.data.gradOutputBuffer
+ nesting.resizeNestedAs(gobuff, gradOutput[1])
+ nesting.fillNested(gobuff, 0)
+ for i=1,#gradOutput do
+ nesting.addNestedTo(gobuff, gradOutput[i])
+ end
+ gradOutput = gobuff
+ else
+ gradOutput = gradOutput[1]
+ end
+ return gradOutput
end
-- The gModule allows to have a general non-cyclic graph of of modules.
@@ -43,84 +43,84 @@ end
local gModule, parent = torch.class('nn.gModule','nn.Module')
function gModule:__init(inputs,outputs)
- parent.__init(self)
- -- the graph is defined backwards, we have the output modules as input here
- -- we will define a dummy output node that connects all output modules
- -- into itself. This will be the output for the forward graph and
- -- input point for the backward graph
- local outnode = nngraph.Node({input={}})
- for i,n in ipairs(outputs) do
- if torch.typename(n) ~= 'nngraph.Node' then
- error(string.format('what is this in the outputs[%s]? %s',
- i, tostring(n)))
- end
- outnode:add(n,true)
- end
- for i,n in ipairs(inputs) do
- if torch.typename(n) ~= 'nngraph.Node' then
- error(string.format('what is this in the inputs[%s]? %s',
- i, tostring(n)))
- end
- end
- -- We add also a dummy input node.
- -- The input node will be split to feed the passed input nodes.
- local innode = nngraph.Node({input={}})
- assert(#inputs > 0, "no inputs are not supported")
- if #inputs == 1 then
- inputs[1]:add(innode,true)
- else
- local splits = {innode:split(#inputs)}
- for i = 1, #inputs do
- assert(#inputs[i].children == 0, "an input should have no inputs")
- end
- for i = 1, #inputs do
- inputs[i]:add(splits[i],true)
- end
- end
-
- -- the backward graph (bg) is for gradients
- -- the forward graph (fg) is for function evaluation
- self.bg = outnode:graph()
- self.fg = self.bg:reverse()
-
- -- the complete graph is constructed
- -- now regenerate the graphs with the additional nodes
- assert(#self.fg:roots() == 1, "expecting only one start")
- self.innode = self.fg:roots()[1]
- assert(self.innode.data == innode.data, "expecting the forward innode")
- self.outnode = outnode
- self.verbose = false
- self.nInputs = #inputs
-
- -- computation on the graph is done through topsort of forward and backward graphs
- self.forwardnodes = self.fg:topsort()
- self.backwardnodes = self.bg:topsort()
- -- Checking for unused inputs or unused split() outputs.
- for i,forwardNode in ipairs(self.forwardnodes) do
- if forwardNode.data.nSplitOutputs and forwardNode.data.nSplitOutputs ~= #forwardNode.children then
- local nUnused = forwardNode.data.nSplitOutputs - #forwardNode.children
- error(string.format("%s of split(%s) outputs are unused", nUnused,
- forwardNode.data.nSplitOutputs))
- end
- end
- -- Adding data.forwardNodeId for nicer node:label() output.
- for i,forwardNode in ipairs(self.forwardnodes) do
- forwardNode.data.forwardNodeId = forwardNode.id
- end
-
- self.output = nil
- self.gradInput = nil
- if #self.outnode.children > 1 then
- self.output = self.outnode.data.input
- end
+ parent.__init(self)
+ -- the graph is defined backwards, we have the output modules as input here
+ -- we will define a dummy output node that connects all output modules
+ -- into itself. This will be the output for the forward graph and
+ -- input point for the backward graph
+ local outnode = nngraph.Node({input={}})
+ for i,n in ipairs(outputs) do
+ if torch.typename(n) ~= 'nngraph.Node' then
+ error(string.format('what is this in the outputs[%s]? %s',
+ i, tostring(n)))
+ end
+ outnode:add(n,true)
+ end
+ for i,n in ipairs(inputs) do
+ if torch.typename(n) ~= 'nngraph.Node' then
+ error(string.format('what is this in the inputs[%s]? %s',
+ i, tostring(n)))
+ end
+ end
+ -- We add also a dummy input node.
+ -- The input node will be split to feed the passed input nodes.
+ local innode = nngraph.Node({input={}})
+ assert(#inputs > 0, "no inputs are not supported")
+ if #inputs == 1 then
+ inputs[1]:add(innode,true)
+ else
+ local splits = {innode:split(#inputs)}
+ for i = 1, #inputs do
+ assert(#inputs[i].children == 0, "an input should have no inputs")
+ end
+ for i = 1, #inputs do
+ inputs[i]:add(splits[i],true)
+ end
+ end
+
+ -- the backward graph (bg) is for gradients
+ -- the forward graph (fg) is for function evaluation
+ self.bg = outnode:graph()
+ self.fg = self.bg:reverse()
+
+ -- the complete graph is constructed
+ -- now regenerate the graphs with the additional nodes
+ assert(#self.fg:roots() == 1, "expecting only one start")
+ self.innode = self.fg:roots()[1]
+ assert(self.innode.data == innode.data, "expecting the forward innode")
+ self.outnode = outnode
+ self.verbose = false
+ self.nInputs = #inputs
+
+ -- computation on the graph is done through topsort of forward and backward graphs
+ self.forwardnodes = self.fg:topsort()
+ self.backwardnodes = self.bg:topsort()
+ -- Checking for unused inputs or unused split() outputs.
+ for i,forwardNode in ipairs(self.forwardnodes) do
+ if forwardNode.data.nSplitOutputs and forwardNode.data.nSplitOutputs ~= #forwardNode.children then
+ local nUnused = forwardNode.data.nSplitOutputs - #forwardNode.children
+ error(string.format("%s of split(%s) outputs are unused", nUnused,
+ forwardNode.data.nSplitOutputs))
+ end
+ end
+ -- Adding data.forwardNodeId for nicer node:label() output.
+ for i,forwardNode in ipairs(self.forwardnodes) do
+ forwardNode.data.forwardNodeId = forwardNode.id
+ end
+
+ self.output = nil
+ self.gradInput = nil
+ if #self.outnode.children > 1 then
+ self.output = self.outnode.data.input
+ end
end
function gModule:apply(func)
- for i,node in ipairs(self.forwardnodes) do
- if node.data.module then
- func(node.data.module)
- end
- end
+ for i,node in ipairs(self.forwardnodes) do
+ if node.data.module then
+ func(node.data.module)
+ end
+ end
end
function gModule:map(gm, func)
@@ -149,266 +149,266 @@ end
function gModule:share(gm, ...)
local args = {...}
self:map(gm,
- function(subnet1, subnet2)
- subnet1:share(subnet2, unpack(args))
+ function(subnet1, subnet2)
+ subnet1:share(subnet2, unpack(args))
end)
return self
end
function gModule:training()
- self:apply(function(module) module:training() end)
+ self:apply(function(module) module:training() end)
end
function gModule:evaluate()
- self:apply(function(module) module:evaluate() end)
+ self:apply(function(module) module:evaluate() end)
end
--[[ Recursively applies type(type_str) to any tensors in the argument. If the
argument is a tensor, type(type_str) is applied; if the argument is an array,
this function recurses into it. ]]
local function recursiveType(param, type_str)
- if torch.type(param) == 'table' then
- for i = 1, #param do
- param[i] = recursiveType(param[i], type_str)
- end
- elseif torch.typename(param) and
- torch.typename(param):find('torch%..+Tensor') then
- param = param:type(type_str)
- end
- return param
+ if torch.type(param) == 'table' then
+ for i = 1, #param do
+ param[i] = recursiveType(param[i], type_str)
+ end
+ elseif torch.typename(param) and
+ torch.typename(param):find('torch%..+Tensor') then
+ param = param:type(type_str)
+ end
+ return param
end
function gModule:type(type)
- local function applyTypeToTable(table)
- for key, value in pairs(table) do
- table[key] = recursiveType(table[key], type)
- end
- end
-
- -- Convert any stored data in self, and in the in and out nodes
- applyTypeToTable(self)
- if self.innode then applyTypeToTable(self.innode.data) end
- if self.outnode then applyTypeToTable(self.outnode.data) end
-
- -- Loop through modules and convert data
- self:apply(function(module) module:type(type) end)
-
- for i,node in ipairs(self.backwardnodes) do
- if node.data.gradOutputBuffer ~= nil then
- node.data.gradOutputBuffer = node.data.gradOutputBuffer:type(type)
- end
- end
-
- return self
+ local function applyTypeToTable(table)
+ for key, value in pairs(table) do
+ table[key] = recursiveType(table[key], type)
+ end
+ end
+
+ -- Convert any stored data in self, and in the in and out nodes
+ applyTypeToTable(self)
+ if self.innode then applyTypeToTable(self.innode.data) end
+ if self.outnode then applyTypeToTable(self.outnode.data) end
+
+ -- Loop through modules and convert data
+ self:apply(function(module) module:type(type) end)
+
+ for i,node in ipairs(self.backwardnodes) do
+ if node.data.gradOutputBuffer ~= nil then
+ node.data.gradOutputBuffer = node.data.gradOutputBuffer:type(type)
+ end
+ end
+
+ return self
end
function gModule:zeroGradParameters()
- self:apply(function(module) module:zeroGradParameters() end)
+ self:apply(function(module) module:zeroGradParameters() end)
end
function gModule:updateOutput(input)
- return self:runForwardFunction('updateOutput',input)
+ return self:runForwardFunction('updateOutput',input)
end
function gModule:runForwardFunction(func,input)
- if type(func) == "string" then
- local func_name = func
- func = function(module,input) return module[func_name](module,input) end
- end
- -- For backward compatibility, we allow self.nInputs to be missing.
- local nInputs = self.nInputs or #self.innode.children
- -- We see the input as a list of inputs.
- if nInputs <= 1 then
- input={input}
- elseif type(input) ~= "table" then
- error(string.format("expecting %s inputs", nInputs))
- end
- local function neteval(node)
- local function propagate(node,x)
- for i,child in ipairs(node.children) do
- child.data.input = child.data.input or {}
- local mapindex = child.data.mapindex[node.data]
- assert(not child.data.input[mapindex], "each input should have one source")
- child.data.input[mapindex] = x
- end
- end
- if node.data.selectindex then
- assert(not node.data.module, "the selectindex-handling nodes should have no module")
- local input = node.data.input
- assert(#input == 1, "only the splitted node should be the input")
- assert(istable(input[1]), "the input for a split should be a table")
- input = input[1][node.data.selectindex]
- propagate(node,input)
- else
- local input = node.data.input
- if #input == 1 then
- input = input[1]
- end
- -- forward through this node
- -- If no module is present, the node behaves like nn.Identity.
- local output
- if not node.data.module then
- output = input
- else
- output = func(node.data.module,input)
- end
- if node.data.nSplitOutputs and node.data.nSplitOutputs ~= #output then
- error(string.format("split(%s) cannot split %s outputs",
- node.data.nSplitOutputs,
- #output))
- end
- -- propagate the output to children
- propagate(node,output)
- end
- if self.verbose then
- print(' V : ' .. node:label())
- end
- end
-
- local innode = self.innode
- if #input ~= nInputs then
- error(string.format('Got %s inputs instead of %s', #input, nInputs))
- end
- -- first clear the input states
- for _,node in ipairs(self.forwardnodes) do
- local input = node.data.input
- while input and #input>0 do
- table.remove(input)
- end
- end
- -- Set the starting input.
- -- We do copy instead of modifying the passed input.
- innode.data.input = innode.data.input or {}
- for i, item in ipairs(input) do
- innode.data.input[i] = item
- end
-
- -- the run forward
- for i,node in ipairs(self.forwardnodes) do
- neteval(node)
- end
-
- self.output = self.outnode.data.input
- if #self.outnode.children == 1 then
- self.output = self.output[1]
- end
- return self.output
+ if type(func) == "string" then
+ local func_name = func
+ func = function(module,input) return module[func_name](module,input) end
+ end
+ -- For backward compatibility, we allow self.nInputs to be missing.
+ local nInputs = self.nInputs or #self.innode.children
+ -- We see the input as a list of inputs.
+ if nInputs <= 1 then
+ input={input}
+ elseif type(input) ~= "table" then
+ error(string.format("expecting %s inputs", nInputs))
+ end
+ local function neteval(node)
+ local function propagate(node,x)
+ for i,child in ipairs(node.children) do
+ child.data.input = child.data.input or {}
+ local mapindex = child.data.mapindex[node.data]
+ assert(not child.data.input[mapindex], "each input should have one source")
+ child.data.input[mapindex] = x
+ end
+ end
+ if node.data.selectindex then
+ assert(not node.data.module, "the selectindex-handling nodes should have no module")
+ local input = node.data.input
+ assert(#input == 1, "only the splitted node should be the input")
+ assert(istable(input[1]), "the input for a split should be a table")
+ input = input[1][node.data.selectindex]
+ propagate(node,input)
+ else
+ local input = node.data.input
+ if #input == 1 then
+ input = input[1]
+ end
+ -- forward through this node
+ -- If no module is present, the node behaves like nn.Identity.
+ local output
+ if not node.data.module then
+ output = input
+ else
+ output = func(node.data.module,input)
+ end
+ if node.data.nSplitOutputs and node.data.nSplitOutputs ~= #output then
+ error(string.format("split(%s) cannot split %s outputs",
+ node.data.nSplitOutputs,
+ #output))
+ end
+ -- propagate the output to children
+ propagate(node,output)
+ end
+ if self.verbose then
+ print(' V : ' .. node:label())
+ end
+ end
+
+ local innode = self.innode
+ if #input ~= nInputs then
+ error(string.format('Got %s inputs instead of %s', #input, nInputs))
+ end
+ -- first clear the input states
+ for _,node in ipairs(self.forwardnodes) do
+ local input = node.data.input
+ while input and #input>0 do
+ table.remove(input)
+ end
+ end
+ -- Set the starting input.
+ -- We do copy instead of modifying the passed input.
+ innode.data.input = innode.data.input or {}
+ for i, item in ipairs(input) do
+ innode.data.input[i] = item
+ end
+
+ -- the run forward
+ for i,node in ipairs(self.forwardnodes) do
+ neteval(node)
+ end
+
+ self.output = self.outnode.data.input
+ if #self.outnode.children == 1 then
+ self.output = self.output[1]
+ end
+ return self.output
end
function gModule:updateGradInput(input,gradOutput)
- local function neteval(node)
- if node.data.selectindex then
- assert(not node.data.module, "the selectindex-handling nodes should have no module")
- assert(#node.children == 1, "only the splitted node should be the input")
- local child = node.children[1]
- local go = getTotalGradOutput(node)
- child.data.gradOutput = child.data.gradOutput or {}
- assert(#child.data.gradOutput <= 1, "the splitted node should be used only once")
- -- The data.gradOutput holds the to-be-summed gradients.
- child.data.gradOutput[1] = child.data.gradOutput[1] or {}
- assert(not child.data.gradOutput[1][node.data.selectindex], "no gradOutput should be assigned yet")
- child.data.gradOutput[1][node.data.selectindex] = go
- else
- local gradOutput = getTotalGradOutput(node)
- -- updateGradInput through this node
- -- If no module is present, the node behaves like nn.Identity.
- local gradInput
- if not node.data.module then
- gradInput = gradOutput
- else
- local input = node.data.input
- if #input == 1 then
- input = input[1]
- end
- local module = node.data.module
- gradInput = module:updateGradInput(input,gradOutput)
- end
- -- propagate the output to children
- for i,child in ipairs(node.children) do
- child.data.gradOutput = child.data.gradOutput or {}
- local mapindex = node.data.mapindex[child.data]
- local gi
- if #node.children == 1 then
- gi = gradInput
- else
- gi = gradInput[mapindex]
- end
- table.insert(child.data.gradOutput,gi)
- end
- end
- if self.verbose then
- print(' V : ' .. node:label())
- end
- end
- local outnode = self.outnode
- if #outnode.children > 1 and #gradOutput ~= #outnode.children then
- error(string.format('Got %s gradOutputs instead of %s', #gradOutput, #outnode.children))
- end
- for _,node in ipairs(self.backwardnodes) do
- local gradOutput = node.data.gradOutput
- while gradOutput and #gradOutput >0 do
- table.remove(gradOutput)
- end
- end
- -- Set the starting gradOutput.
- outnode.data.gradOutput = outnode.data.gradOutput or {}
- outnode.data.gradOutput[1] = gradOutput
-
- for i,node in ipairs(self.backwardnodes) do
- neteval(node)
- end
-
- assert(#self.innode.data.gradOutput == 1, "expecting the innode to be used only once")
- self.gradInput = self.innode.data.gradOutput[1]
- return self.gradInput
+ local function neteval(node)
+ if node.data.selectindex then
+ assert(not node.data.module, "the selectindex-handling nodes should have no module")
+ assert(#node.children == 1, "only the splitted node should be the input")
+ local child = node.children[1]
+ local go = getTotalGradOutput(node)
+ child.data.gradOutput = child.data.gradOutput or {}
+ assert(#child.data.gradOutput <= 1, "the splitted node should be used only once")
+ -- The data.gradOutput holds the to-be-summed gradients.
+ child.data.gradOutput[1] = child.data.gradOutput[1] or {}
+ assert(not child.data.gradOutput[1][node.data.selectindex], "no gradOutput should be assigned yet")
+ child.data.gradOutput[1][node.data.selectindex] = go
+ else
+ local gradOutput = getTotalGradOutput(node)
+ -- updateGradInput through this node
+ -- If no module is present, the node behaves like nn.Identity.
+ local gradInput
+ if not node.data.module then
+ gradInput = gradOutput
+ else
+ local input = node.data.input
+ if #input == 1 then
+ input = input[1]
+ end
+ local module = node.data.module
+ gradInput = module:updateGradInput(input,gradOutput)
+ end
+ -- propagate the output to children
+ for i,child in ipairs(node.children) do
+ child.data.gradOutput = child.data.gradOutput or {}
+ local mapindex = node.data.mapindex[child.data]
+ local gi
+ if #node.children == 1 then
+ gi = gradInput
+ else
+ gi = gradInput[mapindex]
+ end
+ table.insert(child.data.gradOutput,gi)
+ end
+ end
+ if self.verbose then
+ print(' V : ' .. node:label())
+ end
+ end
+ local outnode = self.outnode
+ if #outnode.children > 1 and #gradOutput ~= #outnode.children then
+ error(string.format('Got %s gradOutputs instead of %s', #gradOutput, #outnode.children))
+ end
+ for _,node in ipairs(self.backwardnodes) do
+ local gradOutput = node.data.gradOutput
+ while gradOutput and #gradOutput >0 do
+ table.remove(gradOutput)
+ end
+ end
+ -- Set the starting gradOutput.
+ outnode.data.gradOutput = outnode.data.gradOutput or {}
+ outnode.data.gradOutput[1] = gradOutput
+
+ for i,node in ipairs(self.backwardnodes) do
+ neteval(node)
+ end
+
+ assert(#self.innode.data.gradOutput == 1, "expecting the innode to be used only once")
+ self.gradInput = self.innode.data.gradOutput[1]
+ return self.gradInput
end
function gModule:accGradParameters(input,gradOutput,lr)
- local function neteval(node)
- if node.data.module then
- local module = node.data.module
- local gradOutput = node.data.gradOutput[1]
- if #node.data.gradOutput > 1 then
- gradOutput = node.data.gradOutputBuffer
- end
- local input = node.data.input
- if #input == 1 then
- input = input[1]
- end
- -- accGradParameters through this node
- module:accGradParameters(input,gradOutput,lr)
- end
- if self.verbose then
- print(' V : ' .. node:label())
- end
- end
- local outnode = self.outnode
- if #outnode.children > 1 and #gradOutput ~= #outnode.children then
- error(string.format('Got %s gradOutputs instead of %s', #gradOutput, #outnode.children))
- end
- for i,node in ipairs(self.backwardnodes) do
- neteval(node)
- end
+ local function neteval(node)
+ if node.data.module then
+ local module = node.data.module
+ local gradOutput = node.data.gradOutput[1]
+ if #node.data.gradOutput > 1 then
+ gradOutput = node.data.gradOutputBuffer
+ end
+ local input = node.data.input
+ if #input == 1 then
+ input = input[1]
+ end
+ -- accGradParameters through this node
+ module:accGradParameters(input,gradOutput,lr)
+ end
+ if self.verbose then
+ print(' V : ' .. node:label())
+ end
+ end
+ local outnode = self.outnode
+ if #outnode.children > 1 and #gradOutput ~= #outnode.children then
+ error(string.format('Got %s gradOutputs instead of %s', #gradOutput, #outnode.children))
+ end
+ for i,node in ipairs(self.backwardnodes) do
+ neteval(node)
+ end
end
function gModule:parameters()
- local p,gp = {},{}
- for _,node in ipairs(self.forwardnodes) do
- if node.data.module then
- local mp,mgp = node.data.module:parameters()
- if mp and mgp then
- for i = 1,#mp do
- table.insert(p,mp[i])
- table.insert(gp,mgp[i])
- end
- end
- end
- end
- return p,gp
+ local p,gp = {},{}
+ for _,node in ipairs(self.forwardnodes) do
+ if node.data.module then
+ local mp,mgp = node.data.module:parameters()
+ if mp and mgp then
+ for i = 1,#mp do
+ table.insert(p,mp[i])
+ table.insert(gp,mgp[i])
+ end
+ end
+ end
+ end
+ return p,gp
end
function gModule:__tostring__()
- return self.name or torch.type(self)
+ return self.name or torch.type(self)
end
diff --git a/graphinspecting.lua b/graphinspecting.lua
index 7c8d462..0ccb168 100644
--- a/graphinspecting.lua
+++ b/graphinspecting.lua
@@ -2,98 +2,98 @@
-- The findCurrentNode() depends on the names of the
-- local variables in the nngraph.gModule source code.
local function findCurrentNode()
- for level = 2, math.huge do
- local info = debug.getinfo(level, "n")
- if info == nil then
- return nil
- end
-
- local funcName = info.name
- if funcName == "neteval" then
- local varName, node = debug.getlocal(level, 1)
- if varName == "node" then
- return node
- end
- end
- end
+ for level = 2, math.huge do
+ local info = debug.getinfo(level, "n")
+ if info == nil then
+ return nil
+ end
+
+ local funcName = info.name
+ if funcName == "neteval" then
+ local varName, node = debug.getlocal(level, 1)
+ if varName == "node" then
+ return node
+ end
+ end
+ end
end
-- Runs the func and calls onError(failedNode, ...) on an error.
-- The stack trace is inspected to find the failedNode.
local function runChecked(func, onError, ...)
- -- The current node needs to be searched-for, before unrolling the stack.
- local failedNode
- local function errorHandler(message)
- -- The stack traceback is added only if not already present.
- if not string.find(message, 'stack traceback:\n', 1, true) then
- message = debug.traceback(message, 2)
- end
- failedNode = findCurrentNode()
- return message
- end
-
- local ok, result = xpcall(func, errorHandler)
- if ok then
- return result
- end
-
- onError(failedNode, ...)
- -- Passing the level 0 avoids adding an additional error position info
- -- to the message.
- error(result, 0)
+ -- The current node needs to be searched-for, before unrolling the stack.
+ local failedNode
+ local function errorHandler(message)
+ -- The stack traceback is added only if not already present.
+ if not string.find(message, 'stack traceback:\n', 1, true) then
+ message = debug.traceback(message, 2)
+ end
+ failedNode = findCurrentNode()
+ return message
+ end
+
+ local ok, result = xpcall(func, errorHandler)
+ if ok then
+ return result
+ end
+
+ onError(failedNode, ...)
+ -- Passing the level 0 avoids adding an additional error position info
+ -- to the message.
+ error(result, 0)
end
local function customToDot(graph, title, failedNode)
- local str = graph:todot(title)
- if not failedNode then
- return str
- end
-
- local failedNodeId = nil
- for i, node in ipairs(graph.nodes) do
- if node.data == failedNode.data then
- failedNodeId = node.id
- break
- end
- end
-
- if failedNodeId ~= nil then
- -- The closing '}' is removed.
- -- And red fillcolor is specified for the failedNode.
- str = string.gsub(str, '}%s*$', '')
- str = str .. string.format('n%s[style=filled, fillcolor=red];\n}',
- failedNodeId)
- end
- return str
+ local str = graph:todot(title)
+ if not failedNode then
+ return str
+ end
+
+ local failedNodeId = nil
+ for i, node in ipairs(graph.nodes) do
+ if node.data == failedNode.data then
+ failedNodeId = node.id
+ break
+ end
+ end
+
+ if failedNodeId ~= nil then
+ -- The closing '}' is removed.
+ -- And red fillcolor is specified for the failedNode.
+ str = string.gsub(str, '}%s*$', '')
+ str = str .. string.format('n%s[style=filled, fillcolor=red];\n}',
+ failedNodeId)
+ end
+ return str
end
local function saveSvg(svgPathPrefix, dotStr)
- io.stderr:write(string.format("saving %s.svg\n", svgPathPrefix))
- local dotPath = svgPathPrefix .. '.dot'
- local dotFile = io.open(dotPath, 'w')
- dotFile:write(dotStr)
- dotFile:close()
-
- local svgPath = svgPathPrefix .. '.svg'
- local cmd = string.format('dot -Tsvg -o %s %s', svgPath, dotPath)
- os.execute(cmd)
+ io.stderr:write(string.format("saving %s.svg\n", svgPathPrefix))
+ local dotPath = svgPathPrefix .. '.dot'
+ local dotFile = io.open(dotPath, 'w')
+ dotFile:write(dotStr)
+ dotFile:close()
+
+ local svgPath = svgPathPrefix .. '.svg'
+ local cmd = string.format('dot -Tsvg -o %s %s', svgPath, dotPath)
+ os.execute(cmd)
end
local function onError(failedNode, gmodule)
- local nInputs = gmodule.nInputs or #gmodule.innode.children
- local svgPathPrefix = gmodule.name or string.format(
- 'nngraph_%sin_%sout', nInputs, #gmodule.outnode.children)
- if paths.filep(svgPathPrefix .. '.svg') then
- svgPathPrefix = svgPathPrefix .. '_' .. paths.basename(os.tmpname())
- end
- local dotStr = customToDot(gmodule.fg, svgPathPrefix, failedNode)
- saveSvg(svgPathPrefix, dotStr)
+ local nInputs = gmodule.nInputs or #gmodule.innode.children
+ local svgPathPrefix = gmodule.name or string.format(
+ 'nngraph_%sin_%sout', nInputs, #gmodule.outnode.children)
+ if paths.filep(svgPathPrefix .. '.svg') then
+ svgPathPrefix = svgPathPrefix .. '_' .. paths.basename(os.tmpname())
+ end
+ local dotStr = customToDot(gmodule.fg, svgPathPrefix, failedNode)
+ saveSvg(svgPathPrefix, dotStr)
end
local origFuncs = {
- runForwardFunction = nn.gModule.runForwardFunction,
- updateGradInput = nn.gModule.updateGradInput,
- accGradParameters = nn.gModule.accGradParameters,
+ runForwardFunction = nn.gModule.runForwardFunction,
+ updateGradInput = nn.gModule.updateGradInput,
+ accGradParameters = nn.gModule.accGradParameters,
}
-- When debug is enabled,
@@ -101,42 +101,42 @@ local origFuncs = {
-- if an exception occurs in a graph execution.
-- The problematic node will be marked by red color.
function nngraph.setDebug(enable)
- if not enable then
- -- When debug is disabled,
- -- the origFuncs are restored on nn.gModule.
- for funcName, origFunc in pairs(origFuncs) do
- nn.gModule[funcName] = origFunc
- end
- return
- end
-
- for funcName, origFunc in pairs(origFuncs) do
- nn.gModule[funcName] = function(...)
- local args = {...}
- local gmodule = args[1]
- return runChecked(function()
- return origFunc(unpack(args))
- end, onError, gmodule)
- end
- end
+ if not enable then
+ -- When debug is disabled,
+ -- the origFuncs are restored on nn.gModule.
+ for funcName, origFunc in pairs(origFuncs) do
+ nn.gModule[funcName] = origFunc
+ end
+ return
+ end
+
+ for funcName, origFunc in pairs(origFuncs) do
+ nn.gModule[funcName] = function(...)
+ local args = {...}
+ local gmodule = args[1]
+ return runChecked(function()
+ return origFunc(unpack(args))
+ end, onError, gmodule)
+ end
+ end
end
-- Sets node.data.annotations.name for the found nodes.
-- The local variables at the given stack level are inspected.
-- The default stack level is 1 (the function that called annotateNodes()).
function nngraph.annotateNodes(stackLevel)
- stackLevel = stackLevel or 1
- for index = 1, math.huge do
- local varName, varValue = debug.getlocal(stackLevel + 1, index)
- if not varName then
- break
- end
- if torch.typename(varValue) == "nngraph.Node" then
- -- An explicit name is preserved.
- if not varValue.data.annotations.name then
- varValue:annotate({name = varName})
- end
- end
- end
+ stackLevel = stackLevel or 1
+ for index = 1, math.huge do
+ local varName, varValue = debug.getlocal(stackLevel + 1, index)
+ if not varName then
+ break
+ end
+ if torch.typename(varValue) == "nngraph.Node" then
+ -- An explicit name is preserved.
+ if not varValue.data.annotations.name then
+ varValue:annotate({name = varName})
+ end
+ end
+ end
end
diff --git a/init.lua b/init.lua
index 4d340f3..ad154bb 100644
--- a/init.lua
+++ b/init.lua
@@ -1,4 +1,3 @@
-
require 'nn'
require 'graph'
@@ -18,33 +17,32 @@ local istorchclass = utils.istorchclass
-- simpler todot functions
nngraph.simple_print = paths.dofile('simple_print.lua')
-
-- Modify the __call function to hack into nn.Module
local Module = torch.getmetatable('nn.Module')
function Module:__call__(...)
- local nArgs = select("#", ...)
- assert(nArgs <= 1, 'Use {input1, input2} to pass multiple inputs.')
-
- local input = ...
- if nArgs == 1 and input == nil then
- error('what is this in the input? nil')
- end
- if not istable(input) then
- input = {input}
- end
- local mnode = nngraph.Node({module=self})
-
- for i,dnode in ipairs(input) do
- if torch.typename(dnode) ~= 'nngraph.Node' then
- error('what is this in the input? ' .. tostring(dnode))
- end
- mnode:add(dnode,true)
- end
-
- return mnode
+ local nArgs = select("#", ...)
+ assert(nArgs <= 1, 'Use {input1, input2} to pass multiple inputs.')
+
+ local input = ...
+ if nArgs == 1 and input == nil then
+ error('what is this in the input? nil')
+ end
+ if not istable(input) then
+ input = {input}
+ end
+ local mnode = nngraph.Node({module=self})
+
+ for i,dnode in ipairs(input) do
+ if torch.typename(dnode) ~= 'nngraph.Node' then
+ error('what is this in the input? ' .. tostring(dnode))
+ end
+ mnode:add(dnode,true)
+ end
+
+ return mnode
end
local Criterion = torch.getmetatable('nn.Criterion')
function Criterion:__call__(...)
- return nn.ModuleFromCriterion(self)(...)
+ return nn.ModuleFromCriterion(self)(...)
end
diff --git a/nesting.lua b/nesting.lua
index ab63e70..0a61b36 100644
--- a/nesting.lua
+++ b/nesting.lua
@@ -4,66 +4,65 @@ local nesting = {}
local utils = paths.dofile('utils.lua')
local istensor = utils.istensor
-
-- Creates a clone of a tensor or of a table with tensors.
function nesting.cloneNested(obj)
- if istensor(obj) then
- return obj:clone()
- end
+ if istensor(obj) then
+ return obj:clone()
+ end
- local result = {}
- for key, child in pairs(obj) do
- result[key] = nesting.cloneNested(child)
- end
- return result
+ local result = {}
+ for key, child in pairs(obj) do
+ result[key] = nesting.cloneNested(child)
+ end
+ return result
end
-- Fills the obj with the given value.
-- The obj can be a tensor or a table with tensors.
function nesting.fillNested(obj, value)
- if istensor(obj) then
- obj:fill(value)
- else
- for key, child in pairs(obj) do
- nesting.fillNested(child, value)
- end
- end
+ if istensor(obj) then
+ obj:fill(value)
+ else
+ for key, child in pairs(obj) do
+ nesting.fillNested(child, value)
+ end
+ end
end
-- Resizes all tensors in the output.
function nesting.resizeNestedAs(output, input)
- if istensor(output) then
- output:resizeAs(input)
- else
- for key, child in pairs(input) do
- -- A new element is added to the output, if needed.
- if not output[key] then
- output[key] = nesting.cloneNested(child)
- else
- nesting.resizeNestedAs(output[key], child)
- end
- end
- -- Extra elements are removed from the output.
- for key, child in pairs(output) do
- if not input[key] then
- output[key] = nil
- end
- end
- end
+ if istensor(output) then
+ output:resizeAs(input)
+ else
+ for key, child in pairs(input) do
+ -- A new element is added to the output, if needed.
+ if not output[key] then
+ output[key] = nesting.cloneNested(child)
+ else
+ nesting.resizeNestedAs(output[key], child)
+ end
+ end
+ -- Extra elements are removed from the output.
+ for key, child in pairs(output) do
+ if not input[key] then
+ output[key] = nil
+ end
+ end
+ end
end
-- Adds the input to the output.
-- The input can contain nested tables.
-- The output will contain the same nesting of tables.
function nesting.addNestedTo(output, input)
- if istensor(output) then
- output:add(input)
- else
- for key, child in pairs(input) do
- assert(output[key] ~= nil, "missing key")
- nesting.addNestedTo(output[key], child)
- end
- end
+ if istensor(output) then
+ output:add(input)
+ else
+ for key, child in pairs(input) do
+ assert(output[key] ~= nil, "missing key")
+ nesting.addNestedTo(output[key], child)
+ end
+ end
end
return nesting
diff --git a/node.lua b/node.lua
index b620456..7ef55be 100644
--- a/node.lua
+++ b/node.lua
@@ -5,162 +5,155 @@ local istable = utils.istable
local istorchclass = utils.istorchclass
require 'debug'
-
local nnNode,parent = torch.class('nngraph.Node','graph.Node')
function nnNode:__init(data)
- parent.__init(self,data)
- self.data.annotations = self.data.annotations or {}
- self.data.mapindex = self.data.mapindex or {}
- if not self.data.annotations._debugLabel then
- self:_makeDebugLabel(debug.getinfo(6, 'Sl'))
- end
+ parent.__init(self,data)
+ self.data.annotations = self.data.annotations or {}
+ self.data.mapindex = self.data.mapindex or {}
+ if not self.data.annotations._debugLabel then
+ self:_makeDebugLabel(debug.getinfo(6, 'Sl'))
+ end
end
-
--[[ Build a string label which will be used a tooltip when
- making a graph.]]
+making a graph.]]
function nnNode:_makeDebugLabel(dinfo)
- if dinfo then
- self.data.annotations._debugLabel = string.format('[%s]:%d',
- dinfo.short_src, dinfo.currentline, dinfo.name)
- end
+ if dinfo then
+ self.data.annotations._debugLabel = string.format('[%s]:%d',
+ dinfo.short_src, dinfo.currentline, dinfo.name)
+ end
end
-
-- domap ensures that this node will keep track of the order its children are added.
-- mapindex is a forward/backward list
-- index = self.data.mapindex[child.data]
-- child.data = self.data.mapindex[index]
function nnNode:add(child,domap)
- parent.add(self,child)
- if domap then
- local mapindex = self.data.mapindex
- local data = child.data
- assert(not mapindex[data], "Don't pass the same input twice.")
- table.insert(mapindex,data)
- mapindex[data] = #mapindex
- end
+ parent.add(self,child)
+ if domap then
+ local mapindex = self.data.mapindex
+ local data = child.data
+ assert(not mapindex[data], "Don't pass the same input twice.")
+ table.insert(mapindex,data)
+ mapindex[data] = #mapindex
+ end
end
-- this function returns noutput number of new nodes
-- that each take a single component of the output of this
-- node in the order they are returned.
function nnNode:split(noutput)
- assert(noutput >= 2, "splitting to one output is not supported")
- local debugLabel = self.data.annotations._debugLabel
- local mnode = nngraph.Node({nSplitOutputs=noutput, annotations={_debugLabel=debugLabel .. '-mnode'}})
- mnode:add(self,true)
-
- local selectnodes = {}
- for i=1,noutput do
- local node = nngraph.Node({selectindex=i,input={}, annotations={_debugLabel=debugLabel .. '-' .. i}})
- node:add(mnode,true)
- table.insert(selectnodes,node)
- end
- return unpack(selectnodes)
+ assert(noutput >= 2, "splitting to one output is not supported")
+ local debugLabel = self.data.annotations._debugLabel
+ local mnode = nngraph.Node({nSplitOutputs=noutput, annotations={_debugLabel=debugLabel .. '-mnode'}})
+ mnode:add(self,true)
+
+ local selectnodes = {}
+ for i=1,noutput do
+ local node = nngraph.Node({selectindex=i,input={}, annotations={_debugLabel=debugLabel .. '-' .. i}})
+ node:add(mnode,true)
+ table.insert(selectnodes,node)
+ end
+ return unpack(selectnodes)
end
-
function nnNode:annotate(annotations)
- for k, v in pairs(annotations) do
- self.data.annotations[k] = v
- end
+ for k, v in pairs(annotations) do
+ self.data.annotations[k] = v
+ end
- return self
+ return self
end
-
function nnNode:graphNodeName()
- if self.data.annotations.name then
- return self.data.annotations.name .. ' (' .. self.id .. ')'
- else
- return 'Node' .. self.id
- end
+ if self.data.annotations.name then
+ return self.data.annotations.name .. ' (' .. self.id .. ')'
+ else
+ return 'Node' .. self.id
+ end
end
-
function nnNode:graphNodeAttributes()
- self.data.annotations.graphAttributes =
- self.data.annotations.graphAttributes or {}
- if not self.data.annotations.graphAttributes.tooltip then
- self.data.annotations.graphAttributes.tooltip =
- self.data.annotations._debugLabel
- end
-
- return self.data.annotations.graphAttributes
+ self.data.annotations.graphAttributes =
+ self.data.annotations.graphAttributes or {}
+ if not self.data.annotations.graphAttributes.tooltip then
+ self.data.annotations.graphAttributes.tooltip =
+ self.data.annotations._debugLabel
+ end
+
+ return self.data.annotations.graphAttributes
end
-
local function getNanFlag(data)
- if data:nElement() == 0 then
- return ''
- end
- local isNan = (data:ne(data):sum() > 0)
- if isNan then
- return 'NaN'
- end
- if data:max() == math.huge then
- return 'inf'
- end
- if data:min() == -math.huge then
- return '-inf'
- end
- return ''
+ if data:nElement() == 0 then
+ return ''
+ end
+ local isNan = (data:ne(data):sum() > 0)
+ if isNan then
+ return 'NaN'
+ end
+ if data:max() == math.huge then
+ return 'inf'
+ end
+ if data:min() == -math.huge then
+ return '-inf'
+ end
+ return ''
end
function nnNode:label()
- local lbl = {}
-
- local function getstr(data)
- if not data then return '' end
- if istensor(data) then
- local nanFlag = getNanFlag(data)
- local tensorType = 'Tensor'
- if data:type() ~= torch.Tensor():type() then
- tensorType = data:type()
- end
- return tensorType .. '[' .. table.concat(data:size():totable(),'x') .. ']' .. nanFlag
- elseif istable(data) then
- local tstr = {}
- for i,v in ipairs(data) do
- table.insert(tstr, getstr(v))
- end
- return '{' .. table.concat(tstr,',') .. '}'
- else
- return tostring(data):gsub('\n','\\l')
- end
- end
- local function getmapindexstr(mapindex)
- local tstr = {}
- for i,data in ipairs(mapindex) do
- local inputId = 'Node' .. (data.forwardNodeId or '')
- table.insert(tstr, inputId)
- end
- return '{' .. table.concat(tstr,',') .. '}'
- end
-
- for k,v in pairs(self.data) do
- local vstr = ''
- if k== 'mapindex' then
- if #v > 1 then
- vstr = getmapindexstr(v)
- table.insert(lbl, k .. ' = ' .. vstr)
- end
- elseif k== 'forwardNodeId' or k== 'annotations' then
- -- the forwardNodeId is not displayed in the label.
- else
- vstr = getstr(v)
- table.insert(lbl, k .. ' = ' .. vstr)
- end
- end
-
- local desc
- if self.data.annotations.description then
- desc = 'desc = ' .. self.data.annotations.description .. '\\n'
- else
- desc = ''
- end
- return desc .. table.concat(lbl,"\\l")
+ local lbl = {}
+
+ local function getstr(data)
+ if not data then return '' end
+ if istensor(data) then
+ local nanFlag = getNanFlag(data)
+ local tensorType = 'Tensor'
+ if data:type() ~= torch.Tensor():type() then
+ tensorType = data:type()
+ end
+ return tensorType .. '[' .. table.concat(data:size():totable(),'x') .. ']' .. nanFlag
+ elseif istable(data) then
+ local tstr = {}
+ for i,v in ipairs(data) do
+ table.insert(tstr, getstr(v))
+ end
+ return '{' .. table.concat(tstr,',') .. '}'
+ else
+ return tostring(data):gsub('\n','\\l')
+ end
+ end
+ local function getmapindexstr(mapindex)
+ local tstr = {}
+ for i,data in ipairs(mapindex) do
+ local inputId = 'Node' .. (data.forwardNodeId or '')
+ table.insert(tstr, inputId)
+ end
+ return '{' .. table.concat(tstr,',') .. '}'
+ end
+
+ for k,v in pairs(self.data) do
+ local vstr = ''
+ if k== 'mapindex' then
+ if #v > 1 then
+ vstr = getmapindexstr(v)
+ table.insert(lbl, k .. ' = ' .. vstr)
+ end
+ elseif k== 'forwardNodeId' or k== 'annotations' then
+ -- the forwardNodeId is not displayed in the label.
+ else
+ vstr = getstr(v)
+ table.insert(lbl, k .. ' = ' .. vstr)
+ end
+ end
+
+ local desc
+ if self.data.annotations.description then
+ desc = 'desc = ' .. self.data.annotations.description .. '\\n'
+ else
+ desc = ''
+ end
+ return desc .. table.concat(lbl,"\\l")
end
diff --git a/simple_print.lua b/simple_print.lua
index 878db3e..87bf152 100644
--- a/simple_print.lua
+++ b/simple_print.lua
@@ -1,125 +1,124 @@
local function removeNodeFromEdges(node_id, edges)
- local from_nodes = {}
- local to_nodes = {}
- -- remove edges
- local idx = 1
- while idx <= #edges do
- local edge = edges[idx]
- if edge.source == node_id then
- local to_node = edges[idx].target
- table.insert(to_nodes, to_node)
- table.remove(edges, idx)
- elseif edge.target == node_id then
- local from_node = edges[idx].source
- table.insert(from_nodes, from_node)
- table.remove(edges, idx)
- else
- idx = idx + 1
- end
- end
+ local from_nodes = {}
+ local to_nodes = {}
+ -- remove edges
+ local idx = 1
+ while idx <= #edges do
+ local edge = edges[idx]
+ if edge.source == node_id then
+ local to_node = edges[idx].target
+ table.insert(to_nodes, to_node)
+ table.remove(edges, idx)
+ elseif edge.target == node_id then
+ local from_node = edges[idx].source
+ table.insert(from_nodes, from_node)
+ table.remove(edges, idx)
+ else
+ idx = idx + 1
+ end
+ end
- -- add new edges
- for _, f in pairs(from_nodes) do
- for _, t in pairs(to_nodes) do
- local edge = {source = f, target= t}
- table.insert(edges, edge)
- end
- end
+ -- add new edges
+ for _, f in pairs(from_nodes) do
+ for _, t in pairs(to_nodes) do
+ local edge = {source = f, target= t}
+ table.insert(edges, edge)
+ end
+ end
- return edges
+ return edges
end
local function isNodeGood(node)
- return node.data and node.data.module and (torch.typename(node.data.module) ~= 'nn.Identity')
+ return node.data and node.data.module and (torch.typename(node.data.module) ~= 'nn.Identity')
end
local function reIndexNodes(nodes, edges)
- -- make reverse map
- local rev_map = {}
- for idx = 1, #nodes do
- rev_map[nodes[idx].id] = idx
- nodes[idx].id = idx
- end
- for idx = 1, #edges do
- local edge = edges[idx]
- edge.source = rev_map[edge.source]
- edge.target = rev_map[edge.target]
- end
- return nodes, edges
+ -- make reverse map
+ local rev_map = {}
+ for idx = 1, #nodes do
+ rev_map[nodes[idx].id] = idx
+ nodes[idx].id = idx
+ end
+ for idx = 1, #edges do
+ local edge = edges[idx]
+ edge.source = rev_map[edge.source]
+ edge.target = rev_map[edge.target]
+ end
+ return nodes, edges
end
local function cleanGraph(nodes, edges)
- local idx = 1
- while idx <= #nodes do
- local node = nodes[idx]
- if isNodeGood(node.orig_node) then
- idx = idx + 1
- else
- local id = node.id
- table.remove(nodes, idx)
- edges = removeNodeFromEdges(id, edges)
- end
- end
- return reIndexNodes(nodes, edges)
+ local idx = 1
+ while idx <= #nodes do
+ local node = nodes[idx]
+ if isNodeGood(node.orig_node) then
+ idx = idx + 1
+ else
+ local id = node.id
+ table.remove(nodes, idx)
+ edges = removeNodeFromEdges(id, edges)
+ end
+ end
+ return reIndexNodes(nodes, edges)
end
local function loadGraph(graph)
- local nodes = {}
- local edges = {}
- for _, node in ipairs(graph.nodes) do
- local idx = node.id
- table.insert(nodes, {id=idx, orig_node = node} )
- for ich = 1, #node.children do
- table.insert( edges, {source = idx, target = node.children[ich].id})
- end
- end
- nodes, edges = cleanGraph(nodes, edges)
- return nodes , edges
+ local nodes = {}
+ local edges = {}
+ for _, node in ipairs(graph.nodes) do
+ local idx = node.id
+ table.insert(nodes, {id=idx, orig_node = node} )
+ for ich = 1, #node.children do
+ table.insert( edges, {source = idx, target = node.children[ich].id})
+ end
+ end
+ nodes, edges = cleanGraph(nodes, edges)
+ return nodes , edges
end
local M = {}
function M.todot( graph, title )
- local nodes, edges = loadGraph(graph)
- local str = {}
- table.insert(str,'digraph G {\n')
- if title then
- table.insert(str,'labelloc="t";\nlabel="' .. title .. '";\n')
- end
- table.insert(str,'node [shape = oval]; ')
- local nodelabels = {}
- for i,node in ipairs(nodes) do
- local true_node = node.orig_node
- local l = '"' .. ( 'Node' .. true_node.id .. '\\n' .. true_node:label() ) .. '"'
- nodelabels[i] = 'n' .. true_node.id
- table.insert(str, '\n' .. nodelabels[i] .. '[label=' .. l .. '];')
- end
- table.insert(str,'\n')
- for i,edge in ipairs(edges) do
- table.insert(str,nodelabels[edge.source] .. ' -> ' .. nodelabels[edge.target] .. ';\n')
- end
- table.insert(str,'}')
- return table.concat(str,'')
+ local nodes, edges = loadGraph(graph)
+ local str = {}
+ table.insert(str,'digraph G {\n')
+ if title then
+ table.insert(str,'labelloc="t";\nlabel="' .. title .. '";\n')
+ end
+ table.insert(str,'node [shape = oval]; ')
+ local nodelabels = {}
+ for i,node in ipairs(nodes) do
+ local true_node = node.orig_node
+ local l = '"' .. ( 'Node' .. true_node.id .. '\\n' .. true_node:label() ) .. '"'
+ nodelabels[i] = 'n' .. true_node.id
+ table.insert(str, '\n' .. nodelabels[i] .. '[label=' .. l .. '];')
+ end
+ table.insert(str,'\n')
+ for i,edge in ipairs(edges) do
+ table.insert(str,nodelabels[edge.source] .. ' -> ' .. nodelabels[edge.target] .. ';\n')
+ end
+ table.insert(str,'}')
+ return table.concat(str,'')
end
function M.dot(g,title,fname)
- local gv = M.todot(g, title)
- local fngv = (fname or os.tmpname()) .. '.dot'
- local fgv = io.open(fngv,'w')
- fgv:write(gv)
- fgv:close()
- local fnsvg = (fname or os.tmpname()) .. '.svg'
- os.execute('dot -Tsvg -o ' .. fnsvg .. ' ' .. fngv)
- if not fname then
- require 'qtsvg'
- local qs = qt.QSvgWidget(fnsvg)
- qs:show()
- os.remove(fngv)
- os.remove(fnsvg)
- -- print(fngv,fnpng)
- return qs
- end
+ local gv = M.todot(g, title)
+ local fngv = (fname or os.tmpname()) .. '.dot'
+ local fgv = io.open(fngv,'w')
+ fgv:write(gv)
+ fgv:close()
+ local fnsvg = (fname or os.tmpname()) .. '.svg'
+ os.execute('dot -Tsvg -o ' .. fnsvg .. ' ' .. fngv)
+ if not fname then
+ require 'qtsvg'
+ local qs = qt.QSvgWidget(fnsvg)
+ qs:show()
+ os.remove(fngv)
+ os.remove(fnsvg)
+ -- print(fngv,fnpng)
+ return qs
+ end
end
return M
-
diff --git a/test/speed.lua b/test/speed.lua
index 355645c..7218cbe 100644
--- a/test/speed.lua
+++ b/test/speed.lua
@@ -1,105 +1,104 @@
require 'nngraph'
-
function time_benchmark(model, input, n)
- local forward_timer = torch.Timer():stop():reset()
- local backward_timer = torch.Timer():stop():reset()
- local total_timer = torch.Timer():stop():reset()
- local gradOut
- total_timer:resume()
- for i = 1, n do
- forward_timer:resume()
- local out = model:forward(input)
- forward_timer:stop()
- if not gradOut then
- gradOut = torch.rand(out:size())
- end
- backward_timer:resume()
- model:backward(input, gradOut)
- backward_timer:stop()
- end
- total_timer:stop()
-
- return {forward = forward_timer:time().real,
- backward = backward_timer:time().real,
- total = total_timer:time().real}
+ local forward_timer = torch.Timer():stop():reset()
+ local backward_timer = torch.Timer():stop():reset()
+ local total_timer = torch.Timer():stop():reset()
+ local gradOut
+ total_timer:resume()
+ for i = 1, n do
+ forward_timer:resume()
+ local out = model:forward(input)
+ forward_timer:stop()
+ if not gradOut then
+ gradOut = torch.rand(out:size())
+ end
+ backward_timer:resume()
+ model:backward(input, gradOut)
+ backward_timer:stop()
+ end
+ total_timer:stop()
+
+ return {forward = forward_timer:time().real,
+ backward = backward_timer:time().real,
+ total = total_timer:time().real}
end
function report_benchmark(result, title)
- local nspace = (80-string.len(title))/2
- report = {string.rep('#', 80),
- string.format('%s%s%s', string.rep(' ', math.floor(nspace)), title, string.rep(' ', math.ceil(nspace))),
- string.format('Total Time Spent = %.2f s', result.total),
- string.format(' Forward Time = %.2f s', result.forward),
- string.format(' Backward Time = %.2f s', result.backward),
- string.rep('#', 80)
- }
- print(table.concat(report, '\n'))
- return result
+ local nspace = (80-string.len(title))/2
+ report = {string.rep('#', 80),
+ string.format('%s%s%s', string.rep(' ', math.floor(nspace)), title, string.rep(' ', math.ceil(nspace))),
+ string.format('Total Time Spent = %.2f s', result.total),
+ string.format(' Forward Time = %.2f s', result.forward),
+ string.format(' Backward Time = %.2f s', result.backward),
+ string.rep('#', 80)
+}
+print(table.concat(report, '\n'))
+return result
end
function compare_benchmarks(result, base, title)
- local nspace = (80-string.len(title))/2
- report = {string.rep('#', 80),
- string.format('%s%s%s', string.rep(' ', math.floor(nspace)), title, string.rep(' ', math.ceil(nspace))),
- string.format('Total Time Spent = %.2f %%', result.total/base.total*100),
- string.format(' Forward Time = %.2f %%', result.forward/base.forward*100),
- string.format(' Backward Time = %.2f %%', result.backward/base.backward*100),
- string.rep('#', 80)
- }
- print(table.concat(report, '\n'))
- return result
+ local nspace = (80-string.len(title))/2
+ report = {string.rep('#', 80),
+ string.format('%s%s%s', string.rep(' ', math.floor(nspace)), title, string.rep(' ', math.ceil(nspace))),
+ string.format('Total Time Spent = %.2f %%', result.total/base.total*100),
+ string.format(' Forward Time = %.2f %%', result.forward/base.forward*100),
+ string.format(' Backward Time = %.2f %%', result.backward/base.backward*100),
+ string.rep('#', 80)
+}
+print(table.concat(report, '\n'))
+return result
end
function get_models(nhidden_layers, ninput, noutput, nhidden)
- local function get_concat_layer(nfrom, nto)
- local concat_module = nn.Sequential()
- local concat_layer = nn.ConcatTable()
- concat_layer:add(nn.Linear(nfrom, nto))
- concat_layer:add(nn.Linear(nfrom, nto))
- concat_module:add(concat_layer)
- concat_module:add(nn.CAddTable())
- concat_module:add(nn.ReLU())
- return concat_module
- end
-
- -- NN
- local nn_model = nn.Sequential()
- nn_model:add(get_concat_layer(ninput, nhidden))--nn.Linear(ninput, nhidden)):add(nn.ReLU())
- for i = 1, nhidden_layers do
- nn_model:add(get_concat_layer(nhidden, nhidden))--nn.Linear(nhidden, nhidden)):add(nn.ReLU())
- end
- nn_model:add(get_concat_layer(nhidden, noutput))--nn.Linear(nhidden, noutput))
-
- -- NN graph
- local input = nn.Identity()()
- local nng_model = nn.ReLU()(nn.CAddTable()({nn.Linear(ninput, nhidden)(input),
- nn.Linear(ninput, nhidden)(input)}))
- for i = 1, nhidden_layers do
- nng_model = nn.ReLU()(nn.CAddTable()({nn.Linear(nhidden, nhidden)(nng_model),
- nn.Linear(nhidden, nhidden)(nng_model)}))
- end
- nng_model = nn.ReLU()(nn.CAddTable()({nn.Linear(nhidden, noutput)(nng_model),
- nn.Linear(nhidden, noutput)(nng_model)}))
-
- nng_model = nn.gModule({input},{nng_model})
-
- return {nn_model = nn_model, nng_model = nng_model}
+ local function get_concat_layer(nfrom, nto)
+ local concat_module = nn.Sequential()
+ local concat_layer = nn.ConcatTable()
+ concat_layer:add(nn.Linear(nfrom, nto))
+ concat_layer:add(nn.Linear(nfrom, nto))
+ concat_module:add(concat_layer)
+ concat_module:add(nn.CAddTable())
+ concat_module:add(nn.ReLU())
+ return concat_module
+ end
+
+ -- NN
+ local nn_model = nn.Sequential()
+ nn_model:add(get_concat_layer(ninput, nhidden))--nn.Linear(ninput, nhidden)):add(nn.ReLU())
+ for i = 1, nhidden_layers do
+ nn_model:add(get_concat_layer(nhidden, nhidden))--nn.Linear(nhidden, nhidden)):add(nn.ReLU())
+ end
+ nn_model:add(get_concat_layer(nhidden, noutput))--nn.Linear(nhidden, noutput))
+
+ -- NN graph
+ local input = nn.Identity()()
+ local nng_model = nn.ReLU()(nn.CAddTable()({nn.Linear(ninput, nhidden)(input),
+ nn.Linear(ninput, nhidden)(input)}))
+ for i = 1, nhidden_layers do
+ nng_model = nn.ReLU()(nn.CAddTable()({nn.Linear(nhidden, nhidden)(nng_model),
+ nn.Linear(nhidden, nhidden)(nng_model)}))
+ end
+ nng_model = nn.ReLU()(nn.CAddTable()({nn.Linear(nhidden, noutput)(nng_model),
+ nn.Linear(nhidden, noutput)(nng_model)}))
+
+ nng_model = nn.gModule({input},{nng_model})
+
+ return {nn_model = nn_model, nng_model = nng_model}
end
function get_options(arg)
- local cmd = torch.CmdLine()
- cmd:text('nngraph benchmarking')
- cmd:option('-niter', 10, 'number of iterations of forward/backward for each model')
- cmd:option('-nhidden_layers', 10, 'number of hidden layers')
- cmd:option('-input_size', 512, 'size of input')
- cmd:option('-batch_size', 16, 'size of batch')
- cmd:option('-hidden_size', 1024, 'size of hidden layer')
- cmd:option('-output_size', 128, 'size of output layer')
- local opt = cmd:parse(arg)
- return opt
+ local cmd = torch.CmdLine()
+ cmd:text('nngraph benchmarking')
+ cmd:option('-niter', 10, 'number of iterations of forward/backward for each model')
+ cmd:option('-nhidden_layers', 10, 'number of hidden layers')
+ cmd:option('-input_size', 512, 'size of input')
+ cmd:option('-batch_size', 16, 'size of batch')
+ cmd:option('-hidden_size', 1024, 'size of hidden layer')
+ cmd:option('-output_size', 128, 'size of output layer')
+ local opt = cmd:parse(arg)
+ return opt
end
local opt = get_options(arg)
diff --git a/test/test_ModuleFromCriterion.lua b/test/test_ModuleFromCriterion.lua
index 78d3cd2..9206905 100644
--- a/test/test_ModuleFromCriterion.lua
+++ b/test/test_ModuleFromCriterion.lua
@@ -5,52 +5,52 @@ local test = {}
local tester = totem.Tester()
function test.test_call()
- local prediction = nn.Identity()()
- local target = nn.Identity()()
- local mse = nn.MSECriterion()({prediction, target})
- local costBits = nn.MulConstant(1/math.log(2))(mse)
- local net = nn.gModule({prediction, target}, {costBits})
-
- local input = {torch.randn(3, 5), torch.rand(3, 5)}
- local criterion = nn.MSECriterion()
- local output = net:forward(input)
- criterion:forward(input[1], input[2])
- tester:eq(output[1], criterion.output/math.log(2), "output", 1e-14)
-
- local gradOutput = torch.randn(1)
- local gradInput = net:backward(input, gradOutput)
- criterion:backward(input[1], input[2])
- tester:eq(gradInput[1], criterion.gradInput:clone():mul(gradOutput[1]/math.log(2)), "gradPrediction", 1e-14)
- tester:eq(gradInput[2], torch.zeros(input[2]:size()), "gradTarget")
+ local prediction = nn.Identity()()
+ local target = nn.Identity()()
+ local mse = nn.MSECriterion()({prediction, target})
+ local costBits = nn.MulConstant(1/math.log(2))(mse)
+ local net = nn.gModule({prediction, target}, {costBits})
+
+ local input = {torch.randn(3, 5), torch.rand(3, 5)}
+ local criterion = nn.MSECriterion()
+ local output = net:forward(input)
+ criterion:forward(input[1], input[2])
+ tester:eq(output[1], criterion.output/math.log(2), "output", 1e-14)
+
+ local gradOutput = torch.randn(1)
+ local gradInput = net:backward(input, gradOutput)
+ criterion:backward(input[1], input[2])
+ tester:eq(gradInput[1], criterion.gradInput:clone():mul(gradOutput[1]/math.log(2)), "gradPrediction", 1e-14)
+ tester:eq(gradInput[2], torch.zeros(input[2]:size()), "gradTarget")
end
function test.test_grad()
- local prediction = nn.Identity()()
- local zero = nn.MulConstant(0)(prediction)
- -- The target is created inside of the nngraph
- -- to ignore the zero gradTarget.
- local target = nn.AddConstant(1.23)(zero)
- local mse = nn.MSECriterion()({prediction, target})
- local net = nn.gModule({prediction}, {mse})
-
- local input = torch.randn(4, 7)
- totem.nn.checkGradients(tester, net, input)
+ local prediction = nn.Identity()()
+ local zero = nn.MulConstant(0)(prediction)
+ -- The target is created inside of the nngraph
+ -- to ignore the zero gradTarget.
+ local target = nn.AddConstant(1.23)(zero)
+ local mse = nn.MSECriterion()({prediction, target})
+ local net = nn.gModule({prediction}, {mse})
+
+ local input = torch.randn(4, 7)
+ totem.nn.checkGradients(tester, net, input)
end
local function module()
- local module = nn.ModuleFromCriterion(nn.MSECriterion())
- local input = {torch.randn(3, 5), torch.randn(3, 5)}
- return module, input
+ local module = nn.ModuleFromCriterion(nn.MSECriterion())
+ local input = {torch.randn(3, 5), torch.randn(3, 5)}
+ return module, input
end
function test.test_serializable()
- local module, input = module()
- totem.nn.checkSerializable(tester, module, input)
+ local module, input = module()
+ totem.nn.checkSerializable(tester, module, input)
end
function test.test_typeCastable()
- local module, input = module()
- totem.nn.checkTypeCastable(tester, module, input)
+ local module, input = module()
+ totem.nn.checkTypeCastable(tester, module, input)
end
diff --git a/test/test_nngraph.lua b/test/test_nngraph.lua
index 86bf730..95a1658 100644
--- a/test/test_nngraph.lua
+++ b/test/test_nngraph.lua
@@ -5,379 +5,379 @@ local test = {}
local tester = totem.Tester()
local function checkGradients(...)
- totem.nn.checkGradients(tester, ...)
+ totem.nn.checkGradients(tester, ...)
end
function test.test_oneOutput()
- local in1 = nn.Identity()()
- local out1 = nn.Identity()(in1)
- local module = nn.gModule({in1}, {out1})
-
- local input = torch.Tensor({1})
- module:forward(input)
- tester:eq(module.output, torch.Tensor{1}, "output")
- local gradInput = module:backward(input, torch.Tensor({-123}))
- tester:eq(gradInput, torch.Tensor{-123}, "gradInput")
-
- local input2 = torch.Tensor({2})
- module:forward(input2)
- tester:eq(module.output, torch.Tensor{2}, "output for input2")
- gradInput = module:backward(input2, torch.Tensor({-2}))
- tester:eq(gradInput, torch.Tensor{-2}, "expecting a recomputed gradInput")
+ local in1 = nn.Identity()()
+ local out1 = nn.Identity()(in1)
+ local module = nn.gModule({in1}, {out1})
+
+ local input = torch.Tensor({1})
+ module:forward(input)
+ tester:eq(module.output, torch.Tensor{1}, "output")
+ local gradInput = module:backward(input, torch.Tensor({-123}))
+ tester:eq(gradInput, torch.Tensor{-123}, "gradInput")
+
+ local input2 = torch.Tensor({2})
+ module:forward(input2)
+ tester:eq(module.output, torch.Tensor{2}, "output for input2")
+ gradInput = module:backward(input2, torch.Tensor({-2}))
+ tester:eq(gradInput, torch.Tensor{-2}, "expecting a recomputed gradInput")
end
function test.test_twoOutputs()
- local in1 = nn.Identity()()
- local out1 = nn.Identity()(in1)
- local out2 = nn.Identity()(in1)
- local module = nn.gModule({in1}, {out1, out2})
-
- local input = torch.Tensor({1})
- module:forward(input)
- local gradInput = module:backward(input, {torch.Tensor({-2}), torch.Tensor({-3})})
- tester:eq(gradInput, torch.Tensor{-5}, "gradInput of a fork")
- checkGradients(module, input)
+ local in1 = nn.Identity()()
+ local out1 = nn.Identity()(in1)
+ local out2 = nn.Identity()(in1)
+ local module = nn.gModule({in1}, {out1, out2})
+
+ local input = torch.Tensor({1})
+ module:forward(input)
+ local gradInput = module:backward(input, {torch.Tensor({-2}), torch.Tensor({-3})})
+ tester:eq(gradInput, torch.Tensor{-5}, "gradInput of a fork")
+ checkGradients(module, input)
end
function test.test_twoGradOutputs()
- local in1 = nn.Sigmoid()()
- local splitTable = nn.SplitTable(1)({in1})
- local out1, out2 = splitTable:split(2)
- local module = nn.gModule({in1}, {out1, out2})
-
- local input = torch.randn(2, 3)
- local output = module:forward(input)
- assert(#output == 2, "wrong number of outputs")
- module:backward(input, {torch.randn(3), torch.randn(3)})
- checkGradients(module, input)
+ local in1 = nn.Sigmoid()()
+ local splitTable = nn.SplitTable(1)({in1})
+ local out1, out2 = splitTable:split(2)
+ local module = nn.gModule({in1}, {out1, out2})
+
+ local input = torch.randn(2, 3)
+ local output = module:forward(input)
+ assert(#output == 2, "wrong number of outputs")
+ module:backward(input, {torch.randn(3), torch.randn(3)})
+ checkGradients(module, input)
end
function test.test_twoInputs()
- local in1 = nn.Identity()()
- local in2 = nn.Identity()()
- local prevH, prevCell = in2:split(2)
-
- local out1 = nn.CMulTable()({in1, prevH, prevCell})
- local module = nn.gModule({in1, in2}, {out1})
-
- local input = {torch.randn(3), {torch.randn(3), torch.randn(3)}}
- module:forward(input)
- local gradInput = module:backward(input, torch.randn(3))
- assert(#gradInput == 2, "wrong number of gradInputs")
- assert(type(gradInput[2]) == "table", "wrong gradInput[2] type")
- checkGradients(module, input)
+ local in1 = nn.Identity()()
+ local in2 = nn.Identity()()
+ local prevH, prevCell = in2:split(2)
+
+ local out1 = nn.CMulTable()({in1, prevH, prevCell})
+ local module = nn.gModule({in1, in2}, {out1})
+
+ local input = {torch.randn(3), {torch.randn(3), torch.randn(3)}}
+ module:forward(input)
+ local gradInput = module:backward(input, torch.randn(3))
+ assert(#gradInput == 2, "wrong number of gradInputs")
+ assert(type(gradInput[2]) == "table", "wrong gradInput[2] type")
+ checkGradients(module, input)
end
function test.test_twoInputs2()
- local in1 = nn.Sigmoid()()
- local in2 = nn.Sigmoid()()
- local module = nn.gModule({in1, in2}, {in1, in2, nn.Sigmoid()(in1)})
-
- local input = {torch.randn(3), torch.randn(3)}
- module:forward(input)
- local gradInput = module:backward(input, {torch.randn(3), torch.randn(3), torch.randn(3)})
- checkGradients(module, input)
+ local in1 = nn.Sigmoid()()
+ local in2 = nn.Sigmoid()()
+ local module = nn.gModule({in1, in2}, {in1, in2, nn.Sigmoid()(in1)})
+
+ local input = {torch.randn(3), torch.randn(3)}
+ module:forward(input)
+ local gradInput = module:backward(input, {torch.randn(3), torch.randn(3), torch.randn(3)})
+ checkGradients(module, input)
end
function test.test_splitDebugLabels()
- local node = nn.Identity()()
- node.data.annotations._debugLabel = "node"
- local node1, node2 = node:split(2)
- assert(node1.data.annotations._debugLabel == "node-1")
- assert(node2.data.annotations._debugLabel == "node-2")
+ local node = nn.Identity()()
+ node.data.annotations._debugLabel = "node"
+ local node1, node2 = node:split(2)
+ assert(node1.data.annotations._debugLabel == "node-1")
+ assert(node2.data.annotations._debugLabel == "node-2")
end
function test.test_identity()
- local in1 = nn.Identity()()
- local in2 = nn.Identity()()
- local module = nn.gModule({in1, in2}, {in1, in2, nn.Identity()(in1)})
-
- local input = {torch.randn(3), torch.randn(3)}
- module:forward(input)
- module:backward(input, {torch.randn(3), torch.randn(3), torch.randn(3)})
- checkGradients(module, input)
+ local in1 = nn.Identity()()
+ local in2 = nn.Identity()()
+ local module = nn.gModule({in1, in2}, {in1, in2, nn.Identity()(in1)})
+
+ local input = {torch.randn(3), torch.randn(3)}
+ module:forward(input)
+ module:backward(input, {torch.randn(3), torch.randn(3), torch.randn(3)})
+ checkGradients(module, input)
end
function test.test_gradInputType()
- local xInput = torch.randn(3)
- local h = torch.randn(3)
-
- local x = nn.Identity()()
- local prevRnnState = nn.Identity()()
- local prevH1, prevCell = prevRnnState:split(2)
- local prevH = prevH1
-
- local cellOut = nn.CAddTable()({
- nn.CMulTable()({x, prevH}),
- nn.CMulTable()({prevH, prevCell})})
- local module = nn.gModule({x, prevRnnState}, {cellOut})
-
- local c = torch.randn(h:size())
- local prevRnnState = {h, c}
- local input = {xInput, prevRnnState}
- local output = module:forward(input)
-
- local gradOutput = torch.randn(h:size())
- local gradInput = module:backward(input, gradOutput)
-
- local gradX, gradPrevState = unpack(gradInput)
- local gradPrevH, gradPrevCell = unpack(gradPrevState)
- assert(type(gradPrevH) == type(h), "wrong gradPrevH type")
-
- tester:eq(type(gradPrevH), type(h), "wrong gradPrevH type")
- tester:eq(gradPrevH:size(), h:size(), "wrong gradPrevH size")
- checkGradients(module, input)
-end
-
-function test.test_tabularInput()
- local in1 = nn.SplitTable(1)()
- local out1 = nn.CAddTable()(in1)
- local module = nn.gModule({in1}, {out1})
-
- local input = torch.randn(2, 3)
- checkGradients(module, input)
-end
-
-function test.test_extraTable()
- local in1 = nn.Identity()()
- local out1 = nn.Identity()(in1)
- local module = nn.gModule({in1}, {out1})
-
- local input = torch.Tensor({123})
- tester:eq(module:forward(input), input, "simple output")
- tester:eq(module:forward({input}), {input}, "tabular output")
-end
-
-function test.test_accGradParameters()
- local input = torch.randn(10)
-
- local in1 = nn.CMul(input:nElement())()
- local out1 = nn.Identity()(in1)
- local out2 = nn.Identity()(in1)
- local module = nn.gModule({in1}, {out1, out2})
- checkGradients(module, input)
-end
-
-function test.test_example1()
- local x1 = nn.Linear(20,10)()
- local mout = nn.Linear(10,1)(nn.Tanh()(nn.Linear(10,10)(nn.Tanh()(x1))))
- local mlp = nn.gModule({x1},{mout})
-
- local x = torch.rand(20)
- checkGradients(mlp, x)
-end
-
-function test.test_example2()
- local x1=nn.Linear(20,20)()
- local x2=nn.Linear(10,10)()
- local m0=nn.Linear(20,1)(nn.Tanh()(x1))
- local m1=nn.Linear(10,1)(nn.Tanh()(x2))
- local madd=nn.CAddTable()({m0,m1})
- local m2=nn.Sigmoid()(madd)
- local m3=nn.Tanh()(madd)
- local gmod = nn.gModule({x1,x2},{m2,m3})
-
- local x = torch.rand(20)
- local y = torch.rand(10)
- checkGradients(gmod, {x, y})
-end
-
-function test.test_example3()
- local m = nn.Sequential()
- m:add(nn.SplitTable(1))
- m:add(nn.ParallelTable():add(nn.Linear(10,20)):add(nn.Linear(10,30)))
- local input = nn.Identity()()
- local input1,input2 = m(input):split(2)
- local m3 = nn.JoinTable(1)({input1,input2})
- local g = nn.gModule({input},{m3})
-
- local indata = torch.rand(2,10)
- checkGradients(g, indata)
-end
-
-function test.test_example4()
- local input = nn.Identity()()
- local L1 = nn.Tanh()(nn.Linear(1,2)(input))
- local L2 = nn.Tanh()(nn.Linear(3,6)(nn.JoinTable(1)({input,L1})))
- local L3 = nn.Tanh()(nn.Linear(8,16)(nn.JoinTable(1)({L1,L2})))
- local g = nn.gModule({input},{L3})
-
- local indata = torch.rand(1)
- checkGradients(g, indata)
-end
-
-function test.test_type()
- local in1 = nn.Linear(20,10)()
- local out1 = nn.Linear(10,1)(nn.Tanh()(nn.Linear(10,10)(nn.Tanh()(in1))))
- local module = nn.gModule({in1}, {out1})
- local input = torch.rand(20)
- local output = module:forward(input)
- module:backward(input, output)
- tester:eq(torch.typename(output), "torch.DoubleTensor")
- tester:eq(torch.typename(module.output), "torch.DoubleTensor")
- tester:eq(torch.typename(module.gradInput), "torch.DoubleTensor")
- tester:eq(torch.typename(module.innode.data.input[1]), "torch.DoubleTensor")
- tester:eq(torch.typename(module.outnode.data.input[1]), "torch.DoubleTensor")
- tester:eq(torch.typename(module.forwardnodes[1].data.input[1]), "torch.DoubleTensor")
-
- module:float()
- local output = module:forward(input:float())
- tester:eq(torch.typename(output), "torch.FloatTensor")
- tester:eq(torch.typename(module.output), "torch.FloatTensor")
- tester:eq(torch.typename(module.gradInput), "torch.FloatTensor")
- tester:eq(torch.typename(module.innode.data.input[1]), "torch.FloatTensor")
- tester:eq(torch.typename(module.outnode.data.input[1]), "torch.FloatTensor")
- tester:eq(torch.typename(module.forwardnodes[1].data.input[1]), "torch.FloatTensor")
-end
-
-function test.test_nestedGradInput()
- local x = nn.Identity()()
- local h1 = nn.Sequential():add(nn.JoinTable(2)):add(nn.Tanh())
- local h2 = nn.Sequential():add(nn.JoinTable(2)):add(nn.Identity())
- local out = nn.CAddTable()({h1(x), h2(x)})
-
- local model = nn.gModule({x}, {out})
-
- local input = {}
- input[1] = torch.randn(3, 3)
- input[2] = torch.randn(3, 3)
- input[3] = torch.randn(3, 3)
-
- checkGradients(model, input)
-
- local input = {}
- input[1] = torch.randn(2, 3)
- input[2] = torch.randn(2, 3)
- input[3] = torch.randn(2, 3)
-
- checkGradients(model, input)
-end
-
-function test.test_unusedInput()
- local x = nn.Identity()()
- local h = nn.Identity()()
- local h2 = nn.Identity()()
-
- local ok, result = pcall(nn.gModule, {x, h}, {x})
- assert(not ok, "the unused input should be detected")
-end
-
-function test.test_unusedChild()
- local prevState = nn.Identity()()
- local h, cell = prevState:split(2)
-
- local ok, result = pcall(nn.gModule, {prevState}, {h})
- assert(not ok, "the unused cell should be detected")
-end
-
-function test.test_nilInput()
- local ok, result = pcall(function() nn.Sigmoid()(nil) end)
- assert(not ok, "the nil input should be detected")
-end
-
-function test.test_unusedNode()
- local in1 = nn.Identity()()
- local in2 = nn.Identity()()
- local middleResult = nn.Sigmoid()(in2)
- local out1 = nn.Sigmoid()(in1)
-
- local ok, result = pcall(nn.gModule, {in1, in2}, {out1})
- assert(not ok, "the unused middleResult should be detected")
-end
-
-function test.test_usageAfterSplit()
- local prevState = nn.Identity()()
- local h, cell = prevState:split(2)
- local nextState = nn.Identity()(prevState)
- local transformed = nn.Sigmoid()(cell)
-
- local model = nn.gModule({prevState}, {h, nextState, transformed})
- local nHidden = 10
- local input = {torch.randn(nHidden), torch.randn(nHidden)}
- checkGradients(model, input)
-end
-
-function test.test_resizeNestedAs()
- local in1 = nn.Identity()()
- local out1 = nn.Identity()(in1)
- local out2 = nn.Identity()(in1)
-
- local net = nn.gModule({in1}, {out1, out2})
- local input = {torch.randn(10), {torch.randn(3), torch.randn(4)}}
- net:forward(input)
- net:backward(input, net.output)
- checkGradients(net, input)
-
- input = {torch.randn(10), {torch.randn(3), torch.randn(4), torch.randn(5)}}
- net:forward(input)
- net:backward(input, net.output)
- checkGradients(net, input)
-
- input = {torch.randn(10), {torch.randn(3), torch.randn(4)}}
- net:forward(input)
- local gradInput = net:backward(input, net.output)
- tester:eq(#(gradInput[2]), 2, "gradInput[2] size")
- checkGradients(net, input)
-end
-
-
-function test.test_annotateGraph()
- local input = nn.Identity()():annotate(
+ local xInput = torch.randn(3)
+ local h = torch.randn(3)
+
+ local x = nn.Identity()()
+ local prevRnnState = nn.Identity()()
+ local prevH1, prevCell = prevRnnState:split(2)
+ local prevH = prevH1
+
+ local cellOut = nn.CAddTable()({
+ nn.CMulTable()({x, prevH}),
+ nn.CMulTable()({prevH, prevCell})})
+ local module = nn.gModule({x, prevRnnState}, {cellOut})
+
+ local c = torch.randn(h:size())
+ local prevRnnState = {h, c}
+ local input = {xInput, prevRnnState}
+ local output = module:forward(input)
+
+ local gradOutput = torch.randn(h:size())
+ local gradInput = module:backward(input, gradOutput)
+
+ local gradX, gradPrevState = unpack(gradInput)
+ local gradPrevH, gradPrevCell = unpack(gradPrevState)
+ assert(type(gradPrevH) == type(h), "wrong gradPrevH type")
+
+ tester:eq(type(gradPrevH), type(h), "wrong gradPrevH type")
+ tester:eq(gradPrevH:size(), h:size(), "wrong gradPrevH size")
+ checkGradients(module, input)
+ end
+
+ function test.test_tabularInput()
+ local in1 = nn.SplitTable(1)()
+ local out1 = nn.CAddTable()(in1)
+ local module = nn.gModule({in1}, {out1})
+
+ local input = torch.randn(2, 3)
+ checkGradients(module, input)
+ end
+
+ function test.test_extraTable()
+ local in1 = nn.Identity()()
+ local out1 = nn.Identity()(in1)
+ local module = nn.gModule({in1}, {out1})
+
+ local input = torch.Tensor({123})
+ tester:eq(module:forward(input), input, "simple output")
+ tester:eq(module:forward({input}), {input}, "tabular output")
+ end
+
+ function test.test_accGradParameters()
+ local input = torch.randn(10)
+
+ local in1 = nn.CMul(input:nElement())()
+ local out1 = nn.Identity()(in1)
+ local out2 = nn.Identity()(in1)
+ local module = nn.gModule({in1}, {out1, out2})
+ checkGradients(module, input)
+ end
+
+ function test.test_example1()
+ local x1 = nn.Linear(20,10)()
+ local mout = nn.Linear(10,1)(nn.Tanh()(nn.Linear(10,10)(nn.Tanh()(x1))))
+ local mlp = nn.gModule({x1},{mout})
+
+ local x = torch.rand(20)
+ checkGradients(mlp, x)
+ end
+
+ function test.test_example2()
+ local x1=nn.Linear(20,20)()
+ local x2=nn.Linear(10,10)()
+ local m0=nn.Linear(20,1)(nn.Tanh()(x1))
+ local m1=nn.Linear(10,1)(nn.Tanh()(x2))
+ local madd=nn.CAddTable()({m0,m1})
+ local m2=nn.Sigmoid()(madd)
+ local m3=nn.Tanh()(madd)
+ local gmod = nn.gModule({x1,x2},{m2,m3})
+
+ local x = torch.rand(20)
+ local y = torch.rand(10)
+ checkGradients(gmod, {x, y})
+ end
+
+ function test.test_example3()
+ local m = nn.Sequential()
+ m:add(nn.SplitTable(1))
+ m:add(nn.ParallelTable():add(nn.Linear(10,20)):add(nn.Linear(10,30)))
+ local input = nn.Identity()()
+ local input1,input2 = m(input):split(2)
+ local m3 = nn.JoinTable(1)({input1,input2})
+ local g = nn.gModule({input},{m3})
+
+ local indata = torch.rand(2,10)
+ checkGradients(g, indata)
+ end
+
+ function test.test_example4()
+ local input = nn.Identity()()
+ local L1 = nn.Tanh()(nn.Linear(1,2)(input))
+ local L2 = nn.Tanh()(nn.Linear(3,6)(nn.JoinTable(1)({input,L1})))
+ local L3 = nn.Tanh()(nn.Linear(8,16)(nn.JoinTable(1)({L1,L2})))
+ local g = nn.gModule({input},{L3})
+
+ local indata = torch.rand(1)
+ checkGradients(g, indata)
+ end
+
+ function test.test_type()
+ local in1 = nn.Linear(20,10)()
+ local out1 = nn.Linear(10,1)(nn.Tanh()(nn.Linear(10,10)(nn.Tanh()(in1))))
+ local module = nn.gModule({in1}, {out1})
+ local input = torch.rand(20)
+ local output = module:forward(input)
+ module:backward(input, output)
+ tester:eq(torch.typename(output), "torch.DoubleTensor")
+ tester:eq(torch.typename(module.output), "torch.DoubleTensor")
+ tester:eq(torch.typename(module.gradInput), "torch.DoubleTensor")
+ tester:eq(torch.typename(module.innode.data.input[1]), "torch.DoubleTensor")
+ tester:eq(torch.typename(module.outnode.data.input[1]), "torch.DoubleTensor")
+ tester:eq(torch.typename(module.forwardnodes[1].data.input[1]), "torch.DoubleTensor")
+
+ module:float()
+ local output = module:forward(input:float())
+ tester:eq(torch.typename(output), "torch.FloatTensor")
+ tester:eq(torch.typename(module.output), "torch.FloatTensor")
+ tester:eq(torch.typename(module.gradInput), "torch.FloatTensor")
+ tester:eq(torch.typename(module.innode.data.input[1]), "torch.FloatTensor")
+ tester:eq(torch.typename(module.outnode.data.input[1]), "torch.FloatTensor")
+ tester:eq(torch.typename(module.forwardnodes[1].data.input[1]), "torch.FloatTensor")
+ end
+
+ function test.test_nestedGradInput()
+ local x = nn.Identity()()
+ local h1 = nn.Sequential():add(nn.JoinTable(2)):add(nn.Tanh())
+ local h2 = nn.Sequential():add(nn.JoinTable(2)):add(nn.Identity())
+ local out = nn.CAddTable()({h1(x), h2(x)})
+
+ local model = nn.gModule({x}, {out})
+
+ local input = {}
+ input[1] = torch.randn(3, 3)
+ input[2] = torch.randn(3, 3)
+ input[3] = torch.randn(3, 3)
+
+ checkGradients(model, input)
+
+ local input = {}
+ input[1] = torch.randn(2, 3)
+ input[2] = torch.randn(2, 3)
+ input[3] = torch.randn(2, 3)
+
+ checkGradients(model, input)
+ end
+
+ function test.test_unusedInput()
+ local x = nn.Identity()()
+ local h = nn.Identity()()
+ local h2 = nn.Identity()()
+
+ local ok, result = pcall(nn.gModule, {x, h}, {x})
+ assert(not ok, "the unused input should be detected")
+ end
+
+ function test.test_unusedChild()
+ local prevState = nn.Identity()()
+ local h, cell = prevState:split(2)
+
+ local ok, result = pcall(nn.gModule, {prevState}, {h})
+ assert(not ok, "the unused cell should be detected")
+ end
+
+ function test.test_nilInput()
+ local ok, result = pcall(function() nn.Sigmoid()(nil) end)
+ assert(not ok, "the nil input should be detected")
+ end
+
+ function test.test_unusedNode()
+ local in1 = nn.Identity()()
+ local in2 = nn.Identity()()
+ local middleResult = nn.Sigmoid()(in2)
+ local out1 = nn.Sigmoid()(in1)
+
+ local ok, result = pcall(nn.gModule, {in1, in2}, {out1})
+ assert(not ok, "the unused middleResult should be detected")
+ end
+
+ function test.test_usageAfterSplit()
+ local prevState = nn.Identity()()
+ local h, cell = prevState:split(2)
+ local nextState = nn.Identity()(prevState)
+ local transformed = nn.Sigmoid()(cell)
+
+ local model = nn.gModule({prevState}, {h, nextState, transformed})
+ local nHidden = 10
+ local input = {torch.randn(nHidden), torch.randn(nHidden)}
+ checkGradients(model, input)
+ end
+
+ function test.test_resizeNestedAs()
+ local in1 = nn.Identity()()
+ local out1 = nn.Identity()(in1)
+ local out2 = nn.Identity()(in1)
+
+ local net = nn.gModule({in1}, {out1, out2})
+ local input = {torch.randn(10), {torch.randn(3), torch.randn(4)}}
+ net:forward(input)
+ net:backward(input, net.output)
+ checkGradients(net, input)
+
+ input = {torch.randn(10), {torch.randn(3), torch.randn(4), torch.randn(5)}}
+ net:forward(input)
+ net:backward(input, net.output)
+ checkGradients(net, input)
+
+ input = {torch.randn(10), {torch.randn(3), torch.randn(4)}}
+ net:forward(input)
+ local gradInput = net:backward(input, net.output)
+ tester:eq(#(gradInput[2]), 2, "gradInput[2] size")
+ checkGradients(net, input)
+ end
+
+
+ function test.test_annotateGraph()
+ local input = nn.Identity()():annotate(
{name = 'Input', description = 'DescA',
- graphAttributes = {color = 'red'}})
+ graphAttributes = {color = 'red'}})
- local hidden_a = nn.Linear(10, 10)(input):annotate(
+ local hidden_a = nn.Linear(10, 10)(input):annotate(
{name = 'Hidden A', description = 'DescB',
- graphAttributes = {color = 'blue', fontcolor='green', tooltip = 'I am green'}})
- local hidden_b = nn.Sigmoid()(hidden_a)
- local output = nn.Linear(10, 10)(hidden_b)
- local net = nn.gModule({input}, {output})
-
- tester:assert(hidden_a:label():match('DescB'))
- local fg_tmpfile = os.tmpname()
- local bg_tmpfile = os.tmpname()
- graph.dot(net.fg, 'Test', fg_tmpfile)
- graph.dot(net.fg, 'Test BG', bg_tmpfile)
-
- local function checkDotFile(tmpfile)
- local dotcontent = io.open(tmpfile .. '.dot', 'r'):read("*all")
- tester:assert(
- dotcontent:match('%[color=red.*label=%"Input.*DescA.*%".*%]'))
- tester:assert(
- dotcontent:match(
- '%[.*fontcolor=green.*label=%"Hidden A.*DescB.*%".*%]'))
- tester:assert(
- dotcontent:match('%[color=blue.*label=%".*DescB.*%".*%]'))
- tester:assert(
- dotcontent:match(
- '%[.*label=%".*DescB.*%".*tooltip=%"I am green%".*%]'))
- end
-
- checkDotFile(fg_tmpfile)
- checkDotFile(bg_tmpfile)
-end
-
-function test.test_splitMore()
- local nSplits = 2
- local in1 = nn.Identity()()
- local out1, out2 = nn.SplitTable(2)(in1):split(nSplits)
-
- local model = nn.gModule({in1}, {out1, out2})
- local input = torch.randn(10, nSplits + 1)
- local ok, result = pcall(model.forward, model, input)
- assert(not ok, "the extra input to split should be detected")
-end
-
-function test.test_splitLess()
- local nSplits = 3
- local in1 = nn.Identity()()
- local out1, out2, out3 = nn.SplitTable(2)(in1):split(nSplits)
-
- local model = nn.gModule({in1}, {out1, out2, out3})
- local input = torch.randn(10, nSplits - 1)
- local ok, result = pcall(model.forward, model, input)
- assert(not ok, "the missing input to split should be detected")
-end
-
-tester:add(test):run()
+ graphAttributes = {color = 'blue', fontcolor='green', tooltip = 'I am green'}})
+ local hidden_b = nn.Sigmoid()(hidden_a)
+ local output = nn.Linear(10, 10)(hidden_b)
+ local net = nn.gModule({input}, {output})
+
+ tester:assert(hidden_a:label():match('DescB'))
+ local fg_tmpfile = os.tmpname()
+ local bg_tmpfile = os.tmpname()
+ graph.dot(net.fg, 'Test', fg_tmpfile)
+ graph.dot(net.fg, 'Test BG', bg_tmpfile)
+
+ local function checkDotFile(tmpfile)
+ local dotcontent = io.open(tmpfile .. '.dot', 'r'):read("*all")
+ tester:assert(
+ dotcontent:match('%[color=red.*label=%"Input.*DescA.*%".*%]'))
+ tester:assert(
+ dotcontent:match(
+ '%[.*fontcolor=green.*label=%"Hidden A.*DescB.*%".*%]'))
+ tester:assert(
+ dotcontent:match('%[color=blue.*label=%".*DescB.*%".*%]'))
+ tester:assert(
+ dotcontent:match(
+ '%[.*label=%".*DescB.*%".*tooltip=%"I am green%".*%]'))
+ end
+
+ checkDotFile(fg_tmpfile)
+ checkDotFile(bg_tmpfile)
+ end
+
+ function test.test_splitMore()
+ local nSplits = 2
+ local in1 = nn.Identity()()
+ local out1, out2 = nn.SplitTable(2)(in1):split(nSplits)
+
+ local model = nn.gModule({in1}, {out1, out2})
+ local input = torch.randn(10, nSplits + 1)
+ local ok, result = pcall(model.forward, model, input)
+ assert(not ok, "the extra input to split should be detected")
+ end
+
+ function test.test_splitLess()
+ local nSplits = 3
+ local in1 = nn.Identity()()
+ local out1, out2, out3 = nn.SplitTable(2)(in1):split(nSplits)
+
+ local model = nn.gModule({in1}, {out1, out2, out3})
+ local input = torch.randn(10, nSplits - 1)
+ local ok, result = pcall(model.forward, model, input)
+ assert(not ok, "the missing input to split should be detected")
+ end
+
+ tester:add(test):run()
diff --git a/test/test_old.lua b/test/test_old.lua
index b9e2f6c..1a1e862 100644
--- a/test/test_old.lua
+++ b/test/test_old.lua
@@ -1,227 +1,227 @@
require 'nngraph'
function t1()
- local x1 = nn.Linear(20,20)()
- local x2 = nn.Linear(10,10)()
- local m0=nn.Linear(20,1)(nn.Tanh()(x1))
- local m1=nn.Linear(10,1)(nn.Tanh()(x2))
- local madd=nn.CAddTable()({m0,m1})
- local m2=nn.Sigmoid()(madd)
- local m3=nn.Tanh()(madd)
- local x = torch.rand(20)
- local y = torch.rand(10)
- gmod = nn.gModule({x1,x2},{m2,m3})
- gmod.verbose = true
- print('forward')
- gmod:updateOutput({x,y})
- print('updateGradInput')
- gmod:updateGradInput({x,y},{torch.rand(1),torch.rand(1)})
- graph.dot(gmod.fg,'forward')
- graph.dot(gmod.bg,'backward')
+ local x1 = nn.Linear(20,20)()
+ local x2 = nn.Linear(10,10)()
+ local m0=nn.Linear(20,1)(nn.Tanh()(x1))
+ local m1=nn.Linear(10,1)(nn.Tanh()(x2))
+ local madd=nn.CAddTable()({m0,m1})
+ local m2=nn.Sigmoid()(madd)
+ local m3=nn.Tanh()(madd)
+ local x = torch.rand(20)
+ local y = torch.rand(10)
+ gmod = nn.gModule({x1,x2},{m2,m3})
+ gmod.verbose = true
+ print('forward')
+ gmod:updateOutput({x,y})
+ print('updateGradInput')
+ gmod:updateGradInput({x,y},{torch.rand(1),torch.rand(1)})
+ graph.dot(gmod.fg,'forward')
+ graph.dot(gmod.bg,'backward')
end
function t2()
- print('compare')
- local m0 = nn.Linear(5,10)()
- local m1 = nn.Linear(10,20)()
- local m2 = nn.Linear(30,50)(nn.JoinTable(1){m0,m1})
- gmod = nn.gModule({m0,m1},{m2})
-
- local nn0 = nn.Linear(5,10)
- local nn1 = nn.Linear(10,20)
- local nn2 = nn.Linear(30,50)
- local nnmod = nn.Sequential():add(nn.ParallelTable():add(nn0):add(nn1)):add(nn.JoinTable(1)):add(nn2)
-
- nn0.weight:copy(m0.data.module.weight)
- nn0.bias:copy(m0.data.module.bias)
- nn1.weight:copy(m1.data.module.weight)
- nn1.bias:copy(m1.data.module.bias)
- nn2.weight:copy(m2.data.module.weight)
- nn2.bias:copy(m2.data.module.bias)
-
-
- for i=1,5 do
- local x,y = torch.rand(5),torch.rand(10)
- local xx,yy = x:clone(),y:clone()
-
- gmod:updateOutput({x,y})
- nnmod:updateOutput({xx,yy})
- print('fdiff = ', torch.dist(gmod.output,nnmod.output))
-
- local odx = torch.rand(50)
- local odxx = odx:clone()
-
- gmod:updateGradInput({x,y},odx)
- nnmod:updateGradInput({xx,yy},odxx)
- graph.dot(gmod.fg,tostring(i))
- for i,v in ipairs(gmod.gradInput) do
- print('bdiff [' ..i.. '] = ', torch.dist(gmod.gradInput[i],nnmod.gradInput[i]))
- end
- end
-
- local gms = {m0,m1,m2}
- local nms = {nn0,nn1,nn2}
-
- for i=1,5 do
- local x,y = torch.rand(5),torch.rand(10)
- local xx,yy = x:clone(),y:clone()
-
- gmod:updateOutput({x,y})
- nnmod:updateOutput({xx,yy})
- print('fdiff = ', torch.dist(gmod.output,nnmod.output))
-
- local odx = torch.rand(50)
- local odxx = odx:clone()
-
- gmod:zeroGradParameters()
- nnmod:zeroGradParameters()
-
- gmod:updateGradInput({x,y},odx)
- nnmod:updateGradInput({xx,yy},odxx)
-
- gmod:accGradParameters({x,y},odx)
- nnmod:accGradParameters({xx,yy},odxx)
- graph.dot(gmod.fg)
- for i,v in ipairs(gms) do
- print('accdiff [' ..i.. '] = ', torch.dist(gms[i].data.module.gradWeight,nms[i].gradWeight))
- print('accdiff [' ..i.. '] = ', torch.dist(gms[i].data.module.gradBias,nms[i].gradBias))
- end
- end
+ print('compare')
+ local m0 = nn.Linear(5,10)()
+ local m1 = nn.Linear(10,20)()
+ local m2 = nn.Linear(30,50)(nn.JoinTable(1){m0,m1})
+ gmod = nn.gModule({m0,m1},{m2})
+
+ local nn0 = nn.Linear(5,10)
+ local nn1 = nn.Linear(10,20)
+ local nn2 = nn.Linear(30,50)
+ local nnmod = nn.Sequential():add(nn.ParallelTable():add(nn0):add(nn1)):add(nn.JoinTable(1)):add(nn2)
+
+ nn0.weight:copy(m0.data.module.weight)
+ nn0.bias:copy(m0.data.module.bias)
+ nn1.weight:copy(m1.data.module.weight)
+ nn1.bias:copy(m1.data.module.bias)
+ nn2.weight:copy(m2.data.module.weight)
+ nn2.bias:copy(m2.data.module.bias)
+
+
+ for i=1,5 do
+ local x,y = torch.rand(5),torch.rand(10)
+ local xx,yy = x:clone(),y:clone()
+
+ gmod:updateOutput({x,y})
+ nnmod:updateOutput({xx,yy})
+ print('fdiff = ', torch.dist(gmod.output,nnmod.output))
+
+ local odx = torch.rand(50)
+ local odxx = odx:clone()
+
+ gmod:updateGradInput({x,y},odx)
+ nnmod:updateGradInput({xx,yy},odxx)
+ graph.dot(gmod.fg,tostring(i))
+ for i,v in ipairs(gmod.gradInput) do
+ print('bdiff [' ..i.. '] = ', torch.dist(gmod.gradInput[i],nnmod.gradInput[i]))
+ end
+ end
+
+ local gms = {m0,m1,m2}
+ local nms = {nn0,nn1,nn2}
+
+ for i=1,5 do
+ local x,y = torch.rand(5),torch.rand(10)
+ local xx,yy = x:clone(),y:clone()
+
+ gmod:updateOutput({x,y})
+ nnmod:updateOutput({xx,yy})
+ print('fdiff = ', torch.dist(gmod.output,nnmod.output))
+
+ local odx = torch.rand(50)
+ local odxx = odx:clone()
+
+ gmod:zeroGradParameters()
+ nnmod:zeroGradParameters()
+
+ gmod:updateGradInput({x,y},odx)
+ nnmod:updateGradInput({xx,yy},odxx)
+
+ gmod:accGradParameters({x,y},odx)
+ nnmod:accGradParameters({xx,yy},odxx)
+ graph.dot(gmod.fg)
+ for i,v in ipairs(gms) do
+ print('accdiff [' ..i.. '] = ', torch.dist(gms[i].data.module.gradWeight,nms[i].gradWeight))
+ print('accdiff [' ..i.. '] = ', torch.dist(gms[i].data.module.gradBias,nms[i].gradBias))
+ end
+ end
end
function t3()
- mlp=nn.Sequential(); --Create a network that takes a Tensor as input
- mlp:add(nn.SplitTable(2))
- c=nn.ParallelTable() --The two Tensors go through two different Linear
- c:add(nn.Linear(10,3)) --Layers in Parallel
- c:add(nn.Linear(10,7))
- mlp:add(c) --Outputing a table with 2 elements
- p=nn.ParallelTable() --These tables go through two more linear layers
- p:add(nn.Linear(3,2)) -- separately.
- p:add(nn.Linear(7,1))
- mlp:add(p)
- mlp:add(nn.JoinTable(1)) --Finally, the tables are joined together and output.
-
- pred=mlp:forward(torch.randn(10,2))
- print(pred)
-
- for i=1,25 do -- A few steps of training such a network..
- x=torch.ones(10,2);
- y=torch.Tensor(3); y:copy(x:select(2,1,1):narrow(1,1,3))
- pred=mlp:forward(x)
-
- criterion= nn.MSECriterion()
- local err=criterion:forward(pred,y)
- local gradCriterion = criterion:backward(pred,y);
- print(x,y)
- mlp:zeroGradParameters();
- mlp:backward(x, gradCriterion);
- mlp:updateParameters(0.05);
-
- print(err)
- end
+ mlp=nn.Sequential(); --Create a network that takes a Tensor as input
+ mlp:add(nn.SplitTable(2))
+ c=nn.ParallelTable() --The two Tensors go through two different Linear
+ c:add(nn.Linear(10,3)) --Layers in Parallel
+ c:add(nn.Linear(10,7))
+ mlp:add(c) --Outputing a table with 2 elements
+ p=nn.ParallelTable() --These tables go through two more linear layers
+ p:add(nn.Linear(3,2)) -- separately.
+ p:add(nn.Linear(7,1))
+ mlp:add(p)
+ mlp:add(nn.JoinTable(1)) --Finally, the tables are joined together and output.
+
+ pred=mlp:forward(torch.randn(10,2))
+ print(pred)
+
+ for i=1,25 do -- A few steps of training such a network..
+ x=torch.ones(10,2);
+ y=torch.Tensor(3); y:copy(x:select(2,1,1):narrow(1,1,3))
+ pred=mlp:forward(x)
+
+ criterion= nn.MSECriterion()
+ local err=criterion:forward(pred,y)
+ local gradCriterion = criterion:backward(pred,y);
+ print(x,y)
+ mlp:zeroGradParameters();
+ mlp:backward(x, gradCriterion);
+ mlp:updateParameters(0.05);
+
+ print(err)
+ end
end
function t4()
- local getInput1 = nn.Identity()()
- local getInput2 = nn.Identity()()
- local mlp = nn.Tanh()(getInput1)
- net = nn.gModule({getInput1, getInput2}, {mlp, getInput2})
+ local getInput1 = nn.Identity()()
+ local getInput2 = nn.Identity()()
+ local mlp = nn.Tanh()(getInput1)
+ net = nn.gModule({getInput1, getInput2}, {mlp, getInput2})
- local input1 = torch.randn(2)
- local input2 = torch.randn(5)
+ local input1 = torch.randn(2)
+ local input2 = torch.randn(5)
- net:forward({input1, input2})
- local gradInput = net:backward({input1, input2},
- {torch.randn(input1:size()), torch.randn(input2:size())})
- print("gradInput[1]:", gradInput[1])
- print("gradInput[2]:", gradInput[2])
- graph.dot(net.fg)
- assert(gradInput[1]:nElement() == input1:nElement(), "size mismatch")
+ net:forward({input1, input2})
+ local gradInput = net:backward({input1, input2},
+ {torch.randn(input1:size()), torch.randn(input2:size())})
+ print("gradInput[1]:", gradInput[1])
+ print("gradInput[2]:", gradInput[2])
+ graph.dot(net.fg)
+ assert(gradInput[1]:nElement() == input1:nElement(), "size mismatch")
end
function t5()
- local m = nn.Sequential()
- m:add(nn.SplitTable(1))
- m:add(nn.ParallelTable():add(nn.Linear(10,20)):add(nn.Linear(10,30)))
- local input = nn.Identity()()
- local input1,input2 = m(input):split(2)
- local m3 = nn.JoinTable(1)({input1,input2})
-
- g = nn.gModule({input},{m3})
- graph.dot(g.fg,'init forward')
-
- local indata = torch.rand(2,10)
- local gdata = torch.rand(50)
- g:forward(indata)
- g:backward(indata,gdata)
-
- graph.dot(g.fg,'forward')
- graph.dot(g.bg,'backward')
+ local m = nn.Sequential()
+ m:add(nn.SplitTable(1))
+ m:add(nn.ParallelTable():add(nn.Linear(10,20)):add(nn.Linear(10,30)))
+ local input = nn.Identity()()
+ local input1,input2 = m(input):split(2)
+ local m3 = nn.JoinTable(1)({input1,input2})
+
+ g = nn.gModule({input},{m3})
+ graph.dot(g.fg,'init forward')
+
+ local indata = torch.rand(2,10)
+ local gdata = torch.rand(50)
+ g:forward(indata)
+ g:backward(indata,gdata)
+
+ graph.dot(g.fg,'forward')
+ graph.dot(g.bg,'backward')
end
function topsort(a)
- -- first clone the graph
- -- local g = self:clone()
- -- local nodes = g.nodes
- -- local edges = g.edges
- -- for i,node in ipairs(nodes) do
- -- node.children = {}
- -- end
-
- -- reverse the graph
- rg,map = a:reverse()
- local rmap = {}
- for k,v in pairs(map) do
- rmap[v] = k
- end
-
- -- work on the sorted graph
- sortednodes = {}
- rootnodes = rg:roots()
-
- if #rootnodes == 0 then
- print('Graph has cycles')
- end
-
- -- run
- for i,root in ipairs(rootnodes) do
- root:dfs(function(node)
- print(node.id,rmap[node].id)
- -- print(rmap[node])
- table.insert(sortednodes,rmap[node]) end)
- end
-
- if #sortednodes ~= #a.nodes then
- print('Graph has cycles')
- end
- return sortednodes,rg,rootnodes
-end
-
-local my={eq =
- function(a,b,s)
- if a:dist(b) == 0 then
- print('ok')
- else
- print('error : ' .. s)
- print('a : ');print(a)
- print('b : ');print(b)
- end
- end}
-
-function t8()
- local in1 = nn.Identity()()
- local m = nn.Linear(10,10)(in1)
- local out1 = nn.Tanh()(m)
- local out2 = nn.Tanh()(m)
- local out = nn.CAddTable(){out1, out2}
- local mod = nn.gModule({in1}, {out})
-
- local dot = nngraph.simple_print.todot(mod.fg, 'bogus')
- print (dot)
- nngraph.simple_print.dot(mod.fg, 'bogus', 'new')
- graph.dot(mod.fg, 'bogus', 'old')
-end
--- t2()
-t8()
+ -- first clone the graph
+ -- local g = self:clone()
+ -- local nodes = g.nodes
+ -- local edges = g.edges
+ -- for i,node in ipairs(nodes) do
+ -- node.children = {}
+ -- end
+
+ -- reverse the graph
+ rg,map = a:reverse()
+ local rmap = {}
+ for k,v in pairs(map) do
+ rmap[v] = k
+ end
+
+ -- work on the sorted graph
+ sortednodes = {}
+ rootnodes = rg:roots()
+
+ if #rootnodes == 0 then
+ print('Graph has cycles')
+ end
+
+ -- run
+ for i,root in ipairs(rootnodes) do
+ root:dfs(function(node)
+ print(node.id,rmap[node].id)
+ -- print(rmap[node])
+ table.insert(sortednodes,rmap[node]) end)
+ end
+
+ if #sortednodes ~= #a.nodes then
+ print('Graph has cycles')
+ end
+ return sortednodes,rg,rootnodes
+ end
+
+ local my={eq =
+ function(a,b,s)
+ if a:dist(b) == 0 then
+ print('ok')
+ else
+ print('error : ' .. s)
+ print('a : ');print(a)
+ print('b : ');print(b)
+ end
+ end}
+
+ function t8()
+ local in1 = nn.Identity()()
+ local m = nn.Linear(10,10)(in1)
+ local out1 = nn.Tanh()(m)
+ local out2 = nn.Tanh()(m)
+ local out = nn.CAddTable(){out1, out2}
+ local mod = nn.gModule({in1}, {out})
+
+ local dot = nngraph.simple_print.todot(mod.fg, 'bogus')
+ print (dot)
+ nngraph.simple_print.dot(mod.fg, 'bogus', 'new')
+ graph.dot(mod.fg, 'bogus', 'old')
+ end
+ -- t2()
+ t8()
diff --git a/utils.lua b/utils.lua
index d358157..7b1ba07 100644
--- a/utils.lua
+++ b/utils.lua
@@ -1,20 +1,18 @@
local utils = {}
function utils.istensor(x)
- if torch.typename(x) and torch.typename(x):find('Tensor') then
- return true
- end
- return false
+ if torch.typename(x) and torch.typename(x):find('Tensor') then
+ return true
+ end
+ return false
end
function utils.istorchclass(x)
- return type(x) == 'table' and torch.typename(x)
+ return type(x) == 'table' and torch.typename(x)
end
function utils.istable(x)
- return type(x) == 'table' and not torch.typename(x)
+ return type(x) == 'table' and not torch.typename(x)
end
return utils
-
-