Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cwrap.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2015-07-21 22:08:29 +0300
committerSoumith Chintala <soumith@gmail.com>2015-07-21 22:08:29 +0300
commit05c0b83b255ecd0252ab9704379db5df4e441bf5 (patch)
tree5a1088f440fea36678c0a0023678d2bc9f211618
parent6b7bb2a9f7a002e50b6d1dfc3e653ecdf4c01f4a (diff)
parent5a6ed1768b9c7d372986617e121d3f46344fb0df (diff)
Merge pull request #4 from adamlerer/type_error_message
cwrap type errors: print the observed as well as expected types.
-rw-r--r--cinterface.lua54
1 files changed, 47 insertions, 7 deletions
diff --git a/cinterface.lua b/cinterface.lua
index e8678cc..a32230b 100644
--- a/cinterface.lua
+++ b/cinterface.lua
@@ -30,6 +30,8 @@ function CInterface:wrap(luaname, ...)
-- add function to the registry
table.insert(self.registry, {name=luaname, wrapname=self:luaname2wrapname(luaname)})
+ self:__addchelpers()
+
table.insert(txt, string.format("static int %s(lua_State *L)", self:luaname2wrapname(luaname)))
table.insert(txt, "{")
table.insert(txt, "int narg = lua_gettop(L);")
@@ -41,12 +43,17 @@ function CInterface:wrap(luaname, ...)
if #varargs == 2 then
local cfuncname = varargs[1]
local args = varargs[2]
-
+
local helpargs, cargs, argcreturned = self:__writeheaders(txt, args)
self:__writechecks(txt, args)
-
+
table.insert(txt, 'else')
- table.insert(txt, string.format('luaL_error(L, "expected arguments: %s");', table.concat(helpargs, ' ')))
+ table.insert(txt, '{')
+ table.insert(txt, string.format('char type_buf[512];'))
+ table.insert(txt, string.format('str_arg_types(L, type_buf, 512);'))
+ table.insert(txt, string.format('luaL_error(L, "invalid arguments: %%s\\nexpected arguments: %s", type_buf);',
+ table.concat(helpargs, ' ')))
+ table.insert(txt, '}')
self:__writecall(txt, args, cfuncname, cargs, argcreturned)
else
@@ -78,7 +85,12 @@ function CInterface:wrap(luaname, ...)
for k=1,#varargs/2 do
table.insert(allconcathelpargs, table.concat(allhelpargs[k], ' '))
end
- table.insert(txt, string.format('luaL_error(L, "expected arguments: %s");', table.concat(allconcathelpargs, ' | ')))
+ table.insert(txt, '{')
+ table.insert(txt, string.format('char type_buf[512];'))
+ table.insert(txt, string.format('str_arg_types(L, type_buf, 512);'))
+ table.insert(txt, string.format('luaL_error(L, "invalid arguments: %%s\\nexpected arguments: %s", type_buf);',
+ table.concat(allconcathelpargs, ' | ')))
+ table.insert(txt, '}')
for k=1,#varargs/2 do
if k == 1 then
@@ -98,6 +110,36 @@ function CInterface:wrap(luaname, ...)
table.insert(txt, '')
end
+function CInterface:__addchelpers()
+ if not self.__chelpers_added then
+ local txt = self.txt
+ table.insert(txt, '#ifndef _CWRAP_STR_ARG_TYPES_4821726c1947cdf3eebacade98173939')
+ table.insert(txt, '#define _CWRAP_STR_ARG_TYPES_4821726c1947cdf3eebacade98173939')
+ table.insert(txt, '#include "string.h"')
+ table.insert(txt, 'static void str_arg_types(lua_State *L, char *buf, int n) {')
+ table.insert(txt, ' for (int i = 1; i <= lua_gettop(L); i++) {')
+ table.insert(txt, ' int l;')
+ table.insert(txt, ' const char *torch_type = luaT_typename(L, i);')
+ table.insert(txt, ' if(torch_type && !strncmp(torch_type, "torch.", 6)) torch_type += 6;')
+ table.insert(txt, ' if (torch_type) l = snprintf(buf, n, "%s ", torch_type);')
+ table.insert(txt, ' else if(lua_isnil(L, i)) l = snprintf(buf, n, "%s ", "nil");')
+ table.insert(txt, ' else if(lua_isboolean(L, i)) l = snprintf(buf, n, "%s ", "boolean");')
+ table.insert(txt, ' else if(lua_isnumber(L, i)) l = snprintf(buf, n, "%s ", "number");')
+ table.insert(txt, ' else if(lua_isstring(L, i)) l = snprintf(buf, n, "%s ", "string");')
+ table.insert(txt, ' else if(lua_istable(L, i)) l = snprintf(buf, n, "%s ", "table");')
+ table.insert(txt, ' else if(lua_isuserdata(L, i)) l = snprintf(buf, n, "%s ", "userdata");')
+ table.insert(txt, ' else l = snprintf(buf, n, "%s ", "???");')
+ table.insert(txt, ' if (l >= n) return;')
+ table.insert(txt, ' buf += l;')
+ table.insert(txt, ' n -= l;')
+ table.insert(txt, ' }')
+ table.insert(txt, '}')
+ table.insert(txt, '#endif')
+
+ self.__chelpers_added = true
+ end
+end
+
function CInterface:register(name)
local txt = self.txt
table.insert(txt, string.format('static const struct luaL_Reg %s [] = {', name))
@@ -127,7 +169,7 @@ function CInterface:tofile(filename)
end
local function bit(p)
- return 2 ^ (p - 1) -- 1-based indexing
+ return 2 ^ (p - 1) -- 1-based indexing
end
local function hasbit(x, p)
@@ -318,5 +360,3 @@ function CInterface:__writecall(txt, args, cfuncname, cargs, argcreturned)
end
return CInterface
-
-