Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nngraph.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorkoray kavukcuoglu <koray@kavukcuoglu.org>2015-03-16 13:42:01 +0300
committerkoray kavukcuoglu <koray@kavukcuoglu.org>2015-03-16 13:42:01 +0300
commit78c3649f7903454aa9043dcb8d639537c6d5f604 (patch)
tree9e6e1cd0dba49e5c7fb924b2f7656bed0292f2a9
parentbcc714d6b40cfe5be2dea4916412446b259f7601 (diff)
parent0e9f7d4e46b0a15f249b4045f86c9d4fbaf9dbcb (diff)
Merge pull request #42 from yozw/debug-labels
Fixing problem with debug labels for nodes
-rw-r--r--node.lua11
-rw-r--r--test/test_nngraph.lua8
2 files changed, 15 insertions, 4 deletions
diff --git a/node.lua b/node.lua
index 07bcf17..c01bdae 100644
--- a/node.lua
+++ b/node.lua
@@ -21,8 +21,10 @@ end
--[[ Build a string label which will be used a tooltip when
making a graph.]]
function nnNode:_makeDebugLabel(dinfo)
- self.data.annotations._debugLabel = string.format('[%s]:%d',
- dinfo.short_src, dinfo.currentline, dinfo.name)
+ if dinfo then
+ self.data.annotations._debugLabel = string.format('[%s]:%d',
+ dinfo.short_src, dinfo.currentline, dinfo.name)
+ end
end
@@ -46,12 +48,13 @@ end
-- node in the order they are returned.
function nnNode:split(noutput)
assert(noutput >= 2, "splitting to one output is not supported")
- local mnode = nngraph.Node({nSplitOutputs=noutput})
+ local debugLabel = self.data.annotations._debugLabel
+ local mnode = nngraph.Node({nSplitOutputs=noutput, annotations={_debugLabel=debugLabel .. '-mnode'}})
mnode:add(self,true)
local selectnodes = {}
for i=1,noutput do
- local node = nngraph.Node({selectindex=i,input={}})
+ local node = nngraph.Node({selectindex=i,input={}, annotations={_debugLabel=debugLabel .. '-' .. i}})
node:add(mnode,true)
table.insert(selectnodes,node)
end
diff --git a/test/test_nngraph.lua b/test/test_nngraph.lua
index b709cda..ac47be2 100644
--- a/test/test_nngraph.lua
+++ b/test/test_nngraph.lua
@@ -80,6 +80,14 @@ function test.test_twoInputs2()
checkGradients(module, input)
end
+function test.test_splitDebugLabels()
+ local node = nn.Identity()()
+ node.data.annotations._debugLabel = "node"
+ local node1, node2 = node:split(2)
+ assert(node1.data.annotations._debugLabel == "node-1")
+ assert(node2.data.annotations._debugLabel == "node-2")
+end
+
function test.test_identity()
local in1 = nn.Identity()()
local in2 = nn.Identity()()