From dc3943a1cabb835d0b5103f4f3edda5599a93974 Mon Sep 17 00:00:00 2001 From: Ronan Collobert Date: Sun, 9 Nov 2014 20:44:12 -0800 Subject: more robust/predictible tree generation --- graph.lua | 242 +++++++++++++++++++++++++++++++++++++------------------------- init.lua | 23 +++--- utils.lua | 8 +++ 3 files changed, 162 insertions(+), 111 deletions(-) diff --git a/graph.lua b/graph.lua index a8f1371..5e4d9bd 100644 --- a/graph.lua +++ b/graph.lua @@ -1,4 +1,5 @@ local usage = require 'argcheck.usage' +local utils = require 'argcheck.utils' local function argname2idx(rules, name) for idx, rule in ipairs(rules) do @@ -19,9 +20,50 @@ local function func2id(func) return tostring(func):match('0x([^%s]+)') end +local function rules2maskedrules(rules, rulesmask, rulestype) + local maskedrules = {} + for ridx=1,#rulesmask do + local rule = utils.duptable(rules[ridx]) + rule.__ridx = ridx + if rulestype == 'O' then + rule.name = nil + end + + local rulemask = rulesmask:sub(ridx,ridx) + if rulemask == '1' then + table.insert(maskedrules, rule) + elseif rulemask == '2' then + elseif rulemask == '3' then + rule.type = 'nil' + rule.check = nil + table.insert(maskedrules, rule) + end + end + return maskedrules +end + +local function rules2defaultrules(rules, rulesmask, rulestype) + local defaultrules = {} + for ridx=1,#rulesmask do + local rule = utils.duptable(rules[ridx]) + rule.__ridx = ridx + if rulestype == 'O' then + rule.name = nil + end + + local rulemask = rulesmask:sub(ridx,ridx) + if rulemask == '1' then + elseif rulemask == '2' then + table.insert(defaultrules, rule) + elseif rulemask == '3' then + end + end + return defaultrules +end + local ACN = {} -function ACN.new(typename, name, check, rules, rulemask) +function ACN.new(typename, name, check, rules, rulesmask, rulestype) assert(typename) local self = {} setmetatable(self, {__index=ACN}) @@ -29,7 +71,8 @@ function ACN.new(typename, name, check, rules, rulemask) self.name = name self.check = check self.rules = rules - self.rulemask = rulemask + self.rulesmask = rulesmask + self.rulestype = rulestype self.next = {} self.n = 0 return self @@ -40,80 +83,70 @@ function ACN:add(node) self.n = self.n + 1 end -function ACN:match(rules, rulemask, named) +function ACN:match(rules) local head = self - local nmatched = 0 - for _,idx in ipairs(rulemask) do - local isnil - if idx < 0 then - idx = -idx - isnil = true - end + for idx=1,#rules do local rule = rules[idx] local matched = false for n=1,head.n do - if head.next[n].type == (isnil and 'nil' or rule.type) - and head.next[n].check == rule.check - and (not named or (named and head.next[n].name == rule.name)) then + if head.next[n].type == rule.type + and head.next[n].check == rule.check + and head.next[n].name == rule.name then head = head.next[n] - nmatched = nmatched + 1 matched = true break end end if not matched then - break + return head, idx-1 end end - return head, nmatched + return head, #rules end -function ACN:addpath(rules, rulemask, named) - - -- check the corner case where one has named - -- and ordered arguments, and ordered - -- can take a single table - if not rules.force and named then -- named - local noordered = true -- do we have ordered? - self:apply(function(rules) - if not rules.noordered then - noordered = false - end - end) - - if not noordered then -- if yes then beware - for n=1,self.n do - if self.next[n].type == 'table' - and not self.next[n].check - and self.next[n].rules then - error('argcheck rules led to ambiguous situations') - end - end - end +function ACN:hasruletype(ruletype) + local hasruletype + self:apply(function(self) + if self.rulestype == ruletype then + hasruletype = true + end + end) + return hasruletype +end + +function ACN:addpath(rules, rulesmask, rulestype) -- 'O', 'N', 'M' + -- DEBUG: on peut aussi imaginer avoir d'abord mis + -- les no-named, et ensuite les named!! + + assert(rules) + assert(rulesmask) + assert(rulestype) + + local maskedrules = rules2maskedrules(rules, rulesmask, rulestype) + + if rulestype == 'N' then + table.insert(maskedrules, 1, {type='table'}) end - local head, n = self:match(rules, rulemask, named) - if n == #rulemask then + local head, idx = self:match(maskedrules) + + if idx == #maskedrules then -- check we are not overwriting something here if not rules.force and head.rules and rules ~= head.rules then error('argcheck rules led to ambiguous situations') end head.rules = rules - head.rulemask = rulemask + head.rulesmask = rulesmask + head.rulestype = rulestype end - for n=n+1,#rulemask do - local idx = rulemask[n] - local isnil - if idx < 0 then - idx = -idx - isnil = true - end - local rule = rules[idx] - local node = ACN.new(isnil and 'nil' or rule.type, - named and rule.name or nil, - (not isnil) and rule.check or nil, -- nil -> no check at all (beware: not...) - n == #rulemask and rules or nil, - n == #rulemask and rulemask or nil) + for idx=idx+1,#maskedrules do + local rule = maskedrules[idx] + local node = ACN.new(rule.type, + rule.name, + rule.check, + idx == #maskedrules and rules or nil, + idx == #maskedrules and rulesmask or nil, + idx == #maskedrules and rulestype or nil) head:add(node) head = node end @@ -149,11 +182,15 @@ function ACN:print(txt) end end -function ACN:generate_ordered_or_named(code, upvalues, named, depth) +function ACN:generate_ordered_or_named(code, upvalues, rulestype, depth) depth = depth or 0 + if not self:hasruletype(rulestype) then + return + end + if depth == 0 then - if named then + if rulestype == 'N' then table.insert(code, ' if narg == 1 and istype(select(1, ...), "table") then') table.insert(code, ' local args = select(1, ...)') table.insert(code, ' local narg = 0') @@ -162,7 +199,7 @@ function ACN:generate_ordered_or_named(code, upvalues, named, depth) table.insert(code, ' end') end else - local argname = named and string.format('args.%s', self.name) or string.format('select(%d, ...)', depth) + local argname = rulestype == 'N' and string.format('args.%s', self.name) or string.format('select(%d, ...)', depth) if self.check then upvalues[string.format('check%s', func2id(self.check))] = self.check end @@ -174,45 +211,52 @@ function ACN:generate_ordered_or_named(code, upvalues, named, depth) self.check and string.format(' and check%s(%s)', func2id(self.check), argname) or '')) end - if self.rules then + if self.rules and self.rulestype == rulestype then local rules = self.rules local id = table2id(rules) table.insert(code, string.format(' %sif narg == %d then', string.rep(' ', depth), depth)) + + -- func args local argcode = {} - local defacode = {} for ridx, rule in ipairs(rules) do if rules.pack then table.insert(argcode, string.format('%s=arg%d', rule.name, ridx)) else table.insert(argcode, string.format('arg%d', ridx)) end + end - local argidx - for i=1,#self.rulemask do -- DEBUG: bourrin - if ridx == math.abs(self.rulemask[i]) then - argidx = i - break - end - end - if argidx then - table.insert(code, string.format(' %slocal arg%d = %s', - string.rep(' ', depth), - ridx, - named and string.format('args.%s', rules[ridx].name) or string.format('select(%d, ...)', argidx))) - else - if rule.default ~= nil then - table.insert(code, string.format(' %slocal arg%d = arg%s_%dd', string.rep(' ', depth), ridx, id, ridx)) - upvalues[string.format('arg%s_%dd', id, ridx)] = rule.default - elseif rule.defaultf then - table.insert(code, string.format(' %slocal arg%d = arg%s_%df()', string.rep(' ', depth), ridx, id, ridx)) - upvalues[string.format('arg%s_%df', id, ridx)] = rule.defaultf - elseif rule.opt then - table.insert(code, string.format(' %slocal arg%d', string.rep(' ', depth), ridx)) - elseif rule.defaulta then - table.insert(defacode, string.format(' %slocal arg%d = arg%d', string.rep(' ', depth), ridx, argname2idx(rules, rule.defaulta))) - end + -- passed arguments + local maskedrules = rules2maskedrules(rules, self.rulesmask, self.rulestype) + for argidx, rule in ipairs(maskedrules) do + + local argname = rulestype == 'N' + and string.format('args.%s', rule.name) + or string.format('select(%d, ...)', argidx) + + table.insert(code, string.format(' %slocal arg%d = %s', + string.rep(' ', depth), + rule.__ridx, + argname)) + end + + -- default arguments + local defaultrules = rules2defaultrules(rules, self.rulesmask, self.rulestype) + local defacode = {} + for _, rule in ipairs(defaultrules) do + if rule.default ~= nil then + table.insert(code, string.format(' %slocal arg%d = arg%s_%dd', string.rep(' ', depth), rule.__ridx, id, rule.__ridx)) + upvalues[string.format('arg%s_%dd', id, rule.__ridx)] = rule.default + elseif rule.defaultf then + table.insert(code, string.format(' %slocal arg%d = arg%s_%df()', string.rep(' ', depth), rule.__ridx, id, rule.__ridx)) + upvalues[string.format('arg%s_%df', id, rule.__ridx)] = rule.defaultf + elseif rule.opt then + table.insert(code, string.format(' %slocal arg%d', string.rep(' ', depth), rule.__ridx)) + elseif rule.defaulta then + table.insert(defacode, string.format(' %slocal arg%d = arg%d', string.rep(' ', depth), rule.__ridx, argname2idx(rules, rule.defaulta))) end end + if #defacode > 0 then table.insert(code, table.concat(defacode, '\n')) end @@ -237,14 +281,11 @@ function ACN:generate_ordered_or_named(code, upvalues, named, depth) end for i=1,self.n do - if (named and self.next[i].name) - or (not named and not self.next[i].name) then - self.next[i]:generate_ordered_or_named(code, upvalues, named, depth+1) - end + self.next[i]:generate_ordered_or_named(code, upvalues, rulestype, depth+1) end if depth == 0 then - if named then + if rulestype == 'N' then table.insert(code, ' end') end else @@ -254,9 +295,7 @@ function ACN:generate_ordered_or_named(code, upvalues, named, depth) end function ACN:apply(func) - if self.rules then - func(self.rules) - end + func(self) for i=1,self.n do self.next[i]:apply(func) end @@ -266,10 +305,10 @@ function ACN:usage() local txt = {} local history = {} self:apply( - function(rules) - if not history[rules] then - history[rules] = true - table.insert(txt, usage(rules)) + function(self) + if self.rules and not history[self.rules] then + history[self.rules] = true + table.insert(txt, usage(self.rules)) end end) return table.concat(txt, '\n\nor\n\n') @@ -280,8 +319,13 @@ function ACN:generate(upvalues) local code = {} table.insert(code, 'return function(...)') table.insert(code, ' local narg = select("#", ...)') - self:generate_ordered_or_named(code, upvalues, false) - self:generate_ordered_or_named(code, upvalues, true) + self:generate_ordered_or_named(code, upvalues, 'O') + + local selfnamed = self:match({{type='table'}}) + if selfnamed ~= self then -- is there any named? + selfnamed:generate_ordered_or_named(code, upvalues, 'N') + end + for upvaluename, upvalue in pairs(upvalues) do table.insert(code, 1, string.format('local %s', upvaluename)) end @@ -290,8 +334,8 @@ function ACN:generate(upvalues) local quiet = true self:apply( - function(rules) - if not rules.quiet then + function(self) + if self.rules and not self.rules.quiet then quiet = false end end diff --git a/init.lua b/init.lua index 73d100d..b6fe006 100644 --- a/init.lua +++ b/init.lua @@ -8,7 +8,6 @@ local getupvalue = utils.getupvalue local loadstring = loadstring or load local function generaterules(rules) - local graph if rules.chain or rules.overload then local status @@ -39,27 +38,27 @@ local function generaterules(rules) nvariant = nvariant * optperrule[ridx] end + -- note: we keep the original rules (id) for all path variants + -- hence, the mask. for variant=1,nvariant do local r = variant - local rulemask = {} + local rulemask = {} -- 1/2/3 means present/not present/opt for ridx=1,#rules do - local f = math.floor((r-1)/optperrulestride[ridx]) + 1 - if f == 1 then -- here - table.insert(rulemask, ridx) - elseif f == 2 then -- not here - elseif f == 3 then -- opt - table.insert(rulemask, -ridx) - end + table.insert(rulemask, math.floor((r-1)/optperrulestride[ridx]) + 1) r = (r-1) % optperrulestride[ridx] + 1 - local rule = rules[ridx] end + rulemask = table.concat(rulemask) if not rules.noordered then - graph:addpath(rules, rulemask) + graph:addpath(rules, rulemask, 'O') end if not rules.nonamed then - graph:addpath(rules, rulemask, true) + if rules[1].name == 'self' then + graph:addpath(rules, rulemask, 'M') + else + graph:addpath(rules, rulemask, 'N') + end end end diff --git a/utils.lua b/utils.lua index 2a7c2fd..e0a288c 100644 --- a/utils.lua +++ b/utils.lua @@ -29,4 +29,12 @@ function utils.getupvalue(func, name, quiet) end end +function utils.duptable(tbl) + local dup = {} + for k,v in pairs(tbl) do + dup[k] = v + end + return dup +end + return utils -- cgit v1.2.3