diff options
author | Yori Zwols <yori@google.com> | 2015-03-12 17:23:42 +0300 |
---|---|---|
committer | Yori Zwols <yori@google.com> | 2015-03-12 17:23:42 +0300 |
commit | 0e9f7d4e46b0a15f249b4045f86c9d4fbaf9dbcb (patch) | |
tree | 2a08d61a569177d6f45e79ab52de8ace27507756 | |
parent | 53e4bedf82ece70c91dabe9b81a780e288fddecb (diff) |
Fixing problem with debug labels for nodes
-rw-r--r-- | node.lua | 11 | ||||
-rw-r--r-- | test/test_nngraph.lua | 8 |
2 files changed, 15 insertions, 4 deletions
@@ -21,8 +21,10 @@ end --[[ Build a string label which will be used a tooltip when making a graph.]] function nnNode:_makeDebugLabel(dinfo) - self.data.annotations._debugLabel = string.format('[%s]:%d', - dinfo.short_src, dinfo.currentline, dinfo.name) + if dinfo then + self.data.annotations._debugLabel = string.format('[%s]:%d', + dinfo.short_src, dinfo.currentline, dinfo.name) + end end @@ -46,12 +48,13 @@ end -- node in the order they are returned. function nnNode:split(noutput) assert(noutput >= 2, "splitting to one output is not supported") - local mnode = nngraph.Node({nSplitOutputs=noutput}) + local debugLabel = self.data.annotations._debugLabel + local mnode = nngraph.Node({nSplitOutputs=noutput, annotations={_debugLabel=debugLabel .. '-mnode'}}) mnode:add(self,true) local selectnodes = {} for i=1,noutput do - local node = nngraph.Node({selectindex=i,input={}}) + local node = nngraph.Node({selectindex=i,input={}, annotations={_debugLabel=debugLabel .. '-' .. i}}) node:add(mnode,true) table.insert(selectnodes,node) end diff --git a/test/test_nngraph.lua b/test/test_nngraph.lua index 4062c5a..9073036 100644 --- a/test/test_nngraph.lua +++ b/test/test_nngraph.lua @@ -80,6 +80,14 @@ function test.test_twoInputs2() checkGradients(module, input) end +function test.test_splitDebugLabels() + local node = nn.Identity()() + node.data.annotations._debugLabel = "node" + local node1, node2 = node:split(2) + assert(node1.data.annotations._debugLabel == "node-1") + assert(node2.data.annotations._debugLabel == "node-2") +end + function test.test_identity() local in1 = nn.Identity()() local in2 = nn.Identity()() |