Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nngraph.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorkoray kavukcuoglu <koray@kavukcuoglu.org>2014-04-01 14:34:29 +0400
committerkoray kavukcuoglu <koray@kavukcuoglu.org>2014-04-01 14:34:29 +0400
commit4615d9a443937a2a9b2e4294828bcd17858e733c (patch)
tree3b8bcf8081eee225a2d17f5c47afef6da15415a3
parent05a9ca42259cb5657209fe2d6e4fa27542193844 (diff)
add contructor and functional access to set the name of a nodenamed_nnnode
-rw-r--r--init.lua25
-rw-r--r--node.lua9
2 files changed, 27 insertions, 7 deletions
diff --git a/init.lua b/init.lua
index 273de8c..98c53e6 100644
--- a/init.lua
+++ b/init.lua
@@ -20,16 +20,29 @@ local istorchclass = utils.istorchclass
local Module = torch.getmetatable('nn.Module')
function Module:__call__(...)
local nArgs = select("#", ...)
- assert(nArgs <= 1, 'Use {input1, input2} to pass multiple inputs.')
-
- local input = ...
- if nArgs == 1 and input == nil then
- error('what is this in the input? nil')
+ assert(nArgs <= 2, 'Usage __call(input|{input1,input2,...} [, name])')
+
+ local input, name = nil, nil
+ if nArgs == 2 then
+ input = ({...})[1]
+ name = ({...})[2]
+ assert(type(name) == 'string', 'The second argument can be string only (used for name)')
+ elseif nArgs == 1 then
+ input = ...
+ if type(input) == 'string' then
+ name = input
+ input = nil
+ nArgs = 0
+ elseif input == nil then
+ error('what is this in the input? nil')
+ end
end
+
if not istable(input) then
input = {input}
end
- local mnode = nngraph.Node({module=self})
+
+ local mnode = nngraph.Node({module=self, _name = name})
for i,dnode in ipairs(input) do
if torch.typename(dnode) ~= 'nngraph.Node' then
diff --git a/node.lua b/node.lua
index 51ee300..76f74f4 100644
--- a/node.lua
+++ b/node.lua
@@ -12,6 +12,13 @@ function nnNode:__init(data)
self.data.mapindex = self.data.mapindex or {}
end
+function nnNode:name(name)
+ if self.data and istable(self.data) then
+ self.data._name = name
+ end
+ return self
+end
+
-- domap ensures that this node will keep track of the order its children are added.
-- mapindex is a forward/backward list
-- index = self.data.mapindex[child.data]
@@ -47,7 +54,7 @@ end
function nnNode:label()
local lbl = {}
-
+
local function getstr(data)
if not data then return '' end
if istensor(data) then