Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cwrap.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRonan Collobert <ronan@collobert.com>2012-01-25 17:55:20 +0400
committerRonan Collobert <ronan@collobert.com>2012-01-25 17:55:20 +0400
commit096dce92a6dd21f8bf33ebf106f9557d33702e39 (patch)
treef6d7b9ecc3af5d083cce6fdcdec7c0badd9d9e2c
initial revamp of torch7 tree
-rw-r--r--CMakeLists.txt5
-rw-r--r--init.lua353
-rw-r--r--types.lua396
3 files changed, 754 insertions, 0 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
new file mode 100644
index 0000000..f5bda18
--- /dev/null
+++ b/CMakeLists.txt
@@ -0,0 +1,5 @@
+SET(src)
+SET(luasrc init.lua types.lua)
+
+ADD_TORCH_PACKAGE(wrap "${src}" "${luasrc}")
+#ADD_TORCH_DOK(dok gnuplot "Fundamentals" "Plotting with Gnuplot" 1.)
diff --git a/init.lua b/init.lua
new file mode 100644
index 0000000..4553353
--- /dev/null
+++ b/init.lua
@@ -0,0 +1,353 @@
+wrap = {}
+
+dofile(debug.getinfo(1).source:gsub('init%.lua$', 'types.lua'):gsub('^@', ''))
+
+local CInterface = {}
+wrap.CInterface = CInterface
+
+function CInterface.new()
+ self = {}
+ self.txt = {}
+ self.registry = {}
+ self.argtypes = wrap.argtypes
+ setmetatable(self, {__index=CInterface})
+ return self
+end
+
+function CInterface:luaname2wrapname(name)
+ return string.format("wrapper_%s", name)
+end
+
+function CInterface:print(str)
+ table.insert(self.txt, str)
+end
+
+function CInterface:wrap(luaname, ...)
+ local txt = self.txt
+ local varargs = {...}
+
+ assert(#varargs > 0 and #varargs % 2 == 0, 'must provide both the C function name and the corresponding arguments')
+
+ -- add function to the registry
+ table.insert(self.registry, {name=luaname, wrapname=self:luaname2wrapname(luaname)})
+
+ table.insert(txt, string.format("static int %s(lua_State *L)", self:luaname2wrapname(luaname)))
+ table.insert(txt, "{")
+ table.insert(txt, "int narg = lua_gettop(L);")
+
+ if #varargs == 2 then
+ local cfuncname = varargs[1]
+ local args = varargs[2]
+
+ local helpargs, cargs, argcreturned = self:__writeheaders(txt, args)
+ self:__writechecks(txt, args)
+
+ table.insert(txt, 'else')
+ table.insert(txt, string.format('luaL_error(L, "expected arguments: %s");', table.concat(helpargs, ' ')))
+
+ self:__writecall(txt, args, cfuncname, cargs, argcreturned)
+ else
+ local allcfuncname = {}
+ local allargs = {}
+ local allhelpargs = {}
+ local allcargs = {}
+ local allargcreturned = {}
+
+ table.insert(txt, "int argset = 0;")
+
+ for k=1,#varargs/2 do
+ allcfuncname[k] = varargs[(k-1)*2+1]
+ allargs[k] = varargs[(k-1)*2+2]
+ end
+
+ local argoffset = 0
+ for k=1,#varargs/2 do
+ allhelpargs[k], allcargs[k], allargcreturned[k] = self:__writeheaders(txt, allargs[k], argoffset)
+ argoffset = argoffset + #allargs[k]
+ end
+
+ for k=1,#varargs/2 do
+ self:__writechecks(txt, allargs[k], k)
+ end
+
+ table.insert(txt, 'else')
+ local allconcathelpargs = {}
+ for k=1,#varargs/2 do
+ table.insert(allconcathelpargs, table.concat(allhelpargs[k], ' '))
+ end
+ table.insert(txt, string.format('luaL_error(L, "expected arguments: %s");', table.concat(allconcathelpargs, ' | ')))
+
+ for k=1,#varargs/2 do
+ if k == 1 then
+ table.insert(txt, string.format('if(argset == %d)', k))
+ else
+ table.insert(txt, string.format('else if(argset == %d)', k))
+ end
+ table.insert(txt, '{')
+ self:__writecall(txt, allargs[k], allcfuncname[k], allcargs[k], allargcreturned[k])
+ table.insert(txt, '}')
+ end
+
+ table.insert(txt, 'return 0;')
+ end
+
+ table.insert(txt, '}')
+ table.insert(txt, '')
+end
+
+function CInterface:register(name)
+ local txt = self.txt
+ table.insert(txt, string.format('static const struct luaL_Reg %s [] = {', name))
+ for _,reg in ipairs(self.registry) do
+ table.insert(txt, string.format('{"%s", %s},', reg.name, reg.wrapname))
+ end
+ table.insert(txt, '{NULL, NULL}')
+ table.insert(txt, '};')
+ table.insert(txt, '')
+ self.registry = {}
+end
+
+function CInterface:tostdio()
+ return table.concat(self.txt, '\n')
+end
+
+function CInterface:tofile(filename)
+ local f = io.open(filename, 'w')
+ f:write(table.concat(self.txt, '\n'))
+ f:close()
+end
+
+local function bit(p)
+ return 2 ^ (p - 1) -- 1-based indexing
+end
+
+local function hasbit(x, p)
+ return x % (p + p) >= p
+end
+
+local function beautify(txt)
+ local indent = 0
+ for i=1,#txt do
+ if txt[i]:match('}') then
+ indent = indent - 2
+ end
+ if indent > 0 then
+ txt[i] = string.rep(' ', indent) .. txt[i]
+ end
+ if txt[i]:match('{') then
+ indent = indent + 2
+ end
+ end
+end
+
+local function tableinsertcheck(tbl, stuff)
+ if stuff and not stuff:match('^%s*$') then
+ table.insert(tbl, stuff)
+ end
+end
+
+function CInterface:__writeheaders(txt, args, argoffset)
+ local argtypes = self.argtypes
+ local helpargs = {}
+ local cargs = {}
+ local argcreturned
+ argoffset = argoffset or 0
+
+ for i,arg in ipairs(args) do
+ arg.i = i+argoffset
+ arg.args = args -- in case we want to do stuff depending on other args
+ assert(argtypes[arg.name], 'unknown type ' .. arg.name)
+ setmetatable(arg, {__index=argtypes[arg.name]})
+ tableinsertcheck(txt, arg:declare())
+ local helpname = arg:helpname()
+ if arg.returned then
+ helpname = string.format('*%s*', helpname)
+ end
+ if arg.invisible and arg.default == nil then
+ error('Invisible arguments must have a default! How could I guess how to initialize it?')
+ end
+ if arg.default ~= nil then
+ if not arg.invisible then
+ table.insert(helpargs, string.format('[%s]', helpname))
+ end
+ elseif not arg.creturned then
+ table.insert(helpargs, helpname)
+ end
+ if arg.creturned then
+ if argcreturned then
+ error('A C function can only return one argument!')
+ end
+ if arg.default ~= nil then
+ error('Obviously, an "argument" returned by a C function cannot have a default value')
+ end
+ if arg.returned then
+ error('Options "returned" and "creturned" are incompatible')
+ end
+ argcreturned = arg
+ else
+ table.insert(cargs, arg:carg())
+ end
+ end
+
+ for _,arg in ipairs(args) do
+ if arg.userhead then
+ if type(arg.userhead) == 'string' then
+ table.insert(txt, arg.userhead)
+ elseif type(arg.userhead) == 'function' then
+ tableinsertcheck(txt, arg:userhead())
+ else
+ error('userhead must be a string or a function')
+ end
+ end
+ end
+
+ return helpargs, cargs, argcreturned
+end
+
+function CInterface:__writechecks(txt, args, argset)
+ local argtypes = self.argtypes
+
+ local multiargset = argset
+ argset = argset or 1
+
+ local nopt = 0
+ for i,arg in ipairs(args) do
+ if arg.default ~= nil and not arg.invisible then
+ nopt = nopt + 1
+ end
+ end
+
+ for variant=0,math.pow(2, nopt)-1 do
+ local opt = 0
+ local currentargs = {}
+ local optargs = {}
+ local hasvararg = false
+ for i,arg in ipairs(args) do
+ if arg.invisible then
+ table.insert(optargs, arg)
+ elseif arg.default ~= nil then
+ opt = opt + 1
+ if hasbit(variant, bit(opt)) then
+ table.insert(currentargs, arg)
+ else
+ table.insert(optargs, arg)
+ end
+ elseif not arg.creturned then
+ table.insert(currentargs, arg)
+ end
+ end
+
+ for _,arg in ipairs(args) do
+ if arg.vararg then
+ if hasvararg then
+ error('Only one argument can be a "vararg"!')
+ end
+ hasvararg = true
+ end
+ end
+
+ if hasvararg and not currentargs[#currentargs].vararg then
+ error('Only the last argument can be a "vararg"')
+ end
+
+ local compop
+ if hasvararg then
+ compop = '>='
+ else
+ compop = '=='
+ end
+
+ if variant == 0 and argset == 1 then
+ table.insert(txt, string.format('if(narg %s %d', compop, #currentargs))
+ else
+ table.insert(txt, string.format('else if(narg %s %d', compop, #currentargs))
+ end
+
+ for stackidx, arg in ipairs(currentargs) do
+ table.insert(txt, string.format("&& %s", arg:check(stackidx)))
+ end
+ table.insert(txt, ')')
+ table.insert(txt, '{')
+
+ if multiargset then
+ table.insert(txt, string.format('argset = %d;', argset))
+ end
+
+ for stackidx, arg in ipairs(currentargs) do
+ tableinsertcheck(txt, arg:read(stackidx))
+ end
+
+ for _,arg in ipairs(optargs) do
+ tableinsertcheck(txt, arg:init())
+ end
+
+ for _,arg in ipairs(args) do
+ if arg.usercheck then
+ if type(arg.usercheck) == 'string' then
+ table.insert(txt, arg.usercheck)
+ elseif type(arg.usercheck) == 'function' then
+ tableinsertcheck(txt, arg:usercheck())
+ else
+ error('usercheck must be a string or a function')
+ end
+ end
+ end
+
+ table.insert(txt, '}')
+
+ end
+end
+
+function CInterface:__writecall(txt, args, cfuncname, cargs, argcreturned)
+ local argtypes = self.argtypes
+
+ for _,arg in ipairs(args) do
+ tableinsertcheck(txt, arg:precall())
+ end
+
+ for _,arg in ipairs(args) do
+ if arg.userprecall then
+ if type(arg.userprecall) == 'string' then
+ table.insert(txt, arg.userprecall)
+ elseif type(arg.userprecall) == 'function' then
+ tableinsertcheck(txt, arg:userprecall())
+ else
+ error('userprecall must be a string or a function')
+ end
+ end
+ end
+
+ if argcreturned then
+ table.insert(txt, string.format('%s = %s(%s);', argtypes[argcreturned.name].creturn(argcreturned), cfuncname, table.concat(cargs, ',')))
+ else
+ table.insert(txt, string.format('%s(%s);', cfuncname, table.concat(cargs, ',')))
+ end
+
+ for _,arg in ipairs(args) do
+ tableinsertcheck(txt, arg:postcall())
+ end
+
+ for _,arg in ipairs(args) do
+ if arg.userpostcall then
+ if type(arg.userpostcall) == 'string' then
+ table.insert(txt, arg.userpostcall)
+ elseif type(arg.userpostcall) == 'function' then
+ tableinsertcheck(txt, arg:userpostcall())
+ else
+ error('userpostcall must be a string or a function')
+ end
+ end
+ end
+
+ local nret = 0
+ if argcreturned then
+ nret = nret + 1
+ end
+ for _,arg in ipairs(args) do
+ if arg.returned then
+ nret = nret + 1
+ end
+ end
+ table.insert(txt, string.format('return %d;', nret))
+end
+
diff --git a/types.lua b/types.lua
new file mode 100644
index 0000000..3ed227e
--- /dev/null
+++ b/types.lua
@@ -0,0 +1,396 @@
+wrap.argtypes = {}
+
+wrap.argtypes.Tensor = {
+
+ helpname = function(arg)
+ if arg.dim then
+ return string.format("Tensor~%dD", arg.dim)
+ else
+ return "Tensor"
+ end
+ end,
+
+ declare = function(arg)
+ local txt = {}
+ table.insert(txt, string.format("THTensor *arg%d = NULL;", arg.i))
+ if arg.returned then
+ table.insert(txt, string.format("int arg%d_idx = 0;", arg.i));
+ end
+ return table.concat(txt, '\n')
+ end,
+
+ check = function(arg, idx)
+ if arg.dim then
+ return string.format("(arg%d = luaT_toudata(L, %d, torch_(Tensor_id))) && (arg%d->nDimension == %d)", arg.i, idx, arg.i, arg.dim)
+ else
+ return string.format("(arg%d = luaT_toudata(L, %d, torch_(Tensor_id)))", arg.i, idx)
+ end
+ end,
+
+ read = function(arg, idx)
+ if arg.returned then
+ return string.format("arg%d_idx = %d;", arg.i, idx)
+ end
+ end,
+
+ init = function(arg)
+ return string.format('arg%d = TH%s_new();', arg.i, typename)
+ end,
+
+ carg = function(arg, idx)
+ return string.format('arg%d', arg.i)
+ end,
+
+ creturn = function(arg, idx)
+ return string.format('arg%d', arg.i)
+ end,
+
+ precall = function(arg)
+ local txt = {}
+ if arg.default and arg.returned then
+ table.insert(txt, string.format('if(arg%d_idx)', arg.i)) -- means it was passed as arg
+ table.insert(txt, string.format('lua_pushvalue(L, arg%d_idx);', arg.i))
+ table.insert(txt, string.format('else')) -- means we did a new()
+ table.insert(txt, string.format('luaT_pushudata(L, arg%d, torch_(Tensor_id));', arg.i))
+ elseif arg.default then
+ error('a tensor cannot be optional if not returned')
+ elseif arg.returned then
+ table.insert(txt, string.format('lua_pushvalue(L, arg%d_idx);', arg.i))
+ end
+ return table.concat(txt, '\n')
+ end,
+
+ postcall = function(arg)
+ local txt = {}
+ if arg.creturned then
+ -- this next line is actually debatable
+ table.insert(txt, string.format('THTensor_(retain)(arg%d);', arg.i))
+ table.insert(txt, string.format('luaT_pushudata(L, arg%d, torch_(Tensor_id));', arg.i))
+ end
+ return table.concat(txt, '\n')
+ end
+}
+
+wrap.argtypes.IndexTensor = {
+
+ helpname = function(arg)
+ return "LongTensor"
+ end,
+
+ declare = function(arg)
+ local txt = {}
+ table.insert(txt, string.format("THLongTensor *arg%d = NULL;", arg.i))
+ if arg.returned then
+ table.insert(txt, string.format("int arg%d_idx = 0;", arg.i));
+ end
+ return table.concat(txt, '\n')
+ end,
+
+ check = function(arg, idx)
+ return string.format("(arg%d = luaT_toudata(L, %d, torch_LongTensor_id))", arg.i, idx)
+ end,
+
+ read = function(arg, idx)
+ local txt = {}
+ table.insert(txt, string.format("THLongTensor_add(arg%d, arg%d, -1);", arg.i, arg.i));
+ if arg.returned then
+ return table.insert(txt, string.format("arg%d_idx = %d;", arg.i, idx))
+ end
+ return table.concat(txt, '\n')
+ end,
+
+ init = function(arg)
+ return string.format('arg%d = THLongTensor_new();', arg.i)
+ end,
+
+ carg = function(arg, idx)
+ return string.format('arg%d', arg.i)
+ end,
+
+ creturn = function(arg, idx)
+ return string.format('arg%d', arg.i)
+ end,
+
+ precall = function(arg)
+ local txt = {}
+ if arg.default and arg.returned then
+ table.insert(txt, string.format('if(arg%d_idx)', arg.i)) -- means it was passed as arg
+ table.insert(txt, string.format('lua_pushvalue(L, arg%d_idx);', arg.i))
+ table.insert(txt, string.format('else')) -- means we did a new()
+ table.insert(txt, string.format('luaT_pushudata(L, arg%d, torch_LongTensor_id);', arg.i))
+ elseif arg.default then
+ error('a tensor cannot be optional if not returned')
+ elseif arg.returned then
+ table.insert(txt, string.format('lua_pushvalue(L, arg%d_idx);', arg.i))
+ end
+ return table.concat(txt, '\n')
+ end,
+
+ postcall = function(arg)
+ local txt = {}
+ if arg.creturned or arg.returned then
+ table.insert(txt, string.format("THLongTensor_add(arg%d, arg%d, 1);", arg.i, arg.i));
+ end
+ if arg.creturned then
+ -- this next line is actually debatable
+ table.insert(txt, string.format('THLongTensor_retain(arg%d);', arg.i))
+ table.insert(txt, string.format('luaT_pushudata(L, arg%d, torch_LongTensor_id);', arg.i))
+ end
+ return table.concat(txt, '\n')
+ end
+}
+
+for _,typename in ipairs({"ByteTensor", "CharTensor", "ShortTensor", "IntTensor", "LongTensor",
+ "FloatTensor", "DoubleTensor"}) do
+
+ wrap.argtypes[typename] = {
+
+ helpname = function(arg)
+ if arg.dim then
+ return string.format('%s~%dD', typename, arg.dim)
+ else
+ return typename
+ end
+ end,
+
+ declare = function(arg)
+ local txt = {}
+ table.insert(txt, string.format("TH%s *arg%d = NULL;", typename, arg.i))
+ if arg.returned then
+ table.insert(txt, string.format("int arg%d_idx = 0;", arg.i));
+ end
+ return table.concat(txt, '\n')
+ end,
+
+ check = function(arg, idx)
+ if arg.dim then
+ return string.format("(arg%d = luaT_toudata(L, %d, torch_%s_id)) && (arg%d->nDimension == %d)", arg.i, idx, typename, arg.i, arg.dim)
+ else
+ return string.format("(arg%d = luaT_toudata(L, %d, torch_%s_id))", arg.i, idx, typename)
+ end
+ end,
+
+ read = function(arg, idx)
+ if arg.returned then
+ return string.format("arg%d_idx = %d;", arg.i, idx)
+ end
+ end,
+
+ init = function(arg)
+ return string.format('arg%d = TH%s_new();', arg.i, typename)
+ end,
+
+ carg = function(arg, idx)
+ return string.format('arg%d', arg.i)
+ end,
+
+ creturn = function(arg, idx)
+ return string.format('arg%d', arg.i)
+ end,
+
+ precall = function(arg)
+ local txt = {}
+ if arg.default and arg.returned then
+ table.insert(txt, string.format('if(arg%d_idx)', arg.i)) -- means it was passed as arg
+ table.insert(txt, string.format('lua_pushvalue(L, arg%d_idx);', arg.i))
+ table.insert(txt, string.format('else')) -- means we did a new()
+ table.insert(txt, string.format('luaT_pushudata(L, arg%d, torch_%s_id);', arg.i, typename))
+ elseif arg.default then
+ error('a tensor cannot be optional if not returned')
+ elseif arg.returned then
+ table.insert(txt, string.format('lua_pushvalue(L, arg%d_idx);', arg.i))
+ end
+ return table.concat(txt, '\n')
+ end,
+
+ postcall = function(arg)
+ local txt = {}
+ if arg.creturned then
+ -- this next line is actually debatable
+ table.insert(txt, string.format('TH%s_retain(arg%d);', typename, arg.i))
+ table.insert(txt, string.format('luaT_pushudata(L, arg%d, torch_%s_id);', arg.i, typename))
+ end
+ return table.concat(txt, '\n')
+ end
+ }
+end
+
+local function interpretdefaultvalue(arg)
+ local default = arg.default
+ if type(default) == 'boolean' then
+ if default then
+ return '1'
+ else
+ return '0'
+ end
+ elseif type(default) == 'number' then
+ return tostring(default)
+ elseif type(default) == 'string' then
+ return default
+ elseif type(default) == 'function' then
+ default = default(arg)
+ assert(type(default) == 'string', 'a default function must return a string')
+ return default
+ elseif type(default) == 'nil' then
+ return nil
+ else
+ error('unknown default type value')
+ end
+end
+
+wrap.argtypes.index = {
+
+ helpname = function(arg)
+ return "index"
+ end,
+
+ declare = function(arg)
+ -- if it is a number we initialize here
+ local default = tonumber(interpretdefaultvalue(arg)) or 1
+ return string.format("long arg%d = %d;", arg.i, tonumber(default)-1)
+ end,
+
+ check = function(arg, idx)
+ return string.format("lua_isnumber(L, %d)", idx)
+ end,
+
+ read = function(arg, idx)
+ return string.format("arg%d = (long)lua_tonumber(L, %d)-1;", arg.i, idx)
+ end,
+
+ init = function(arg)
+ -- otherwise do it here
+ if arg.default then
+ local default = interpretdefaultvalue(arg)
+ if not tonumber(default) then
+ return string.format("arg%d = %s-1;", arg.i, default)
+ end
+ end
+ end,
+
+ carg = function(arg, idx)
+ return string.format('arg%d', arg.i)
+ end,
+
+ creturn = function(arg, idx)
+ return string.format('arg%d', arg.i)
+ end,
+
+ precall = function(arg)
+ if arg.returned then
+ return string.format('lua_pushnumber(L, (lua_Number)arg%d+1);', arg.i)
+ end
+ end,
+
+ postcall = function(arg)
+ if arg.creturned then
+ return string.format('lua_pushnumber(L, (lua_Number)arg%d+1);', arg.i)
+ end
+ end
+}
+
+for _,typename in ipairs({"real", "unsigned char", "char", "short", "int", "long", "float", "double"}) do
+ wrap.argtypes[typename] = {
+
+ helpname = function(arg)
+ return typename
+ end,
+
+ declare = function(arg)
+ -- if it is a number we initialize here
+ local default = tonumber(interpretdefaultvalue(arg)) or 0
+ return string.format("%s arg%d = %d;", typename, arg.i, tonumber(default))
+ end,
+
+ check = function(arg, idx)
+ return string.format("lua_isnumber(L, %d)", idx)
+ end,
+
+ read = function(arg, idx)
+ return string.format("arg%d = (%s)lua_tonumber(L, %d);", arg.i, typename, idx)
+ end,
+
+ init = function(arg)
+ -- otherwise do it here
+ if arg.default then
+ local default = interpretdefaultvalue(arg)
+ if not tonumber(default) then
+ return string.format("arg%d = %s;", arg.i, default)
+ end
+ end
+ end,
+
+ carg = function(arg, idx)
+ return string.format('arg%d', arg.i)
+ end,
+
+ creturn = function(arg, idx)
+ return string.format('arg%d', arg.i)
+ end,
+
+ precall = function(arg)
+ if arg.returned then
+ return string.format('lua_pushnumber(L, (lua_Number)arg%d);', arg.i)
+ end
+ end,
+
+ postcall = function(arg)
+ if arg.creturned then
+ return string.format('lua_pushnumber(L, (lua_Number)arg%d);', arg.i)
+ end
+ end
+ }
+end
+
+wrap.argtypes.byte = wrap.argtypes['unsigned char']
+
+wrap.argtypes.boolean = {
+
+ helpname = function(arg)
+ return "boolean"
+ end,
+
+ declare = function(arg)
+ -- if it is a number we initialize here
+ local default = tonumber(interpretdefaultvalue(arg)) or 0
+ return string.format("int arg%d = %d;", arg.i, tonumber(default))
+ end,
+
+ check = function(arg, idx)
+ return string.format("lua_isboolean(L, %d)", idx)
+ end,
+
+ read = function(arg, idx)
+ return string.format("arg%d = lua_toboolean(L, %d);", arg.i, idx)
+ end,
+
+ init = function(arg)
+ -- otherwise do it here
+ if arg.default then
+ local default = interpretdefaultvalue(arg)
+ if not tonumber(default) then
+ return string.format("arg%d = %s;", arg.i, default)
+ end
+ end
+ end,
+
+ carg = function(arg, idx)
+ return string.format('arg%d', arg.i)
+ end,
+
+ creturn = function(arg, idx)
+ return string.format('arg%d', arg.i)
+ end,
+
+ precall = function(arg)
+ if arg.returned then
+ return string.format('lua_pushboolean(L, arg%d);', arg.i)
+ end
+ end,
+
+ postcall = function(arg)
+ if arg.creturned then
+ return string.format('lua_pushboolean(L, arg%d);', arg.i)
+ end
+ end
+}