Welcome to mirror list, hosted at ThFree Co, Russian Federation.

node.lua - github.com/torch/nngraph.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 76f74f43a56e94578ae81e03a1212ae47e0704b1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96

local utils = paths.dofile('utils.lua')
local istensor = utils.istensor
local istable = utils.istable
local istorchclass = utils.istorchclass


local nnNode,parent = torch.class('nngraph.Node','graph.Node')

function nnNode:__init(data)
	parent.__init(self,data)
	self.data.mapindex = self.data.mapindex or {}
end

function nnNode:name(name)
	if self.data and istable(self.data) then
		self.data._name = name
	end
	return self
end

-- domap ensures that this node will keep track of the order its children are added.
-- mapindex is a forward/backward list
-- index = self.data.mapindex[child.data]
-- child.data = self.data.mapindex[index]
function nnNode:add(child,domap)
	parent.add(self,child)
	if domap then
		local mapindex = self.data.mapindex
		local data = child.data
		assert(not mapindex[data], "Don't pass the same input twice.")
		table.insert(mapindex,data)
		mapindex[data] = #mapindex
	end
end

-- this function returns noutput number of new nodes
-- that each take a single component of the output of this 
-- node in the order they are returned.
function nnNode:split(noutput)
	assert(noutput >= 2, "splitting to one output is not supported")
	local mnode = nngraph.Node({nSplitOutputs=noutput})
	mnode:add(self,true)

	local selectnodes = {}
	for i=1,noutput do
		local node = nngraph.Node({selectindex=i,input={}})
		node:add(mnode,true)
		table.insert(selectnodes,node)
	end
	return unpack(selectnodes)
end

function nnNode:label()

	local lbl = {}
	
	local function getstr(data)
		if not data then return '' end
		if istensor(data) then
			return 'Tensor[' .. table.concat(data:size():totable(),'x') .. ']'
		elseif istable(data) then
			local tstr = {}
			for i,v in ipairs(data) do
				table.insert(tstr, getstr(v))
			end
			return '{' .. table.concat(tstr,',') .. '}'
		else
			return tostring(data):gsub('\n','\\l')
		end
	end
	local function getmapindexstr(mapindex)
		local tstr = {}
		for i,data in ipairs(mapindex) do
			local inputId = 'Node' .. (data.forwardNodeId or '')
			table.insert(tstr, inputId)
		end
		return '{' .. table.concat(tstr,',') .. '}'
	end

	for k,v in pairs(self.data) do
		local vstr = ''
		if k=='mapindex' then
			if #v > 1 then 
				vstr = getmapindexstr(v)
				table.insert(lbl, k .. ' = ' .. vstr)
			end
		elseif k=='forwardNodeId' then
			-- the forwardNodeId is not displayed in the label.
		else
			vstr = getstr(v)
			table.insert(lbl, k .. ' = ' .. vstr)
		end
	end
	return table.concat(lbl,"\\l")
end