diff options
author | Ronan Collobert <ronan@collobert.com> | 2012-08-13 16:25:58 +0400 |
---|---|---|
committer | Ronan Collobert <ronan@collobert.com> | 2012-08-13 16:25:58 +0400 |
commit | 05a6d882c434de34963dad2bfa56797cf82b97a7 (patch) | |
tree | 860505a4673d5352164148a24cf52eb7ab26bc8a /utils.c | |
parent | 59f75333637e08e62236aa2c5a038584274e220d (diff) |
torch now complies with the new luaT API
Diffstat (limited to 'utils.c')
-rw-r--r-- | utils.c | 88 |
1 files changed, 55 insertions, 33 deletions
@@ -7,18 +7,15 @@ #include <omp.h> #endif -static const void* torch_LongStorage_id = NULL; -static const void* torch_default_tensor_id = NULL; - THLongStorage* torch_checklongargs(lua_State *L, int index) { THLongStorage *storage; int i; int narg = lua_gettop(L)-index+1; - if(narg == 1 && luaT_toudata(L, index, torch_LongStorage_id)) + if(narg == 1 && luaT_toudata(L, index, "torch.LongStorage")) { - THLongStorage *storagesrc = luaT_toudata(L, index, torch_LongStorage_id); + THLongStorage *storagesrc = luaT_toudata(L, index, "torch.LongStorage"); storage = THLongStorage_newWithSize(storagesrc->size); THLongStorage_copy(storage, storagesrc); } @@ -42,7 +39,7 @@ int torch_islongargs(lua_State *L, int index) { int narg = lua_gettop(L)-index+1; - if(narg == 1 && luaT_toudata(L, index, torch_LongStorage_id)) + if(narg == 1 && luaT_toudata(L, index, "torch.LongStorage")) { return 1; } @@ -81,34 +78,63 @@ static int torch_lua_toc(lua_State* L) return 1; } -static int torch_lua_setdefaulttensortype(lua_State *L) -{ - const void *id; - - luaL_checkstring(L, 1); - - if(!(id = luaT_typename2id(L, lua_tostring(L, 1)))) \ - return luaL_error(L, "<%s> is not a string describing a torch object", lua_tostring(L, 1)); \ - - torch_default_tensor_id = id; - - return 0; -} - static int torch_lua_getdefaulttensortype(lua_State *L) { - lua_pushstring(L, luaT_id2typename(L, torch_default_tensor_id)); - return 1; -} - -void torch_setdefaulttensorid(const void* id) -{ - torch_default_tensor_id = id; + const char* tname = torch_getdefaulttensortype(L); + if(tname) + { + lua_pushstring(L, tname); + return 1; + } + return 0; } -const void* torch_getdefaulttensorid() +const char* torch_getdefaulttensortype(lua_State *L) { - return torch_default_tensor_id; + lua_getfield(L, LUA_GLOBALSINDEX, "torch"); + if(lua_istable(L, -1)) + { + lua_getfield(L, -1, "Tensor"); + if(lua_istable(L, -1)) + { + if(lua_getmetatable(L, -1)) + { + lua_pushstring(L, "__index"); + lua_rawget(L, -2); + if(lua_istable(L, -1)) + { + lua_rawget(L, LUA_REGISTRYINDEX); + if(lua_isstring(L, -1)) + { + const char *tname = lua_tostring(L, -1); + lua_pop(L, 4); + return tname; + } + } + else + { + lua_pop(L, 4); + return NULL; + } + } + else + { + lua_pop(L, 2); + return NULL; + } + } + else + { + lua_pop(L, 2); + return NULL; + } + } + else + { + lua_pop(L, 1); + return NULL; + } + return NULL; } static int torch_getnumthreads(lua_State *L) @@ -131,7 +157,6 @@ static int torch_setnumthreads(lua_State *L) } static const struct luaL_Reg torch_utils__ [] = { - {"__setdefaulttensortype", torch_lua_setdefaulttensortype}, {"getdefaulttensortype", torch_lua_getdefaulttensortype}, {"tic", torch_lua_tic}, {"toc", torch_lua_toc}, @@ -139,9 +164,7 @@ static const struct luaL_Reg torch_utils__ [] = { {"getnumthreads", torch_getnumthreads}, {"factory", luaT_lua_factory}, {"getconstructortable", luaT_lua_getconstructortable}, - {"id", luaT_lua_id}, {"typename", luaT_lua_typename}, - {"typename2id", luaT_lua_typename2id}, {"isequal", luaT_lua_isequal}, {"getenv", luaT_lua_getenv}, {"setenv", luaT_lua_setenv}, @@ -155,6 +178,5 @@ static const struct luaL_Reg torch_utils__ [] = { void torch_utils_init(lua_State *L) { - torch_LongStorage_id = luaT_checktypename2id(L, "torch.LongStorage"); luaL_register(L, NULL, torch_utils__); } |