Welcome to mirror list, hosted at ThFree Co, Russian Federation.

node.lua - github.com/torch/nngraph.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: b620456642ebbb51ab4e4f13061039a014ee3abd (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166

local utils = paths.dofile('utils.lua')
local istensor = utils.istensor
local istable = utils.istable
local istorchclass = utils.istorchclass
require 'debug'


local nnNode,parent = torch.class('nngraph.Node','graph.Node')

function nnNode:__init(data)
	parent.__init(self,data)
        self.data.annotations = self.data.annotations or {}
	self.data.mapindex = self.data.mapindex or {}
        if not self.data.annotations._debugLabel then
          self:_makeDebugLabel(debug.getinfo(6, 'Sl'))
        end
end


--[[ Build a string label which will be used a tooltip when
  making a graph.]]
function nnNode:_makeDebugLabel(dinfo)
	if dinfo then
		self.data.annotations._debugLabel = string.format('[%s]:%d',
			dinfo.short_src, dinfo.currentline, dinfo.name)
	end
end


-- domap ensures that this node will keep track of the order its children are added.
-- mapindex is a forward/backward list
-- index = self.data.mapindex[child.data]
-- child.data = self.data.mapindex[index]
function nnNode:add(child,domap)
	parent.add(self,child)
	if domap then
		local mapindex = self.data.mapindex
		local data = child.data
		assert(not mapindex[data], "Don't pass the same input twice.")
		table.insert(mapindex,data)
		mapindex[data] = #mapindex
	end
end

-- this function returns noutput number of new nodes
-- that each take a single component of the output of this
-- node in the order they are returned.
function nnNode:split(noutput)
	assert(noutput >= 2, "splitting to one output is not supported")
	local debugLabel = self.data.annotations._debugLabel
	local mnode = nngraph.Node({nSplitOutputs=noutput, annotations={_debugLabel=debugLabel .. '-mnode'}})
	mnode:add(self,true)

	local selectnodes = {}
	for i=1,noutput do
		local node = nngraph.Node({selectindex=i,input={}, annotations={_debugLabel=debugLabel .. '-' .. i}})
		node:add(mnode,true)
		table.insert(selectnodes,node)
	end
	return unpack(selectnodes)
end


function nnNode:annotate(annotations)
  for k, v in pairs(annotations) do
    self.data.annotations[k] = v
  end

  return self
end


function nnNode:graphNodeName()
  if self.data.annotations.name then
    return self.data.annotations.name .. ' (' .. self.id .. ')'
  else
    return 'Node' .. self.id
  end
end


function nnNode:graphNodeAttributes()
  self.data.annotations.graphAttributes =
      self.data.annotations.graphAttributes or {}
  if not self.data.annotations.graphAttributes.tooltip then
    self.data.annotations.graphAttributes.tooltip =
        self.data.annotations._debugLabel
  end

  return self.data.annotations.graphAttributes
end


local function getNanFlag(data)
	if data:nElement() == 0 then
		return ''
	end
	local isNan = (data:ne(data):sum() > 0)
	if isNan then
		return 'NaN'
	end
	if data:max() == math.huge then
		return 'inf'
	end
	if data:min() == -math.huge then
		return '-inf'
	end
	return ''
end

function nnNode:label()

	local lbl = {}

	local function getstr(data)
		if not data then return '' end
		if istensor(data) then
			local nanFlag = getNanFlag(data)
			local tensorType = 'Tensor'
			if data:type() ~= torch.Tensor():type() then
				tensorType = data:type()
			end
			return tensorType .. '[' .. table.concat(data:size():totable(),'x') .. ']' .. nanFlag
		elseif istable(data) then
			local tstr = {}
			for i,v in ipairs(data) do
				table.insert(tstr, getstr(v))
			end
			return '{' .. table.concat(tstr,',') .. '}'
		else
			return tostring(data):gsub('\n','\\l')
		end
	end
	local function getmapindexstr(mapindex)
		local tstr = {}
		for i,data in ipairs(mapindex) do
			local inputId = 'Node' .. (data.forwardNodeId or '')
			table.insert(tstr, inputId)
		end
		return '{' .. table.concat(tstr,',') .. '}'
	end

	for k,v in pairs(self.data) do
		local vstr = ''
		if k== 'mapindex' then
			if #v > 1 then
				vstr = getmapindexstr(v)
				table.insert(lbl, k .. ' = ' .. vstr)
			end
		elseif k== 'forwardNodeId' or k== 'annotations' then
			-- the forwardNodeId is not displayed in the label.
		else
			vstr = getstr(v)
			table.insert(lbl, k .. ' = ' .. vstr)
		end
	end

        local desc
        if self.data.annotations.description then
          desc = 'desc = ' .. self.data.annotations.description .. '\\n'
        else
          desc = ''
        end
	return desc .. table.concat(lbl,"\\l")
end