Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nngraph.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorClement Farabet <cfarabet@twitter.com>2015-09-04 23:31:01 +0300
committerClement Farabet <cfarabet@twitter.com>2015-09-04 23:31:01 +0300
commitba7f3ec6ffe7e60e5e07b2886178aec54e6305e5 (patch)
tree23aaf8f288fdce23b7361a1e7252254a690a6fb9 /node.lua
parent72f74d39257a344a2c3237c83f8a828b916817e4 (diff)
Whitespace cleanup.
Diffstat (limited to 'node.lua')
-rw-r--r--node.lua235
1 files changed, 114 insertions, 121 deletions
diff --git a/node.lua b/node.lua
index b620456..7ef55be 100644
--- a/node.lua
+++ b/node.lua
@@ -5,162 +5,155 @@ local istable = utils.istable
local istorchclass = utils.istorchclass
require 'debug'
-
local nnNode,parent = torch.class('nngraph.Node','graph.Node')
function nnNode:__init(data)
- parent.__init(self,data)
- self.data.annotations = self.data.annotations or {}
- self.data.mapindex = self.data.mapindex or {}
- if not self.data.annotations._debugLabel then
- self:_makeDebugLabel(debug.getinfo(6, 'Sl'))
- end
+ parent.__init(self,data)
+ self.data.annotations = self.data.annotations or {}
+ self.data.mapindex = self.data.mapindex or {}
+ if not self.data.annotations._debugLabel then
+ self:_makeDebugLabel(debug.getinfo(6, 'Sl'))
+ end
end
-
--[[ Build a string label which will be used a tooltip when
- making a graph.]]
+making a graph.]]
function nnNode:_makeDebugLabel(dinfo)
- if dinfo then
- self.data.annotations._debugLabel = string.format('[%s]:%d',
- dinfo.short_src, dinfo.currentline, dinfo.name)
- end
+ if dinfo then
+ self.data.annotations._debugLabel = string.format('[%s]:%d',
+ dinfo.short_src, dinfo.currentline, dinfo.name)
+ end
end
-
-- domap ensures that this node will keep track of the order its children are added.
-- mapindex is a forward/backward list
-- index = self.data.mapindex[child.data]
-- child.data = self.data.mapindex[index]
function nnNode:add(child,domap)
- parent.add(self,child)
- if domap then
- local mapindex = self.data.mapindex
- local data = child.data
- assert(not mapindex[data], "Don't pass the same input twice.")
- table.insert(mapindex,data)
- mapindex[data] = #mapindex
- end
+ parent.add(self,child)
+ if domap then
+ local mapindex = self.data.mapindex
+ local data = child.data
+ assert(not mapindex[data], "Don't pass the same input twice.")
+ table.insert(mapindex,data)
+ mapindex[data] = #mapindex
+ end
end
-- this function returns noutput number of new nodes
-- that each take a single component of the output of this
-- node in the order they are returned.
function nnNode:split(noutput)
- assert(noutput >= 2, "splitting to one output is not supported")
- local debugLabel = self.data.annotations._debugLabel
- local mnode = nngraph.Node({nSplitOutputs=noutput, annotations={_debugLabel=debugLabel .. '-mnode'}})
- mnode:add(self,true)
-
- local selectnodes = {}
- for i=1,noutput do
- local node = nngraph.Node({selectindex=i,input={}, annotations={_debugLabel=debugLabel .. '-' .. i}})
- node:add(mnode,true)
- table.insert(selectnodes,node)
- end
- return unpack(selectnodes)
+ assert(noutput >= 2, "splitting to one output is not supported")
+ local debugLabel = self.data.annotations._debugLabel
+ local mnode = nngraph.Node({nSplitOutputs=noutput, annotations={_debugLabel=debugLabel .. '-mnode'}})
+ mnode:add(self,true)
+
+ local selectnodes = {}
+ for i=1,noutput do
+ local node = nngraph.Node({selectindex=i,input={}, annotations={_debugLabel=debugLabel .. '-' .. i}})
+ node:add(mnode,true)
+ table.insert(selectnodes,node)
+ end
+ return unpack(selectnodes)
end
-
function nnNode:annotate(annotations)
- for k, v in pairs(annotations) do
- self.data.annotations[k] = v
- end
+ for k, v in pairs(annotations) do
+ self.data.annotations[k] = v
+ end
- return self
+ return self
end
-
function nnNode:graphNodeName()
- if self.data.annotations.name then
- return self.data.annotations.name .. ' (' .. self.id .. ')'
- else
- return 'Node' .. self.id
- end
+ if self.data.annotations.name then
+ return self.data.annotations.name .. ' (' .. self.id .. ')'
+ else
+ return 'Node' .. self.id
+ end
end
-
function nnNode:graphNodeAttributes()
- self.data.annotations.graphAttributes =
- self.data.annotations.graphAttributes or {}
- if not self.data.annotations.graphAttributes.tooltip then
- self.data.annotations.graphAttributes.tooltip =
- self.data.annotations._debugLabel
- end
-
- return self.data.annotations.graphAttributes
+ self.data.annotations.graphAttributes =
+ self.data.annotations.graphAttributes or {}
+ if not self.data.annotations.graphAttributes.tooltip then
+ self.data.annotations.graphAttributes.tooltip =
+ self.data.annotations._debugLabel
+ end
+
+ return self.data.annotations.graphAttributes
end
-
local function getNanFlag(data)
- if data:nElement() == 0 then
- return ''
- end
- local isNan = (data:ne(data):sum() > 0)
- if isNan then
- return 'NaN'
- end
- if data:max() == math.huge then
- return 'inf'
- end
- if data:min() == -math.huge then
- return '-inf'
- end
- return ''
+ if data:nElement() == 0 then
+ return ''
+ end
+ local isNan = (data:ne(data):sum() > 0)
+ if isNan then
+ return 'NaN'
+ end
+ if data:max() == math.huge then
+ return 'inf'
+ end
+ if data:min() == -math.huge then
+ return '-inf'
+ end
+ return ''
end
function nnNode:label()
- local lbl = {}
-
- local function getstr(data)
- if not data then return '' end
- if istensor(data) then
- local nanFlag = getNanFlag(data)
- local tensorType = 'Tensor'
- if data:type() ~= torch.Tensor():type() then
- tensorType = data:type()
- end
- return tensorType .. '[' .. table.concat(data:size():totable(),'x') .. ']' .. nanFlag
- elseif istable(data) then
- local tstr = {}
- for i,v in ipairs(data) do
- table.insert(tstr, getstr(v))
- end
- return '{' .. table.concat(tstr,',') .. '}'
- else
- return tostring(data):gsub('\n','\\l')
- end
- end
- local function getmapindexstr(mapindex)
- local tstr = {}
- for i,data in ipairs(mapindex) do
- local inputId = 'Node' .. (data.forwardNodeId or '')
- table.insert(tstr, inputId)
- end
- return '{' .. table.concat(tstr,',') .. '}'
- end
-
- for k,v in pairs(self.data) do
- local vstr = ''
- if k== 'mapindex' then
- if #v > 1 then
- vstr = getmapindexstr(v)
- table.insert(lbl, k .. ' = ' .. vstr)
- end
- elseif k== 'forwardNodeId' or k== 'annotations' then
- -- the forwardNodeId is not displayed in the label.
- else
- vstr = getstr(v)
- table.insert(lbl, k .. ' = ' .. vstr)
- end
- end
-
- local desc
- if self.data.annotations.description then
- desc = 'desc = ' .. self.data.annotations.description .. '\\n'
- else
- desc = ''
- end
- return desc .. table.concat(lbl,"\\l")
+ local lbl = {}
+
+ local function getstr(data)
+ if not data then return '' end
+ if istensor(data) then
+ local nanFlag = getNanFlag(data)
+ local tensorType = 'Tensor'
+ if data:type() ~= torch.Tensor():type() then
+ tensorType = data:type()
+ end
+ return tensorType .. '[' .. table.concat(data:size():totable(),'x') .. ']' .. nanFlag
+ elseif istable(data) then
+ local tstr = {}
+ for i,v in ipairs(data) do
+ table.insert(tstr, getstr(v))
+ end
+ return '{' .. table.concat(tstr,',') .. '}'
+ else
+ return tostring(data):gsub('\n','\\l')
+ end
+ end
+ local function getmapindexstr(mapindex)
+ local tstr = {}
+ for i,data in ipairs(mapindex) do
+ local inputId = 'Node' .. (data.forwardNodeId or '')
+ table.insert(tstr, inputId)
+ end
+ return '{' .. table.concat(tstr,',') .. '}'
+ end
+
+ for k,v in pairs(self.data) do
+ local vstr = ''
+ if k== 'mapindex' then
+ if #v > 1 then
+ vstr = getmapindexstr(v)
+ table.insert(lbl, k .. ' = ' .. vstr)
+ end
+ elseif k== 'forwardNodeId' or k== 'annotations' then
+ -- the forwardNodeId is not displayed in the label.
+ else
+ vstr = getstr(v)
+ table.insert(lbl, k .. ' = ' .. vstr)
+ end
+ end
+
+ local desc
+ if self.data.annotations.description then
+ desc = 'desc = ' .. self.data.annotations.description .. '\\n'
+ else
+ desc = ''
+ end
+ return desc .. table.concat(lbl,"\\l")
end