Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cwrap.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRonan Collobert <ronan@collobert.com>2012-01-25 17:55:20 +0400
committerRonan Collobert <ronan@collobert.com>2012-01-25 17:55:20 +0400
commit096dce92a6dd21f8bf33ebf106f9557d33702e39 (patch)
treef6d7b9ecc3af5d083cce6fdcdec7c0badd9d9e2c /types.lua
initial revamp of torch7 tree
Diffstat (limited to 'types.lua')
-rw-r--r--types.lua396
1 files changed, 396 insertions, 0 deletions
diff --git a/types.lua b/types.lua
new file mode 100644
index 0000000..3ed227e
--- /dev/null
+++ b/types.lua
@@ -0,0 +1,396 @@
+wrap.argtypes = {}
+
+wrap.argtypes.Tensor = {
+
+ helpname = function(arg)
+ if arg.dim then
+ return string.format("Tensor~%dD", arg.dim)
+ else
+ return "Tensor"
+ end
+ end,
+
+ declare = function(arg)
+ local txt = {}
+ table.insert(txt, string.format("THTensor *arg%d = NULL;", arg.i))
+ if arg.returned then
+ table.insert(txt, string.format("int arg%d_idx = 0;", arg.i));
+ end
+ return table.concat(txt, '\n')
+ end,
+
+ check = function(arg, idx)
+ if arg.dim then
+ return string.format("(arg%d = luaT_toudata(L, %d, torch_(Tensor_id))) && (arg%d->nDimension == %d)", arg.i, idx, arg.i, arg.dim)
+ else
+ return string.format("(arg%d = luaT_toudata(L, %d, torch_(Tensor_id)))", arg.i, idx)
+ end
+ end,
+
+ read = function(arg, idx)
+ if arg.returned then
+ return string.format("arg%d_idx = %d;", arg.i, idx)
+ end
+ end,
+
+ init = function(arg)
+ return string.format('arg%d = TH%s_new();', arg.i, typename)
+ end,
+
+ carg = function(arg, idx)
+ return string.format('arg%d', arg.i)
+ end,
+
+ creturn = function(arg, idx)
+ return string.format('arg%d', arg.i)
+ end,
+
+ precall = function(arg)
+ local txt = {}
+ if arg.default and arg.returned then
+ table.insert(txt, string.format('if(arg%d_idx)', arg.i)) -- means it was passed as arg
+ table.insert(txt, string.format('lua_pushvalue(L, arg%d_idx);', arg.i))
+ table.insert(txt, string.format('else')) -- means we did a new()
+ table.insert(txt, string.format('luaT_pushudata(L, arg%d, torch_(Tensor_id));', arg.i))
+ elseif arg.default then
+ error('a tensor cannot be optional if not returned')
+ elseif arg.returned then
+ table.insert(txt, string.format('lua_pushvalue(L, arg%d_idx);', arg.i))
+ end
+ return table.concat(txt, '\n')
+ end,
+
+ postcall = function(arg)
+ local txt = {}
+ if arg.creturned then
+ -- this next line is actually debatable
+ table.insert(txt, string.format('THTensor_(retain)(arg%d);', arg.i))
+ table.insert(txt, string.format('luaT_pushudata(L, arg%d, torch_(Tensor_id));', arg.i))
+ end
+ return table.concat(txt, '\n')
+ end
+}
+
+wrap.argtypes.IndexTensor = {
+
+ helpname = function(arg)
+ return "LongTensor"
+ end,
+
+ declare = function(arg)
+ local txt = {}
+ table.insert(txt, string.format("THLongTensor *arg%d = NULL;", arg.i))
+ if arg.returned then
+ table.insert(txt, string.format("int arg%d_idx = 0;", arg.i));
+ end
+ return table.concat(txt, '\n')
+ end,
+
+ check = function(arg, idx)
+ return string.format("(arg%d = luaT_toudata(L, %d, torch_LongTensor_id))", arg.i, idx)
+ end,
+
+ read = function(arg, idx)
+ local txt = {}
+ table.insert(txt, string.format("THLongTensor_add(arg%d, arg%d, -1);", arg.i, arg.i));
+ if arg.returned then
+ return table.insert(txt, string.format("arg%d_idx = %d;", arg.i, idx))
+ end
+ return table.concat(txt, '\n')
+ end,
+
+ init = function(arg)
+ return string.format('arg%d = THLongTensor_new();', arg.i)
+ end,
+
+ carg = function(arg, idx)
+ return string.format('arg%d', arg.i)
+ end,
+
+ creturn = function(arg, idx)
+ return string.format('arg%d', arg.i)
+ end,
+
+ precall = function(arg)
+ local txt = {}
+ if arg.default and arg.returned then
+ table.insert(txt, string.format('if(arg%d_idx)', arg.i)) -- means it was passed as arg
+ table.insert(txt, string.format('lua_pushvalue(L, arg%d_idx);', arg.i))
+ table.insert(txt, string.format('else')) -- means we did a new()
+ table.insert(txt, string.format('luaT_pushudata(L, arg%d, torch_LongTensor_id);', arg.i))
+ elseif arg.default then
+ error('a tensor cannot be optional if not returned')
+ elseif arg.returned then
+ table.insert(txt, string.format('lua_pushvalue(L, arg%d_idx);', arg.i))
+ end
+ return table.concat(txt, '\n')
+ end,
+
+ postcall = function(arg)
+ local txt = {}
+ if arg.creturned or arg.returned then
+ table.insert(txt, string.format("THLongTensor_add(arg%d, arg%d, 1);", arg.i, arg.i));
+ end
+ if arg.creturned then
+ -- this next line is actually debatable
+ table.insert(txt, string.format('THLongTensor_retain(arg%d);', arg.i))
+ table.insert(txt, string.format('luaT_pushudata(L, arg%d, torch_LongTensor_id);', arg.i))
+ end
+ return table.concat(txt, '\n')
+ end
+}
+
+for _,typename in ipairs({"ByteTensor", "CharTensor", "ShortTensor", "IntTensor", "LongTensor",
+ "FloatTensor", "DoubleTensor"}) do
+
+ wrap.argtypes[typename] = {
+
+ helpname = function(arg)
+ if arg.dim then
+ return string.format('%s~%dD', typename, arg.dim)
+ else
+ return typename
+ end
+ end,
+
+ declare = function(arg)
+ local txt = {}
+ table.insert(txt, string.format("TH%s *arg%d = NULL;", typename, arg.i))
+ if arg.returned then
+ table.insert(txt, string.format("int arg%d_idx = 0;", arg.i));
+ end
+ return table.concat(txt, '\n')
+ end,
+
+ check = function(arg, idx)
+ if arg.dim then
+ return string.format("(arg%d = luaT_toudata(L, %d, torch_%s_id)) && (arg%d->nDimension == %d)", arg.i, idx, typename, arg.i, arg.dim)
+ else
+ return string.format("(arg%d = luaT_toudata(L, %d, torch_%s_id))", arg.i, idx, typename)
+ end
+ end,
+
+ read = function(arg, idx)
+ if arg.returned then
+ return string.format("arg%d_idx = %d;", arg.i, idx)
+ end
+ end,
+
+ init = function(arg)
+ return string.format('arg%d = TH%s_new();', arg.i, typename)
+ end,
+
+ carg = function(arg, idx)
+ return string.format('arg%d', arg.i)
+ end,
+
+ creturn = function(arg, idx)
+ return string.format('arg%d', arg.i)
+ end,
+
+ precall = function(arg)
+ local txt = {}
+ if arg.default and arg.returned then
+ table.insert(txt, string.format('if(arg%d_idx)', arg.i)) -- means it was passed as arg
+ table.insert(txt, string.format('lua_pushvalue(L, arg%d_idx);', arg.i))
+ table.insert(txt, string.format('else')) -- means we did a new()
+ table.insert(txt, string.format('luaT_pushudata(L, arg%d, torch_%s_id);', arg.i, typename))
+ elseif arg.default then
+ error('a tensor cannot be optional if not returned')
+ elseif arg.returned then
+ table.insert(txt, string.format('lua_pushvalue(L, arg%d_idx);', arg.i))
+ end
+ return table.concat(txt, '\n')
+ end,
+
+ postcall = function(arg)
+ local txt = {}
+ if arg.creturned then
+ -- this next line is actually debatable
+ table.insert(txt, string.format('TH%s_retain(arg%d);', typename, arg.i))
+ table.insert(txt, string.format('luaT_pushudata(L, arg%d, torch_%s_id);', arg.i, typename))
+ end
+ return table.concat(txt, '\n')
+ end
+ }
+end
+
+local function interpretdefaultvalue(arg)
+ local default = arg.default
+ if type(default) == 'boolean' then
+ if default then
+ return '1'
+ else
+ return '0'
+ end
+ elseif type(default) == 'number' then
+ return tostring(default)
+ elseif type(default) == 'string' then
+ return default
+ elseif type(default) == 'function' then
+ default = default(arg)
+ assert(type(default) == 'string', 'a default function must return a string')
+ return default
+ elseif type(default) == 'nil' then
+ return nil
+ else
+ error('unknown default type value')
+ end
+end
+
+wrap.argtypes.index = {
+
+ helpname = function(arg)
+ return "index"
+ end,
+
+ declare = function(arg)
+ -- if it is a number we initialize here
+ local default = tonumber(interpretdefaultvalue(arg)) or 1
+ return string.format("long arg%d = %d;", arg.i, tonumber(default)-1)
+ end,
+
+ check = function(arg, idx)
+ return string.format("lua_isnumber(L, %d)", idx)
+ end,
+
+ read = function(arg, idx)
+ return string.format("arg%d = (long)lua_tonumber(L, %d)-1;", arg.i, idx)
+ end,
+
+ init = function(arg)
+ -- otherwise do it here
+ if arg.default then
+ local default = interpretdefaultvalue(arg)
+ if not tonumber(default) then
+ return string.format("arg%d = %s-1;", arg.i, default)
+ end
+ end
+ end,
+
+ carg = function(arg, idx)
+ return string.format('arg%d', arg.i)
+ end,
+
+ creturn = function(arg, idx)
+ return string.format('arg%d', arg.i)
+ end,
+
+ precall = function(arg)
+ if arg.returned then
+ return string.format('lua_pushnumber(L, (lua_Number)arg%d+1);', arg.i)
+ end
+ end,
+
+ postcall = function(arg)
+ if arg.creturned then
+ return string.format('lua_pushnumber(L, (lua_Number)arg%d+1);', arg.i)
+ end
+ end
+}
+
+for _,typename in ipairs({"real", "unsigned char", "char", "short", "int", "long", "float", "double"}) do
+ wrap.argtypes[typename] = {
+
+ helpname = function(arg)
+ return typename
+ end,
+
+ declare = function(arg)
+ -- if it is a number we initialize here
+ local default = tonumber(interpretdefaultvalue(arg)) or 0
+ return string.format("%s arg%d = %d;", typename, arg.i, tonumber(default))
+ end,
+
+ check = function(arg, idx)
+ return string.format("lua_isnumber(L, %d)", idx)
+ end,
+
+ read = function(arg, idx)
+ return string.format("arg%d = (%s)lua_tonumber(L, %d);", arg.i, typename, idx)
+ end,
+
+ init = function(arg)
+ -- otherwise do it here
+ if arg.default then
+ local default = interpretdefaultvalue(arg)
+ if not tonumber(default) then
+ return string.format("arg%d = %s;", arg.i, default)
+ end
+ end
+ end,
+
+ carg = function(arg, idx)
+ return string.format('arg%d', arg.i)
+ end,
+
+ creturn = function(arg, idx)
+ return string.format('arg%d', arg.i)
+ end,
+
+ precall = function(arg)
+ if arg.returned then
+ return string.format('lua_pushnumber(L, (lua_Number)arg%d);', arg.i)
+ end
+ end,
+
+ postcall = function(arg)
+ if arg.creturned then
+ return string.format('lua_pushnumber(L, (lua_Number)arg%d);', arg.i)
+ end
+ end
+ }
+end
+
+wrap.argtypes.byte = wrap.argtypes['unsigned char']
+
+wrap.argtypes.boolean = {
+
+ helpname = function(arg)
+ return "boolean"
+ end,
+
+ declare = function(arg)
+ -- if it is a number we initialize here
+ local default = tonumber(interpretdefaultvalue(arg)) or 0
+ return string.format("int arg%d = %d;", arg.i, tonumber(default))
+ end,
+
+ check = function(arg, idx)
+ return string.format("lua_isboolean(L, %d)", idx)
+ end,
+
+ read = function(arg, idx)
+ return string.format("arg%d = lua_toboolean(L, %d);", arg.i, idx)
+ end,
+
+ init = function(arg)
+ -- otherwise do it here
+ if arg.default then
+ local default = interpretdefaultvalue(arg)
+ if not tonumber(default) then
+ return string.format("arg%d = %s;", arg.i, default)
+ end
+ end
+ end,
+
+ carg = function(arg, idx)
+ return string.format('arg%d', arg.i)
+ end,
+
+ creturn = function(arg, idx)
+ return string.format('arg%d', arg.i)
+ end,
+
+ precall = function(arg)
+ if arg.returned then
+ return string.format('lua_pushboolean(L, arg%d);', arg.i)
+ end
+ end,
+
+ postcall = function(arg)
+ if arg.creturned then
+ return string.format('lua_pushboolean(L, arg%d);', arg.i)
+ end
+ end
+}