Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/argcheck.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRonan Collobert <ronan@collobert.com>2014-11-10 09:54:12 +0300
committerRonan Collobert <ronan@collobert.com>2014-11-10 09:54:12 +0300
commitc2a1f10a270a0a2e738a033f5cbd3704891dedf7 (patch)
tree47cf398bbe09c4a6255db817e1efeab51e6ea82a
parentdc3943a1cabb835d0b5103f4f3edda5599a93974 (diff)
added support for named method calls self:func{args} (i.e. func(self, {args}))
-rw-r--r--graph.lua109
-rw-r--r--test/test.lua17
2 files changed, 102 insertions, 24 deletions
diff --git a/graph.lua b/graph.lua
index 5e4d9bd..3d23eda 100644
--- a/graph.lua
+++ b/graph.lua
@@ -27,6 +27,8 @@ local function rules2maskedrules(rules, rulesmask, rulestype)
rule.__ridx = ridx
if rulestype == 'O' then
rule.name = nil
+ elseif rulestype == 'M' and ridx == 1 then -- self?
+ rule.name = nil
end
local rulemask = rulesmask:sub(ridx,ridx)
@@ -49,6 +51,8 @@ local function rules2defaultrules(rules, rulesmask, rulestype)
rule.__ridx = ridx
if rulestype == 'O' then
rule.name = nil
+ elseif rulestype == 'M' and ridx == 1 then -- self?
+ rule.name = nil
end
local rulemask = rulesmask:sub(ridx,ridx)
@@ -128,6 +132,10 @@ function ACN:addpath(rules, rulesmask, rulestype) -- 'O', 'N', 'M'
table.insert(maskedrules, 1, {type='table'})
end
+ if rulestype == 'M' then
+ table.insert(maskedrules, 2, {type='table'})
+ end
+
local head, idx = self:match(maskedrules)
if idx == #maskedrules then
@@ -150,6 +158,14 @@ function ACN:addpath(rules, rulesmask, rulestype) -- 'O', 'N', 'M'
head:add(node)
head = node
end
+
+ -- special trick: mark self
+ if rulestype == 'M' then
+ local head, idx = self:match({maskedrules[1]}) -- find self
+ assert(idx == 1, 'internal bug, please report')
+ head.isself = true
+ end
+
end
function ACN:id()
@@ -160,9 +176,10 @@ function ACN:print(txt)
local isroot = not txt
txt = txt or {'digraph ACN {'}
table.insert(txt, 'edge [penwidth=.3 arrowsize=0.8];')
- table.insert(txt, string.format('id%s [label="%s%s%s" penwidth=.1 fontsize=10 style=filled fillcolor="%s"];',
+ table.insert(txt, string.format('id%s [label="%s%s%s%s" penwidth=.1 fontsize=10 style=filled fillcolor="%s"];',
self:id(),
self.type,
+ self.isself and '*' or '',
self.check and ' <check>' or '',
self.name and string.format(' (%s)', self.name) or '',
self.rules and '#aaaaaa' or '#eeeeee'))
@@ -185,21 +202,17 @@ end
function ACN:generate_ordered_or_named(code, upvalues, rulestype, depth)
depth = depth or 0
+ -- no need to go deeper if no rules found later
if not self:hasruletype(rulestype) then
return
end
- if depth == 0 then
- if rulestype == 'N' then
- table.insert(code, ' if narg == 1 and istype(select(1, ...), "table") then')
- table.insert(code, ' local args = select(1, ...)')
- table.insert(code, ' local narg = 0')
- table.insert(code, ' for k,v in pairs(args) do')
- table.insert(code, ' narg = narg + 1')
- table.insert(code, ' end')
- end
- else
- local argname = rulestype == 'N' and string.format('args.%s', self.name) or string.format('select(%d, ...)', depth)
+ if depth > 0 then
+ local argname =
+ (rulestype == 'N' or rulestype == 'M')
+ and string.format('args.%s', self.name)
+ or string.format('select(%d, ...)', depth)
+
if self.check then
upvalues[string.format('check%s', func2id(self.check))] = self.check
end
@@ -213,9 +226,17 @@ function ACN:generate_ordered_or_named(code, upvalues, rulestype, depth)
if self.rules and self.rulestype == rulestype then
local rules = self.rules
+ local rulesmask = self.rulesmask
local id = table2id(rules)
table.insert(code, string.format(' %sif narg == %d then', string.rep(' ', depth), depth))
+ -- 'M' case (method: first arg is self)
+ if rulestype == 'M' then
+ rules = utils.duptable(self.rules)
+ table.remove(rules, 1) -- remove self
+ rulesmask = rulesmask:sub(2)
+ end
+
-- func args
local argcode = {}
for ridx, rule in ipairs(rules) do
@@ -227,10 +248,11 @@ function ACN:generate_ordered_or_named(code, upvalues, rulestype, depth)
end
-- passed arguments
- local maskedrules = rules2maskedrules(rules, self.rulesmask, self.rulestype)
+ local maskedrules = rules2maskedrules(rules, rulesmask)-- no, rulestype)
for argidx, rule in ipairs(maskedrules) do
- local argname = rulestype == 'N'
+ local argname =
+ (rulestype == 'N' or rulestype == 'M')
and string.format('args.%s', rule.name)
or string.format('select(%d, ...)', argidx)
@@ -241,7 +263,7 @@ function ACN:generate_ordered_or_named(code, upvalues, rulestype, depth)
end
-- default arguments
- local defaultrules = rules2defaultrules(rules, self.rulesmask, self.rulestype)
+ local defaultrules = rules2defaultrules(rules, rulesmask)--no, rulestype)
local defacode = {}
for _, rule in ipairs(defaultrules) do
if rule.default ~= nil then
@@ -260,10 +282,21 @@ function ACN:generate_ordered_or_named(code, upvalues, rulestype, depth)
if #defacode > 0 then
table.insert(code, table.concat(defacode, '\n'))
end
- argcode = table.concat(argcode, ', ')
+
if rules.pack then
- argcode = string.format('{%s}', argcode)
+ argcode = table.concat(argcode, ', ')
+ if rulestype == 'M' then
+ argcode = string.format('self, {%s}', argcode)
+ else
+ argcode = string.format('{%s}', argcode)
+ end
+ else
+ if rulestype == 'M' then
+ table.insert(argcode, 1, 'self')
+ end
+ argcode = table.concat(argcode, ', ')
end
+
if rules.call and not rules.quiet then
argcode = string.format('call%s(%s)', id, argcode)
upvalues[string.format('call%s', id)] = rules.call
@@ -284,11 +317,7 @@ function ACN:generate_ordered_or_named(code, upvalues, rulestype, depth)
self.next[i]:generate_ordered_or_named(code, upvalues, rulestype, depth+1)
end
- if depth == 0 then
- if rulestype == 'N' then
- table.insert(code, ' end')
- end
- else
+ if depth > 0 then
table.insert(code, string.format('%send', string.rep(' ', depth)))
end
@@ -321,9 +350,41 @@ function ACN:generate(upvalues)
table.insert(code, ' local narg = select("#", ...)')
self:generate_ordered_or_named(code, upvalues, 'O')
- local selfnamed = self:match({{type='table'}})
- if selfnamed ~= self then -- is there any named?
+ if self:hasruletype('N') then -- is there any named?
+ local selfnamed = self:match({{type='table'}})
+ assert(selfnamed ~= self, 'internal bug, please report')
+ table.insert(code, ' if narg == 1 and istype(select(1, ...), "table") then')
+ table.insert(code, ' local args = select(1, ...)')
+ table.insert(code, ' local narg = 0')
+ table.insert(code, ' for k,v in pairs(args) do')
+ table.insert(code, ' narg = narg + 1')
+ table.insert(code, ' end')
selfnamed:generate_ordered_or_named(code, upvalues, 'N')
+ table.insert(code, ' end')
+ end
+
+ for _,head in ipairs(self.next) do
+ if head.isself then -- named self method
+ local selfnamed = head:match({{type='table'}})
+ assert(selfnamed ~= head, 'internal bug, please report')
+
+ if selfnamed.check then
+ upvalues[string.format('check%s', func2id(selfnamed.check))] = selfnamed.check
+ end
+ table.insert(code,
+ string.format(' if narg == 2 and istype(select(2, ...), "table") and istype(select(1, ...), "%s")%s then',
+ selfnamed.type,
+ selfnamed.check and string.format(' and check%s(select(1, ...))', func2id(self.check)) or '')
+ )
+ table.insert(code, ' local self = select(1, ...)')
+ table.insert(code, ' local args = select(2, ...)')
+ table.insert(code, ' local narg = 0')
+ table.insert(code, ' for k,v in pairs(args) do')
+ table.insert(code, ' narg = narg + 1')
+ table.insert(code, ' end')
+ selfnamed:generate_ordered_or_named(code, upvalues, 'M')
+ table.insert(code, ' end')
+ end
end
for upvaluename, upvalue in pairs(upvalues) do
diff --git a/test/test.lua b/test/test.lua
index 0c546eb..02fe3c9 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -241,4 +241,21 @@ addfive = argcheck{
assert(addfive(5, 'hello') == '5.000000 + 5 = 10.000000 [msg = hello]')
assert(addfive(5) == '5.000000 + 5 = 10.000000 [msg = i know what i am doing]')
+local foobar = {checksum=1234567}
+foobar.addfive = argcheck{
+ {name="self", type="table"},
+ {name="x", type="number"},
+ {name="msg", type="string", default="i know what i am doing"},
+ call =
+ function(self, x, msg) -- called in case of success
+ return string.format('%f + 5 = %f [msg = %s] [self.checksum=%s]', x, x+5, msg, self.checksum)
+ end
+}
+
+assert(foobar:addfive(5, 'paf') == '5.000000 + 5 = 10.000000 [msg = paf] [self.checksum=1234567]')
+assert(foobar:addfive{x=5, msg='paf'} == '5.000000 + 5 = 10.000000 [msg = paf] [self.checksum=1234567]')
+
+assert(foobar:addfive(5) == '5.000000 + 5 = 10.000000 [msg = i know what i am doing] [self.checksum=1234567]')
+assert(foobar:addfive{x=5} == '5.000000 + 5 = 10.000000 [msg = i know what i am doing] [self.checksum=1234567]')
+
print('PASSED')