Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cwrap.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRonan Collobert <ronan@collobert.com>2014-02-13 20:40:19 +0400
committerRonan Collobert <ronan@collobert.com>2014-02-14 13:00:20 +0400
commitb2b35a6ba9ce8245030510ffe4b8558d68d309fd (patch)
treec729eadedbc9351644ac74af1aa140c582f55a75
parent28c17710cde9fd9a22359f5e0c6fef47e12edd46 (diff)
repackaged wrap into a standalone cwrap module
-rw-r--r--cinterface.lua308
-rw-r--r--init.lua314
-rw-r--r--types.lua263
3 files changed, 320 insertions, 565 deletions
diff --git a/cinterface.lua b/cinterface.lua
new file mode 100644
index 0000000..d6c26bb
--- /dev/null
+++ b/cinterface.lua
@@ -0,0 +1,308 @@
+local CInterface = {}
+
+function CInterface.new()
+ self = {}
+ self.txt = {}
+ self.registry = {}
+ setmetatable(self, {__index=CInterface})
+ return self
+end
+
+function CInterface:luaname2wrapname(name)
+ return string.format("wrapper_%s", name)
+end
+
+function CInterface:print(str)
+ table.insert(self.txt, str)
+end
+
+function CInterface:wrap(luaname, ...)
+ local txt = self.txt
+ local varargs = {...}
+
+ assert(#varargs > 0 and #varargs % 2 == 0, 'must provide both the C function name and the corresponding arguments')
+
+ -- add function to the registry
+ table.insert(self.registry, {name=luaname, wrapname=self:luaname2wrapname(luaname)})
+
+ table.insert(txt, string.format("static int %s(lua_State *L)", self:luaname2wrapname(luaname)))
+ table.insert(txt, "{")
+ table.insert(txt, "int narg = lua_gettop(L);")
+
+ if #varargs == 2 then
+ local cfuncname = varargs[1]
+ local args = varargs[2]
+
+ local helpargs, cargs, argcreturned = self:__writeheaders(txt, args)
+ self:__writechecks(txt, args)
+
+ table.insert(txt, 'else')
+ table.insert(txt, string.format('luaL_error(L, "expected arguments: %s");', table.concat(helpargs, ' ')))
+
+ self:__writecall(txt, args, cfuncname, cargs, argcreturned)
+ else
+ local allcfuncname = {}
+ local allargs = {}
+ local allhelpargs = {}
+ local allcargs = {}
+ local allargcreturned = {}
+
+ table.insert(txt, "int argset = 0;")
+
+ for k=1,#varargs/2 do
+ allcfuncname[k] = varargs[(k-1)*2+1]
+ allargs[k] = varargs[(k-1)*2+2]
+ end
+
+ local argoffset = 0
+ for k=1,#varargs/2 do
+ allhelpargs[k], allcargs[k], allargcreturned[k] = self:__writeheaders(txt, allargs[k], argoffset)
+ argoffset = argoffset + #allargs[k]
+ end
+
+ for k=1,#varargs/2 do
+ self:__writechecks(txt, allargs[k], k)
+ end
+
+ table.insert(txt, 'else')
+ local allconcathelpargs = {}
+ for k=1,#varargs/2 do
+ table.insert(allconcathelpargs, table.concat(allhelpargs[k], ' '))
+ end
+ table.insert(txt, string.format('luaL_error(L, "expected arguments: %s");', table.concat(allconcathelpargs, ' | ')))
+
+ for k=1,#varargs/2 do
+ if k == 1 then
+ table.insert(txt, string.format('if(argset == %d)', k))
+ else
+ table.insert(txt, string.format('else if(argset == %d)', k))
+ end
+ table.insert(txt, '{')
+ self:__writecall(txt, allargs[k], allcfuncname[k], allcargs[k], allargcreturned[k])
+ table.insert(txt, '}')
+ end
+
+ table.insert(txt, 'return 0;')
+ end
+
+ table.insert(txt, '}')
+ table.insert(txt, '')
+end
+
+function CInterface:register(name)
+ local txt = self.txt
+ table.insert(txt, string.format('static const struct luaL_Reg %s [] = {', name))
+ for _,reg in ipairs(self.registry) do
+ table.insert(txt, string.format('{"%s", %s},', reg.name, reg.wrapname))
+ end
+ table.insert(txt, '{NULL, NULL}')
+ table.insert(txt, '};')
+ table.insert(txt, '')
+ self.registry = {}
+end
+
+function CInterface:clearhistory()
+ self.txt = {}
+ self.registry = {}
+end
+
+function CInterface:tostring()
+ return table.concat(self.txt, '\n')
+end
+
+function CInterface:tofile(filename)
+ local f = io.open(filename, 'w')
+ f:write(table.concat(self.txt, '\n'))
+ f:close()
+end
+
+local function bit(p)
+ return 2 ^ (p - 1) -- 1-based indexing
+end
+
+local function hasbit(x, p)
+ return x % (p + p) >= p
+end
+
+local function beautify(txt)
+ local indent = 0
+ for i=1,#txt do
+ if txt[i]:match('}') then
+ indent = indent - 2
+ end
+ if indent > 0 then
+ txt[i] = string.rep(' ', indent) .. txt[i]
+ end
+ if txt[i]:match('{') then
+ indent = indent + 2
+ end
+ end
+end
+
+local function tableinsertcheck(tbl, stuff)
+ if stuff and not stuff:match('^%s*$') then
+ table.insert(tbl, stuff)
+ end
+end
+
+function CInterface:__writeheaders(txt, args, argoffset)
+ local argtypes = self.argtypes
+ local helpargs = {}
+ local cargs = {}
+ local argcreturned
+ argoffset = argoffset or 0
+
+ for i,arg in ipairs(args) do
+ arg.i = i+argoffset
+ arg.args = args -- in case we want to do stuff depending on other args
+ assert(argtypes[arg.name], 'unknown type ' .. arg.name)
+ setmetatable(arg, {__index=argtypes[arg.name]})
+ arg.__metatable = argtypes[arg.name]
+ tableinsertcheck(txt, arg:declare())
+ local helpname = arg:helpname()
+ if arg.returned then
+ helpname = string.format('*%s*', helpname)
+ end
+ if arg.invisible and arg.default == nil then
+ error('Invisible arguments must have a default! How could I guess how to initialize it?')
+ end
+ if arg.default ~= nil then
+ if not arg.invisible then
+ table.insert(helpargs, string.format('[%s]', helpname))
+ end
+ elseif not arg.creturned then
+ table.insert(helpargs, helpname)
+ end
+ if arg.creturned then
+ if argcreturned then
+ error('A C function can only return one argument!')
+ end
+ if arg.default ~= nil then
+ error('Obviously, an "argument" returned by a C function cannot have a default value')
+ end
+ if arg.returned then
+ error('Options "returned" and "creturned" are incompatible')
+ end
+ argcreturned = arg
+ else
+ table.insert(cargs, arg:carg())
+ end
+ end
+
+ return helpargs, cargs, argcreturned
+end
+
+function CInterface:__writechecks(txt, args, argset)
+ local argtypes = self.argtypes
+
+ local multiargset = argset
+ argset = argset or 1
+
+ local nopt = 0
+ for i,arg in ipairs(args) do
+ if arg.default ~= nil and not arg.invisible then
+ nopt = nopt + 1
+ end
+ end
+
+ for variant=0,math.pow(2, nopt)-1 do
+ local opt = 0
+ local currentargs = {}
+ local optargs = {}
+ local hasvararg = false
+ for i,arg in ipairs(args) do
+ if arg.invisible then
+ table.insert(optargs, arg)
+ elseif arg.default ~= nil then
+ opt = opt + 1
+ if hasbit(variant, bit(opt)) then
+ table.insert(currentargs, arg)
+ else
+ table.insert(optargs, arg)
+ end
+ elseif not arg.creturned then
+ table.insert(currentargs, arg)
+ end
+ end
+
+ for _,arg in ipairs(args) do
+ if arg.vararg then
+ if hasvararg then
+ error('Only one argument can be a "vararg"!')
+ end
+ hasvararg = true
+ end
+ end
+
+ if hasvararg and not currentargs[#currentargs].vararg then
+ error('Only the last argument can be a "vararg"')
+ end
+
+ local compop
+ if hasvararg then
+ compop = '>='
+ else
+ compop = '=='
+ end
+
+ if variant == 0 and argset == 1 then
+ table.insert(txt, string.format('if(narg %s %d', compop, #currentargs))
+ else
+ table.insert(txt, string.format('else if(narg %s %d', compop, #currentargs))
+ end
+
+ for stackidx, arg in ipairs(currentargs) do
+ table.insert(txt, string.format("&& %s", arg:check(stackidx)))
+ end
+ table.insert(txt, ')')
+ table.insert(txt, '{')
+
+ if multiargset then
+ table.insert(txt, string.format('argset = %d;', argset))
+ end
+
+ for stackidx, arg in ipairs(currentargs) do
+ tableinsertcheck(txt, arg:read(stackidx))
+ end
+
+ for _,arg in ipairs(optargs) do
+ tableinsertcheck(txt, arg:init())
+ end
+
+ table.insert(txt, '}')
+
+ end
+end
+
+function CInterface:__writecall(txt, args, cfuncname, cargs, argcreturned)
+ local argtypes = self.argtypes
+
+ for _,arg in ipairs(args) do
+ tableinsertcheck(txt, arg:precall())
+ end
+
+ if argcreturned then
+ table.insert(txt, string.format('%s = %s(%s);', argtypes[argcreturned.name].creturn(argcreturned), cfuncname, table.concat(cargs, ',')))
+ else
+ table.insert(txt, string.format('%s(%s);', cfuncname, table.concat(cargs, ',')))
+ end
+
+ for _,arg in ipairs(args) do
+ tableinsertcheck(txt, arg:postcall())
+ end
+
+ local nret = 0
+ if argcreturned then
+ nret = nret + 1
+ end
+ for _,arg in ipairs(args) do
+ if arg.returned then
+ nret = nret + 1
+ end
+ end
+ table.insert(txt, string.format('return %d;', nret))
+end
+
+return CInterface
+
+
diff --git a/init.lua b/init.lua
index bd4199d..dabe086 100644
--- a/init.lua
+++ b/init.lua
@@ -1,311 +1,7 @@
-wrap = {}
+local cwrap = {}
-dofile(debug.getinfo(1).source:gsub('init%.lua$', 'types.lua'):gsub('^@', ''))
-
-local CInterface = {}
-wrap.CInterface = CInterface
-
-function CInterface.new()
- self = {}
- self.txt = {}
- self.registry = {}
- self.argtypes = wrap.argtypes
- setmetatable(self, {__index=CInterface})
- return self
-end
-
-function CInterface:luaname2wrapname(name)
- return string.format("wrapper_%s", name)
-end
-
-function CInterface:print(str)
- table.insert(self.txt, str)
-end
-
-function CInterface:wrap(luaname, ...)
- local txt = self.txt
- local varargs = {...}
-
- assert(#varargs > 0 and #varargs % 2 == 0, 'must provide both the C function name and the corresponding arguments')
-
- -- add function to the registry
- table.insert(self.registry, {name=luaname, wrapname=self:luaname2wrapname(luaname)})
-
- table.insert(txt, string.format("static int %s(lua_State *L)", self:luaname2wrapname(luaname)))
- table.insert(txt, "{")
- table.insert(txt, "int narg = lua_gettop(L);")
-
- if #varargs == 2 then
- local cfuncname = varargs[1]
- local args = varargs[2]
-
- local helpargs, cargs, argcreturned = self:__writeheaders(txt, args)
- self:__writechecks(txt, args)
-
- table.insert(txt, 'else')
- table.insert(txt, string.format('luaL_error(L, "expected arguments: %s");', table.concat(helpargs, ' ')))
-
- self:__writecall(txt, args, cfuncname, cargs, argcreturned)
- else
- local allcfuncname = {}
- local allargs = {}
- local allhelpargs = {}
- local allcargs = {}
- local allargcreturned = {}
-
- table.insert(txt, "int argset = 0;")
-
- for k=1,#varargs/2 do
- allcfuncname[k] = varargs[(k-1)*2+1]
- allargs[k] = varargs[(k-1)*2+2]
- end
-
- local argoffset = 0
- for k=1,#varargs/2 do
- allhelpargs[k], allcargs[k], allargcreturned[k] = self:__writeheaders(txt, allargs[k], argoffset)
- argoffset = argoffset + #allargs[k]
- end
-
- for k=1,#varargs/2 do
- self:__writechecks(txt, allargs[k], k)
- end
-
- table.insert(txt, 'else')
- local allconcathelpargs = {}
- for k=1,#varargs/2 do
- table.insert(allconcathelpargs, table.concat(allhelpargs[k], ' '))
- end
- table.insert(txt, string.format('luaL_error(L, "expected arguments: %s");', table.concat(allconcathelpargs, ' | ')))
-
- for k=1,#varargs/2 do
- if k == 1 then
- table.insert(txt, string.format('if(argset == %d)', k))
- else
- table.insert(txt, string.format('else if(argset == %d)', k))
- end
- table.insert(txt, '{')
- self:__writecall(txt, allargs[k], allcfuncname[k], allcargs[k], allargcreturned[k])
- table.insert(txt, '}')
- end
-
- table.insert(txt, 'return 0;')
- end
-
- table.insert(txt, '}')
- table.insert(txt, '')
-end
-
-function CInterface:register(name)
- local txt = self.txt
- table.insert(txt, string.format('static const struct luaL_Reg %s [] = {', name))
- for _,reg in ipairs(self.registry) do
- table.insert(txt, string.format('{"%s", %s},', reg.name, reg.wrapname))
- end
- table.insert(txt, '{NULL, NULL}')
- table.insert(txt, '};')
- table.insert(txt, '')
- self.registry = {}
-end
-
-function CInterface:clearhistory()
- self.txt = {}
- self.registry = {}
-end
-
-function CInterface:tostring()
- return table.concat(self.txt, '\n')
-end
-
-function CInterface:tofile(filename)
- local f = io.open(filename, 'w')
- f:write(table.concat(self.txt, '\n'))
- f:close()
-end
-
-local function bit(p)
- return 2 ^ (p - 1) -- 1-based indexing
-end
-
-local function hasbit(x, p)
- return x % (p + p) >= p
-end
-
-local function beautify(txt)
- local indent = 0
- for i=1,#txt do
- if txt[i]:match('}') then
- indent = indent - 2
- end
- if indent > 0 then
- txt[i] = string.rep(' ', indent) .. txt[i]
- end
- if txt[i]:match('{') then
- indent = indent + 2
- end
- end
-end
-
-local function tableinsertcheck(tbl, stuff)
- if stuff and not stuff:match('^%s*$') then
- table.insert(tbl, stuff)
- end
-end
-
-function CInterface:__writeheaders(txt, args, argoffset)
- local argtypes = self.argtypes
- local helpargs = {}
- local cargs = {}
- local argcreturned
- argoffset = argoffset or 0
-
- for i,arg in ipairs(args) do
- arg.i = i+argoffset
- arg.args = args -- in case we want to do stuff depending on other args
- assert(argtypes[arg.name], 'unknown type ' .. arg.name)
- setmetatable(arg, {__index=argtypes[arg.name]})
- arg.__metatable = argtypes[arg.name]
- tableinsertcheck(txt, arg:declare())
- local helpname = arg:helpname()
- if arg.returned then
- helpname = string.format('*%s*', helpname)
- end
- if arg.invisible and arg.default == nil then
- error('Invisible arguments must have a default! How could I guess how to initialize it?')
- end
- if arg.default ~= nil then
- if not arg.invisible then
- table.insert(helpargs, string.format('[%s]', helpname))
- end
- elseif not arg.creturned then
- table.insert(helpargs, helpname)
- end
- if arg.creturned then
- if argcreturned then
- error('A C function can only return one argument!')
- end
- if arg.default ~= nil then
- error('Obviously, an "argument" returned by a C function cannot have a default value')
- end
- if arg.returned then
- error('Options "returned" and "creturned" are incompatible')
- end
- argcreturned = arg
- else
- table.insert(cargs, arg:carg())
- end
- end
-
- return helpargs, cargs, argcreturned
-end
-
-function CInterface:__writechecks(txt, args, argset)
- local argtypes = self.argtypes
-
- local multiargset = argset
- argset = argset or 1
-
- local nopt = 0
- for i,arg in ipairs(args) do
- if arg.default ~= nil and not arg.invisible then
- nopt = nopt + 1
- end
- end
-
- for variant=0,math.pow(2, nopt)-1 do
- local opt = 0
- local currentargs = {}
- local optargs = {}
- local hasvararg = false
- for i,arg in ipairs(args) do
- if arg.invisible then
- table.insert(optargs, arg)
- elseif arg.default ~= nil then
- opt = opt + 1
- if hasbit(variant, bit(opt)) then
- table.insert(currentargs, arg)
- else
- table.insert(optargs, arg)
- end
- elseif not arg.creturned then
- table.insert(currentargs, arg)
- end
- end
-
- for _,arg in ipairs(args) do
- if arg.vararg then
- if hasvararg then
- error('Only one argument can be a "vararg"!')
- end
- hasvararg = true
- end
- end
-
- if hasvararg and not currentargs[#currentargs].vararg then
- error('Only the last argument can be a "vararg"')
- end
-
- local compop
- if hasvararg then
- compop = '>='
- else
- compop = '=='
- end
-
- if variant == 0 and argset == 1 then
- table.insert(txt, string.format('if(narg %s %d', compop, #currentargs))
- else
- table.insert(txt, string.format('else if(narg %s %d', compop, #currentargs))
- end
-
- for stackidx, arg in ipairs(currentargs) do
- table.insert(txt, string.format("&& %s", arg:check(stackidx)))
- end
- table.insert(txt, ')')
- table.insert(txt, '{')
-
- if multiargset then
- table.insert(txt, string.format('argset = %d;', argset))
- end
-
- for stackidx, arg in ipairs(currentargs) do
- tableinsertcheck(txt, arg:read(stackidx))
- end
-
- for _,arg in ipairs(optargs) do
- tableinsertcheck(txt, arg:init())
- end
-
- table.insert(txt, '}')
-
- end
-end
-
-function CInterface:__writecall(txt, args, cfuncname, cargs, argcreturned)
- local argtypes = self.argtypes
-
- for _,arg in ipairs(args) do
- tableinsertcheck(txt, arg:precall())
- end
-
- if argcreturned then
- table.insert(txt, string.format('%s = %s(%s);', argtypes[argcreturned.name].creturn(argcreturned), cfuncname, table.concat(cargs, ',')))
- else
- table.insert(txt, string.format('%s(%s);', cfuncname, table.concat(cargs, ',')))
- end
-
- for _,arg in ipairs(args) do
- tableinsertcheck(txt, arg:postcall())
- end
-
- local nret = 0
- if argcreturned then
- nret = nret + 1
- end
- for _,arg in ipairs(args) do
- if arg.returned then
- nret = nret + 1
- end
- end
- table.insert(txt, string.format('return %d;', nret))
-end
+cwrap.types = require 'cwrap.types'
+cwrap.CInterface = require 'cwrap.cinterface'
+cwrap.CInterface.argtypes = cwrap.types
+return cwrap
diff --git a/types.lua b/types.lua
index 059ec12..bc9a900 100644
--- a/types.lua
+++ b/types.lua
@@ -1,255 +1,4 @@
-wrap.argtypes = {}
-
-wrap.argtypes.Tensor = {
-
- helpname = function(arg)
- if arg.dim then
- return string.format("Tensor~%dD", arg.dim)
- else
- return "Tensor"
- end
- end,
-
- declare = function(arg)
- local txt = {}
- table.insert(txt, string.format("THTensor *arg%d = NULL;", arg.i))
- if arg.returned then
- table.insert(txt, string.format("int arg%d_idx = 0;", arg.i));
- end
- return table.concat(txt, '\n')
- end,
-
- check = function(arg, idx)
- if arg.dim then
- return string.format("(arg%d = luaT_toudata(L, %d, torch_Tensor)) && (arg%d->nDimension == %d)", arg.i, idx, arg.i, arg.dim)
- else
- return string.format("(arg%d = luaT_toudata(L, %d, torch_Tensor))", arg.i, idx)
- end
- end,
-
- read = function(arg, idx)
- if arg.returned then
- return string.format("arg%d_idx = %d;", arg.i, idx)
- end
- end,
-
- init = function(arg)
- if type(arg.default) == 'boolean' then
- return string.format('arg%d = THTensor_(new)();', arg.i)
- elseif type(arg.default) == 'number' then
- return string.format('arg%d = %s;', arg.i, arg.args[arg.default]:carg())
- else
- error('unknown default tensor type value')
- end
- end,
-
- carg = function(arg)
- return string.format('arg%d', arg.i)
- end,
-
- creturn = function(arg)
- return string.format('arg%d', arg.i)
- end,
-
- precall = function(arg)
- local txt = {}
- if arg.default and arg.returned then
- table.insert(txt, string.format('if(arg%d_idx)', arg.i)) -- means it was passed as arg
- table.insert(txt, string.format('lua_pushvalue(L, arg%d_idx);', arg.i))
- table.insert(txt, string.format('else'))
- if type(arg.default) == 'boolean' then -- boolean: we did a new()
- table.insert(txt, string.format('luaT_pushudata(L, arg%d, torch_Tensor);', arg.i))
- else -- otherwise: point on default tensor --> retain
- table.insert(txt, string.format('{'))
- table.insert(txt, string.format('THTensor_(retain)(arg%d);', arg.i)) -- so we need a retain
- table.insert(txt, string.format('luaT_pushudata(L, arg%d, torch_Tensor);', arg.i))
- table.insert(txt, string.format('}'))
- end
- elseif arg.default then
- -- we would have to deallocate the beast later if we did a new
- -- unlikely anyways, so i do not support it for now
- if type(arg.default) == 'boolean' then
- error('a tensor cannot be optional if not returned')
- end
- elseif arg.returned then
- table.insert(txt, string.format('lua_pushvalue(L, arg%d_idx);', arg.i))
- end
- return table.concat(txt, '\n')
- end,
-
- postcall = function(arg)
- local txt = {}
- if arg.creturned then
- -- this next line is actually debatable
- table.insert(txt, string.format('THTensor_(retain)(arg%d);', arg.i))
- table.insert(txt, string.format('luaT_pushudata(L, arg%d, torch_Tensor);', arg.i))
- end
- return table.concat(txt, '\n')
- end
-}
-
-wrap.argtypes.IndexTensor = {
-
- helpname = function(arg)
- return "LongTensor"
- end,
-
- declare = function(arg)
- local txt = {}
- table.insert(txt, string.format("THLongTensor *arg%d = NULL;", arg.i))
- if arg.returned then
- table.insert(txt, string.format("int arg%d_idx = 0;", arg.i));
- end
- return table.concat(txt, '\n')
- end,
-
- check = function(arg, idx)
- return string.format('(arg%d = luaT_toudata(L, %d, "torch.LongTensor"))', arg.i, idx)
- end,
-
- read = function(arg, idx)
- local txt = {}
- if not arg.noreadadd then
- table.insert(txt, string.format("THLongTensor_add(arg%d, arg%d, -1);", arg.i, arg.i));
- end
- if arg.returned then
- table.insert(txt, string.format("arg%d_idx = %d;", arg.i, idx))
- end
- return table.concat(txt, '\n')
- end,
-
- init = function(arg)
- return string.format('arg%d = THLongTensor_new();', arg.i)
- end,
-
- carg = function(arg)
- return string.format('arg%d', arg.i)
- end,
-
- creturn = function(arg)
- return string.format('arg%d', arg.i)
- end,
-
- precall = function(arg)
- local txt = {}
- if arg.default and arg.returned then
- table.insert(txt, string.format('if(arg%d_idx)', arg.i)) -- means it was passed as arg
- table.insert(txt, string.format('lua_pushvalue(L, arg%d_idx);', arg.i))
- table.insert(txt, string.format('else')) -- means we did a new()
- table.insert(txt, string.format('luaT_pushudata(L, arg%d, "torch.LongTensor");', arg.i))
- elseif arg.default then
- error('a tensor cannot be optional if not returned')
- elseif arg.returned then
- table.insert(txt, string.format('lua_pushvalue(L, arg%d_idx);', arg.i))
- end
- return table.concat(txt, '\n')
- end,
-
- postcall = function(arg)
- local txt = {}
- if arg.creturned or arg.returned then
- table.insert(txt, string.format("THLongTensor_add(arg%d, arg%d, 1);", arg.i, arg.i));
- end
- if arg.creturned then
- -- this next line is actually debatable
- table.insert(txt, string.format('THLongTensor_retain(arg%d);', arg.i))
- table.insert(txt, string.format('luaT_pushudata(L, arg%d, "torch.LongTensor");', arg.i))
- end
- return table.concat(txt, '\n')
- end
-}
-
-for _,typename in ipairs({"ByteTensor", "CharTensor", "ShortTensor", "IntTensor", "LongTensor",
- "FloatTensor", "DoubleTensor"}) do
-
- wrap.argtypes[typename] = {
-
- helpname = function(arg)
- if arg.dim then
- return string.format('%s~%dD', typename, arg.dim)
- else
- return typename
- end
- end,
-
- declare = function(arg)
- local txt = {}
- table.insert(txt, string.format("TH%s *arg%d = NULL;", typename, arg.i))
- if arg.returned then
- table.insert(txt, string.format("int arg%d_idx = 0;", arg.i));
- end
- return table.concat(txt, '\n')
- end,
-
- check = function(arg, idx)
- if arg.dim then
- return string.format('(arg%d = luaT_toudata(L, %d, "torch.%s")) && (arg%d->nDimension == %d)', arg.i, idx, typename, arg.i, arg.dim)
- else
- return string.format('(arg%d = luaT_toudata(L, %d, "torch.%s"))', arg.i, idx, typename)
- end
- end,
-
- read = function(arg, idx)
- if arg.returned then
- return string.format("arg%d_idx = %d;", arg.i, idx)
- end
- end,
-
- init = function(arg)
- if type(arg.default) == 'boolean' then
- return string.format('arg%d = TH%s_new();', arg.i, typename)
- elseif type(arg.default) == 'number' then
- return string.format('arg%d = %s;', arg.i, arg.args[arg.default]:carg())
- else
- error('unknown default tensor type value')
- end
- end,
-
- carg = function(arg)
- return string.format('arg%d', arg.i)
- end,
-
- creturn = function(arg)
- return string.format('arg%d', arg.i)
- end,
-
- precall = function(arg)
- local txt = {}
- if arg.default and arg.returned then
- table.insert(txt, string.format('if(arg%d_idx)', arg.i)) -- means it was passed as arg
- table.insert(txt, string.format('lua_pushvalue(L, arg%d_idx);', arg.i))
- table.insert(txt, string.format('else'))
- if type(arg.default) == 'boolean' then -- boolean: we did a new()
- table.insert(txt, string.format('luaT_pushudata(L, arg%d, "torch.%s");', arg.i, typename))
- else -- otherwise: point on default tensor --> retain
- table.insert(txt, string.format('{'))
- table.insert(txt, string.format('TH%s_retain(arg%d);', typename, arg.i)) -- so we need a retain
- table.insert(txt, string.format('luaT_pushudata(L, arg%d, "torch.%s");', arg.i, typename))
- table.insert(txt, string.format('}'))
- end
- elseif arg.default then
- -- we would have to deallocate the beast later if we did a new
- -- unlikely anyways, so i do not support it for now
- if type(arg.default) == 'boolean' then
- error('a tensor cannot be optional if not returned')
- end
- elseif arg.returned then
- table.insert(txt, string.format('lua_pushvalue(L, arg%d_idx);', arg.i))
- end
- return table.concat(txt, '\n')
- end,
-
- postcall = function(arg)
- local txt = {}
- if arg.creturned then
- -- this next line is actually debatable
- table.insert(txt, string.format('TH%s_retain(arg%d);', typename, arg.i))
- table.insert(txt, string.format('luaT_pushudata(L, arg%d, "torch.%s");', arg.i, typename))
- end
- return table.concat(txt, '\n')
- end
- }
-end
+local argtypes = {}
local function interpretdefaultvalue(arg)
local default = arg.default
@@ -274,7 +23,7 @@ local function interpretdefaultvalue(arg)
end
end
-wrap.argtypes.index = {
+argtypes.index = {
helpname = function(arg)
return "index"
@@ -326,7 +75,7 @@ wrap.argtypes.index = {
}
for _,typename in ipairs({"real", "unsigned char", "char", "short", "int", "long", "float", "double"}) do
- wrap.argtypes[typename] = {
+ argtypes[typename] = {
helpname = function(arg)
return typename
@@ -378,9 +127,9 @@ for _,typename in ipairs({"real", "unsigned char", "char", "short", "int", "long
}
end
-wrap.argtypes.byte = wrap.argtypes['unsigned char']
+argtypes.byte = argtypes['unsigned char']
-wrap.argtypes.boolean = {
+argtypes.boolean = {
helpname = function(arg)
return "boolean"
@@ -430,3 +179,5 @@ wrap.argtypes.boolean = {
end
end
}
+
+return argtypes