Welcome to mirror list, hosted at ThFree Co, Russian Federation.

graphviz.lua - github.com/torch/graph.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: a30fff6b470cdfbcaa1feda95ad1e12069801d48 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
require 'torch'

local ffiOk = false
local graphvizOk = false
local cgraphOk = false
local ffi
local graphviz
local cgraph

ffiOk, ffi = pcall(require, 'ffi')
if ffiOk then
    ffi.cdef[[
typedef struct FILE FILE;

typedef struct Agraph_s Agraph_t;
typedef struct Agnode_s Agnode_t;

extern Agraph_t *agmemread(const char *cp);
extern char *agget(void *obj, char *name);
extern int agclose(Agraph_t * g);
extern Agnode_t *agfstnode(Agraph_t * g);
extern Agnode_t *agnxtnode(Agraph_t * g, Agnode_t * n);
extern Agnode_t *aglstnode(Agraph_t * g);
extern Agnode_t *agprvnode(Agraph_t * g, Agnode_t * n);

typedef struct Agraph_s graph_t;
typedef struct GVJ_s GVJ_t;
typedef struct GVG_s GVG_t;
typedef struct GVC_s GVC_t;
extern GVC_t *gvContext(void);
extern int gvLayout(GVC_t *context, graph_t *g, const char *engine);
extern int gvRender(GVC_t *context, graph_t *g, const char *format, FILE *out);
extern int gvFreeLayout(GVC_t *context, graph_t *g);
extern int gvFreeContext(GVC_t *context);
]]
    graphvizOk, graphviz = pcall(function() return ffi.load('libgvc') end)
    cgraphOk, cgraph = pcall(function() return ffi.load('libcgraph') end)
else
    graphvizOk = false
    cgraphOk = false
end


-- Retrieve attribute data from a graphviz object.
local function getAttribute(obj, name)
	local res = cgraph.agget(obj, ffi.cast("char*", name))
	assert(res ~= ffi.cast("char*", nil), 'could not get attr ' .. name)
	return ffi.string(res)
end
-- Iterate through nodes of a graphviz graph.
local function nodeIterator(graph)
	local node = cgraph.agfstnode(graph)
	local nextNode
	return function()
		if node == nil then return end
		if node == cgraph.aglstnode(graph) then nextNode = nil end
		nextNode = cgraph.agnxtnode(graph, node)
		local result = node
		node = nextNode
		return result
	end
end
-- Convert a string of comma-separated numbers to actual numbers.
local function extractNumbers(n, attr)
	local res = {}
	for number in string.gmatch(attr, "[^%,]+") do
		table.insert(res, tonumber(number))
	end
	assert(#res == n, "attribute is not of expected form")
	return unpack(res)
end
-- Transform from graphviz coordinates to unit square.
local function getRelativePosition(node, bbox)
	local x0, y0, w, h = unpack(bbox)
	local x, y = extractNumbers(2, getAttribute(node, 'pos'))
	local xt = (x - x0) / w
	local yt = (y - y0) / h
	assert(xt >= 0 and xt <= 1, "bad x coordinate")
	assert(yt >= 0 and yt <= 1, "bad y coordinate")
	return xt, yt
end
-- Retrieve a node's ID based on its label string.
local function getID(node)
	local label = getAttribute(node, 'label')
	local res = {string.find(label, "^Node(%d+)")} or {string.find(label, "%((%d+)%)\\n")}
	local id = res[3]
	assert(id ~= nil, "could not get ID from node label : <" .. tostring(label) .. ">")
	return tonumber(id)
end

--[[ Lay out a graph and return the positions of the nodes.

Args:
* `g` - graph to lay out.
* `algorithm` - name of the graphviz algorithm to use. (default: "dot")

Returns:
* `torch.Tensor(n, 2)` containing the resulting positions of the nodes.
where `n` is the number of nodes in the graph.

Coordinates are in the interval [0, 1].

]]
function graph.graphvizLayout(g, algorithm)
	if not graphvizOk or not cgraphOk then
		error("graphviz library could not be loaded.")
	end
	local nNodes = #g.nodes
	local context = graphviz.gvContext()
	local graphvizGraph = cgraph.agmemread(g:todot())
	local algorithm = algorithm or "dot"
	assert(0 == graphviz.gvLayout(context, graphvizGraph, algorithm),
	       "graphviz layout failed")
	assert(0 == graphviz.gvRender(context, graphvizGraph, algorithm, nil),
	       "graphviz render failed")

	-- Extract bounding box.
	local x0, y0, x1, y1 = extractNumbers(4,
	    getAttribute(graphvizGraph, 'bb'), ",")
	local w = x1 - x0
	local h = y1 - y0
	local bbox = { x0, y0, w, h }

	-- Extract node positions.
	local positions = torch.zeros(nNodes, 2)
	for node in nodeIterator(graphvizGraph) do
		local id = getID(node)
		local x, y = getRelativePosition(node, bbox)
		positions[id][1] = x
		positions[id][2] = y
	end

	-- Clean up.
	graphviz.gvFreeLayout(context, graphvizGraph)
	cgraph.agclose(graphvizGraph)
	graphviz.gvFreeContext(context)
	return positions
end

function graph.graphvizFile(g, algorithm, fname)
	algorithm = algorithm or 'dot'
	local _,_,rendertype = fname:reverse():find('(%a+)%.%w+')
	rendertype = rendertype:reverse()

	local context = graphviz.gvContext()
	local graphvizGraph = cgraph.agmemread(g:todot())
	assert(0 == graphviz.gvLayout(context, graphvizGraph, algorithm),
	       "graphviz layout failed")
	assert(0 == graphviz.gvRender(context, graphvizGraph, rendertype, io.open(fname, 'w')),
		   "graphviz render failed")
	graphviz.gvFreeLayout(context, graphvizGraph)
	cgraph.agclose(graphvizGraph)
	graphviz.gvFreeContext(context)
end

--[[
Given a graph, dump an SVG or display it using graphviz.

Args:
* `g` - graph to display
* `title` - Title to display in the graph
* `fname` - [optional] if given it should contain a file name without an extension,
   the graph is saved on disk as fname.svg and display is not shown. If not given
   the graph is shown on qt display (you need to have qtsvg installed and running qlua)

Returns:
* `qs` - the window handle for the qt display (if fname given) or nil
]]
function graph.dot(g,title,fname)
	local qt_display = fname == nil
	fname = fname or os.tmpname()
	local fnsvg = fname .. '.svg'
	local fndot = fname .. '.dot'
	graph.graphvizFile(g, 'dot', fnsvg)
	graph.graphvizFile(g, 'dot', fndot)
	if qt_display then
		require 'qtsvg'
		local qs = qt.QSvgWidget(fnsvg)
		qs:show()
		os.remove(fnsvg)
		os.remove(fndot)
		return qs
	end
end