Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/argcheck.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRonan Collobert <ronan@collobert.com>2014-11-10 07:44:12 +0300
committerRonan Collobert <ronan@collobert.com>2014-11-10 07:44:12 +0300
commitdc3943a1cabb835d0b5103f4f3edda5599a93974 (patch)
treea1543d597520db4ed6ce6ea91be501f4a20eba83
parentb137e9f58c54666e9a8db68b63a0bd65a9f3f881 (diff)
more robust/predictible tree generation
-rw-r--r--graph.lua242
-rw-r--r--init.lua23
-rw-r--r--utils.lua8
3 files changed, 162 insertions, 111 deletions
diff --git a/graph.lua b/graph.lua
index a8f1371..5e4d9bd 100644
--- a/graph.lua
+++ b/graph.lua
@@ -1,4 +1,5 @@
local usage = require 'argcheck.usage'
+local utils = require 'argcheck.utils'
local function argname2idx(rules, name)
for idx, rule in ipairs(rules) do
@@ -19,9 +20,50 @@ local function func2id(func)
return tostring(func):match('0x([^%s]+)')
end
+local function rules2maskedrules(rules, rulesmask, rulestype)
+ local maskedrules = {}
+ for ridx=1,#rulesmask do
+ local rule = utils.duptable(rules[ridx])
+ rule.__ridx = ridx
+ if rulestype == 'O' then
+ rule.name = nil
+ end
+
+ local rulemask = rulesmask:sub(ridx,ridx)
+ if rulemask == '1' then
+ table.insert(maskedrules, rule)
+ elseif rulemask == '2' then
+ elseif rulemask == '3' then
+ rule.type = 'nil'
+ rule.check = nil
+ table.insert(maskedrules, rule)
+ end
+ end
+ return maskedrules
+end
+
+local function rules2defaultrules(rules, rulesmask, rulestype)
+ local defaultrules = {}
+ for ridx=1,#rulesmask do
+ local rule = utils.duptable(rules[ridx])
+ rule.__ridx = ridx
+ if rulestype == 'O' then
+ rule.name = nil
+ end
+
+ local rulemask = rulesmask:sub(ridx,ridx)
+ if rulemask == '1' then
+ elseif rulemask == '2' then
+ table.insert(defaultrules, rule)
+ elseif rulemask == '3' then
+ end
+ end
+ return defaultrules
+end
+
local ACN = {}
-function ACN.new(typename, name, check, rules, rulemask)
+function ACN.new(typename, name, check, rules, rulesmask, rulestype)
assert(typename)
local self = {}
setmetatable(self, {__index=ACN})
@@ -29,7 +71,8 @@ function ACN.new(typename, name, check, rules, rulemask)
self.name = name
self.check = check
self.rules = rules
- self.rulemask = rulemask
+ self.rulesmask = rulesmask
+ self.rulestype = rulestype
self.next = {}
self.n = 0
return self
@@ -40,80 +83,70 @@ function ACN:add(node)
self.n = self.n + 1
end
-function ACN:match(rules, rulemask, named)
+function ACN:match(rules)
local head = self
- local nmatched = 0
- for _,idx in ipairs(rulemask) do
- local isnil
- if idx < 0 then
- idx = -idx
- isnil = true
- end
+ for idx=1,#rules do
local rule = rules[idx]
local matched = false
for n=1,head.n do
- if head.next[n].type == (isnil and 'nil' or rule.type)
- and head.next[n].check == rule.check
- and (not named or (named and head.next[n].name == rule.name)) then
+ if head.next[n].type == rule.type
+ and head.next[n].check == rule.check
+ and head.next[n].name == rule.name then
head = head.next[n]
- nmatched = nmatched + 1
matched = true
break
end
end
if not matched then
- break
+ return head, idx-1
end
end
- return head, nmatched
+ return head, #rules
end
-function ACN:addpath(rules, rulemask, named)
-
- -- check the corner case where one has named
- -- and ordered arguments, and ordered
- -- can take a single table
- if not rules.force and named then -- named
- local noordered = true -- do we have ordered?
- self:apply(function(rules)
- if not rules.noordered then
- noordered = false
- end
- end)
-
- if not noordered then -- if yes then beware
- for n=1,self.n do
- if self.next[n].type == 'table'
- and not self.next[n].check
- and self.next[n].rules then
- error('argcheck rules led to ambiguous situations')
- end
- end
- end
+function ACN:hasruletype(ruletype)
+ local hasruletype
+ self:apply(function(self)
+ if self.rulestype == ruletype then
+ hasruletype = true
+ end
+ end)
+ return hasruletype
+end
+
+function ACN:addpath(rules, rulesmask, rulestype) -- 'O', 'N', 'M'
+ -- DEBUG: on peut aussi imaginer avoir d'abord mis
+ -- les no-named, et ensuite les named!!
+
+ assert(rules)
+ assert(rulesmask)
+ assert(rulestype)
+
+ local maskedrules = rules2maskedrules(rules, rulesmask, rulestype)
+
+ if rulestype == 'N' then
+ table.insert(maskedrules, 1, {type='table'})
end
- local head, n = self:match(rules, rulemask, named)
- if n == #rulemask then
+ local head, idx = self:match(maskedrules)
+
+ if idx == #maskedrules then
-- check we are not overwriting something here
if not rules.force and head.rules and rules ~= head.rules then
error('argcheck rules led to ambiguous situations')
end
head.rules = rules
- head.rulemask = rulemask
+ head.rulesmask = rulesmask
+ head.rulestype = rulestype
end
- for n=n+1,#rulemask do
- local idx = rulemask[n]
- local isnil
- if idx < 0 then
- idx = -idx
- isnil = true
- end
- local rule = rules[idx]
- local node = ACN.new(isnil and 'nil' or rule.type,
- named and rule.name or nil,
- (not isnil) and rule.check or nil, -- nil -> no check at all (beware: not...)
- n == #rulemask and rules or nil,
- n == #rulemask and rulemask or nil)
+ for idx=idx+1,#maskedrules do
+ local rule = maskedrules[idx]
+ local node = ACN.new(rule.type,
+ rule.name,
+ rule.check,
+ idx == #maskedrules and rules or nil,
+ idx == #maskedrules and rulesmask or nil,
+ idx == #maskedrules and rulestype or nil)
head:add(node)
head = node
end
@@ -149,11 +182,15 @@ function ACN:print(txt)
end
end
-function ACN:generate_ordered_or_named(code, upvalues, named, depth)
+function ACN:generate_ordered_or_named(code, upvalues, rulestype, depth)
depth = depth or 0
+ if not self:hasruletype(rulestype) then
+ return
+ end
+
if depth == 0 then
- if named then
+ if rulestype == 'N' then
table.insert(code, ' if narg == 1 and istype(select(1, ...), "table") then')
table.insert(code, ' local args = select(1, ...)')
table.insert(code, ' local narg = 0')
@@ -162,7 +199,7 @@ function ACN:generate_ordered_or_named(code, upvalues, named, depth)
table.insert(code, ' end')
end
else
- local argname = named and string.format('args.%s', self.name) or string.format('select(%d, ...)', depth)
+ local argname = rulestype == 'N' and string.format('args.%s', self.name) or string.format('select(%d, ...)', depth)
if self.check then
upvalues[string.format('check%s', func2id(self.check))] = self.check
end
@@ -174,45 +211,52 @@ function ACN:generate_ordered_or_named(code, upvalues, named, depth)
self.check and string.format(' and check%s(%s)', func2id(self.check), argname) or ''))
end
- if self.rules then
+ if self.rules and self.rulestype == rulestype then
local rules = self.rules
local id = table2id(rules)
table.insert(code, string.format(' %sif narg == %d then', string.rep(' ', depth), depth))
+
+ -- func args
local argcode = {}
- local defacode = {}
for ridx, rule in ipairs(rules) do
if rules.pack then
table.insert(argcode, string.format('%s=arg%d', rule.name, ridx))
else
table.insert(argcode, string.format('arg%d', ridx))
end
+ end
- local argidx
- for i=1,#self.rulemask do -- DEBUG: bourrin
- if ridx == math.abs(self.rulemask[i]) then
- argidx = i
- break
- end
- end
- if argidx then
- table.insert(code, string.format(' %slocal arg%d = %s',
- string.rep(' ', depth),
- ridx,
- named and string.format('args.%s', rules[ridx].name) or string.format('select(%d, ...)', argidx)))
- else
- if rule.default ~= nil then
- table.insert(code, string.format(' %slocal arg%d = arg%s_%dd', string.rep(' ', depth), ridx, id, ridx))
- upvalues[string.format('arg%s_%dd', id, ridx)] = rule.default
- elseif rule.defaultf then
- table.insert(code, string.format(' %slocal arg%d = arg%s_%df()', string.rep(' ', depth), ridx, id, ridx))
- upvalues[string.format('arg%s_%df', id, ridx)] = rule.defaultf
- elseif rule.opt then
- table.insert(code, string.format(' %slocal arg%d', string.rep(' ', depth), ridx))
- elseif rule.defaulta then
- table.insert(defacode, string.format(' %slocal arg%d = arg%d', string.rep(' ', depth), ridx, argname2idx(rules, rule.defaulta)))
- end
+ -- passed arguments
+ local maskedrules = rules2maskedrules(rules, self.rulesmask, self.rulestype)
+ for argidx, rule in ipairs(maskedrules) do
+
+ local argname = rulestype == 'N'
+ and string.format('args.%s', rule.name)
+ or string.format('select(%d, ...)', argidx)
+
+ table.insert(code, string.format(' %slocal arg%d = %s',
+ string.rep(' ', depth),
+ rule.__ridx,
+ argname))
+ end
+
+ -- default arguments
+ local defaultrules = rules2defaultrules(rules, self.rulesmask, self.rulestype)
+ local defacode = {}
+ for _, rule in ipairs(defaultrules) do
+ if rule.default ~= nil then
+ table.insert(code, string.format(' %slocal arg%d = arg%s_%dd', string.rep(' ', depth), rule.__ridx, id, rule.__ridx))
+ upvalues[string.format('arg%s_%dd', id, rule.__ridx)] = rule.default
+ elseif rule.defaultf then
+ table.insert(code, string.format(' %slocal arg%d = arg%s_%df()', string.rep(' ', depth), rule.__ridx, id, rule.__ridx))
+ upvalues[string.format('arg%s_%df', id, rule.__ridx)] = rule.defaultf
+ elseif rule.opt then
+ table.insert(code, string.format(' %slocal arg%d', string.rep(' ', depth), rule.__ridx))
+ elseif rule.defaulta then
+ table.insert(defacode, string.format(' %slocal arg%d = arg%d', string.rep(' ', depth), rule.__ridx, argname2idx(rules, rule.defaulta)))
end
end
+
if #defacode > 0 then
table.insert(code, table.concat(defacode, '\n'))
end
@@ -237,14 +281,11 @@ function ACN:generate_ordered_or_named(code, upvalues, named, depth)
end
for i=1,self.n do
- if (named and self.next[i].name)
- or (not named and not self.next[i].name) then
- self.next[i]:generate_ordered_or_named(code, upvalues, named, depth+1)
- end
+ self.next[i]:generate_ordered_or_named(code, upvalues, rulestype, depth+1)
end
if depth == 0 then
- if named then
+ if rulestype == 'N' then
table.insert(code, ' end')
end
else
@@ -254,9 +295,7 @@ function ACN:generate_ordered_or_named(code, upvalues, named, depth)
end
function ACN:apply(func)
- if self.rules then
- func(self.rules)
- end
+ func(self)
for i=1,self.n do
self.next[i]:apply(func)
end
@@ -266,10 +305,10 @@ function ACN:usage()
local txt = {}
local history = {}
self:apply(
- function(rules)
- if not history[rules] then
- history[rules] = true
- table.insert(txt, usage(rules))
+ function(self)
+ if self.rules and not history[self.rules] then
+ history[self.rules] = true
+ table.insert(txt, usage(self.rules))
end
end)
return table.concat(txt, '\n\nor\n\n')
@@ -280,8 +319,13 @@ function ACN:generate(upvalues)
local code = {}
table.insert(code, 'return function(...)')
table.insert(code, ' local narg = select("#", ...)')
- self:generate_ordered_or_named(code, upvalues, false)
- self:generate_ordered_or_named(code, upvalues, true)
+ self:generate_ordered_or_named(code, upvalues, 'O')
+
+ local selfnamed = self:match({{type='table'}})
+ if selfnamed ~= self then -- is there any named?
+ selfnamed:generate_ordered_or_named(code, upvalues, 'N')
+ end
+
for upvaluename, upvalue in pairs(upvalues) do
table.insert(code, 1, string.format('local %s', upvaluename))
end
@@ -290,8 +334,8 @@ function ACN:generate(upvalues)
local quiet = true
self:apply(
- function(rules)
- if not rules.quiet then
+ function(self)
+ if self.rules and not self.rules.quiet then
quiet = false
end
end
diff --git a/init.lua b/init.lua
index 73d100d..b6fe006 100644
--- a/init.lua
+++ b/init.lua
@@ -8,7 +8,6 @@ local getupvalue = utils.getupvalue
local loadstring = loadstring or load
local function generaterules(rules)
-
local graph
if rules.chain or rules.overload then
local status
@@ -39,27 +38,27 @@ local function generaterules(rules)
nvariant = nvariant * optperrule[ridx]
end
+ -- note: we keep the original rules (id) for all path variants
+ -- hence, the mask.
for variant=1,nvariant do
local r = variant
- local rulemask = {}
+ local rulemask = {} -- 1/2/3 means present/not present/opt
for ridx=1,#rules do
- local f = math.floor((r-1)/optperrulestride[ridx]) + 1
- if f == 1 then -- here
- table.insert(rulemask, ridx)
- elseif f == 2 then -- not here
- elseif f == 3 then -- opt
- table.insert(rulemask, -ridx)
- end
+ table.insert(rulemask, math.floor((r-1)/optperrulestride[ridx]) + 1)
r = (r-1) % optperrulestride[ridx] + 1
- local rule = rules[ridx]
end
+ rulemask = table.concat(rulemask)
if not rules.noordered then
- graph:addpath(rules, rulemask)
+ graph:addpath(rules, rulemask, 'O')
end
if not rules.nonamed then
- graph:addpath(rules, rulemask, true)
+ if rules[1].name == 'self' then
+ graph:addpath(rules, rulemask, 'M')
+ else
+ graph:addpath(rules, rulemask, 'N')
+ end
end
end
diff --git a/utils.lua b/utils.lua
index 2a7c2fd..e0a288c 100644
--- a/utils.lua
+++ b/utils.lua
@@ -29,4 +29,12 @@ function utils.getupvalue(func, name, quiet)
end
end
+function utils.duptable(tbl)
+ local dup = {}
+ for k,v in pairs(tbl) do
+ dup[k] = v
+ end
+ return dup
+end
+
return utils