Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/argcheck.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRonan Collobert <ronan@collobert.com>2014-04-11 19:07:15 +0400
committerRonan Collobert <locronan@fb.com>2014-11-07 05:38:42 +0300
commit35aaadf09878c868714ba0892fb1da9dedb07a33 (patch)
treec0a47f23ae9085aaba9044bc6e8d477a82966c1f
parentb30b0aa0cc3f14eb1855e191e04cb5551a7f79e9 (diff)
basic code for graph generation
-rw-r--r--init.lua125
1 files changed, 124 insertions, 1 deletions
diff --git a/init.lua b/init.lua
index be31a2c..8e50fc0 100644
--- a/init.lua
+++ b/init.lua
@@ -1,6 +1,116 @@
local env = require 'argcheck.env'
local utils = require 'argcheck.utils'
local doc = require 'argcheck.doc'
+local ffi = require 'ffi'
+
+ffi.cdef[[
+
+void free(void *ptr);
+void *malloc(size_t size);
+void *realloc(void *ptr, size_t size);
+
+typedef struct argcheck_node_ {
+ char *type;
+ int checkidx;
+ int outidx;
+ int n; /* # of next */
+ struct argcheck_node_ **next;
+} argcheck_node;
+
+]]
+
+local ACN = {}
+ACN.__index = ACN
+
+function ACN.new(typename, checkidx, outidx)
+ assert(typename)
+ local self = ffi.cast('argcheck_node*', ffi.C.malloc(ffi.sizeof('argcheck_node')))
+ self.type = ffi.cast('char*', ffi.C.malloc(#typename+1))
+ ffi.copy(self.type, typename, #typename)
+ self.type[#typename] = 0
+ self.checkidx = checkidx or 0
+ self.outidx = outidx or 0
+ self.next = nil
+ self.n = 0
+ return self
+end
+
+function ACN:add(node)
+ assert(node ~= nil)
+ if self.n == 0 then
+ self.next = ffi.cast('argcheck_node**', ffi.C.malloc(ffi.sizeof('argcheck_node*')))
+ else
+ self.next = ffi.cast('argcheck_node**', ffi.C.realloc(self.next, ffi.sizeof('argcheck_node*')*(self.n+1)))
+ end
+ self.next[self.n] = node
+ self.n = self.n + 1
+end
+
+function ACN:free()
+ for n = 0,self.n-1 do
+ self.next[n]:free()
+ end
+ if self.next ~= nil then
+ ffi.C.free(self.next)
+ end
+ ffi.C.free(self.type)
+ ffi.C.free(self)
+end
+
+function ACN:match(tbl)
+ local head = self
+ local nmatched = 0
+ for idx,arg in ipairs(tbl) do
+ local matched = false
+ for n=0,head.n-1 do
+ if ffi.string(head.next[n].type) == arg.type and head.next[n].checkidx == arg.checkidx then
+ head = head.next[n]
+ nmatched = nmatched + 1
+ matched = true
+ break
+ end
+ end
+ if not matched then
+ break
+ end
+ end
+ return head, nmatched
+end
+
+function ACN:addpath(tbl, outidx)
+ local head, n = self:match(tbl)
+ for n=n+1,#tbl do
+ local node = ACN.new(tbl[n].type, tbl[n].checkidx, n == #tbl and outidx or 0)
+ head:add(node)
+ head = node
+ end
+end
+
+function ACN:print(txt)
+ local isroot = not txt
+ txt = txt or {'digraph ACN {'}
+ table.insert(txt, string.format('id%d [label="%s%s" style=filled fillcolor=%s];',
+ tonumber(ffi.cast('intptr_t', self)),
+ ffi.string(self.type),
+ self.checkidx > 0 and string.format('+%d', self.checkidx) or '',
+ self.outidx > 0 and 'red' or 'blue'))
+
+ for n=0,self.n-1 do
+ local next = self.next[n]
+ next:print(txt) -- make sure its id is defined
+ table.insert(txt, string.format('id%d -> id%d;',
+ tonumber(ffi.cast('intptr_t', self)),
+ tonumber(ffi.cast('intptr_t', next))))
+ end
+
+ if isroot then
+ table.insert(txt, '}')
+ txt = table.concat(txt, '\n')
+ return txt
+ end
+end
+
+ffi.metatype('struct argcheck_node_', ACN)
local setupvalue = utils.setupvalue
local getupvalue = utils.getupvalue
@@ -157,12 +267,15 @@ local function generaterules(rules, named, hasordered)
indent = ' '
end
+ local root = ACN.new('ROOT')
+
for optmask=0,2^nopt-1 do
local ruletxt = {}
local assntxt = {}
local defatxt = {}
table.insert(ruletxt, string.format('%s%sif narg == %d', indent, optmask == 0 and '' or 'else', nrule-nopt+countbits(optmask)))
+ local acn_path = {}
local narg = nrule-nopt+countbits(optmask)
local ridx = 1
local aidx = 1
@@ -174,7 +287,7 @@ local function generaterules(rules, named, hasordered)
if rule.default ~= nil or rule.defaulta or rule.defaultf or rule.opt then
optidx = optidx + 1
if bit.band(2^(optidx-1), optmask) == 0 then
- if rule.defaulta then
+ if rule.defaulta then -- this is a special case (must be done after all other initializations)
table.insert(defatxt, string.format('%s arg%d = arg%d', indent, ridx, argname2idx(rules, rule.defaulta)))
end
skiprule = true
@@ -182,6 +295,7 @@ local function generaterules(rules, named, hasordered)
end
if not skiprule then
+ table.insert(acn_path, {type=rule.type or '', checkidx=rule.check and ridx or 0})
local checktxt
if rule.opt and rule.type then
checktxt = string.format('(istype(%s, "%s") or istype(%s, "nil"))', rule2arg(rule, aidx, named), rule.type, rule2arg(rule, aidx, named))
@@ -208,6 +322,7 @@ local function generaterules(rules, named, hasordered)
ridx = ridx + 1
end
table.insert(txt, table.concat(ruletxt, ' and ') .. ' then')
+ root:addpath(acn_path, 1)
if #assntxt > 0 then
table.insert(txt, table.concat(assntxt, '\n'))
end
@@ -215,7 +330,15 @@ local function generaterules(rules, named, hasordered)
table.insert(txt, table.concat(defatxt, '\n'))
end
end
+
+ local stuff = root:print()
+ f = io.open('zozo.dot', 'w')
+ f:write(stuff)
+ f:close()
+ print(stuff)
+
return table.concat(txt, '\n')
+
end
local function argcheck(rules)