Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/torch7.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRonan Collobert <ronan@collobert.com>2012-08-13 16:25:58 +0400
committerRonan Collobert <ronan@collobert.com>2012-08-13 16:25:58 +0400
commit05a6d882c434de34963dad2bfa56797cf82b97a7 (patch)
tree860505a4673d5352164148a24cf52eb7ab26bc8a /utils.c
parent59f75333637e08e62236aa2c5a038584274e220d (diff)
torch now complies with the new luaT API
Diffstat (limited to 'utils.c')
-rw-r--r--utils.c88
1 files changed, 55 insertions, 33 deletions
diff --git a/utils.c b/utils.c
index 1d0fd0e..c533505 100644
--- a/utils.c
+++ b/utils.c
@@ -7,18 +7,15 @@
#include <omp.h>
#endif
-static const void* torch_LongStorage_id = NULL;
-static const void* torch_default_tensor_id = NULL;
-
THLongStorage* torch_checklongargs(lua_State *L, int index)
{
THLongStorage *storage;
int i;
int narg = lua_gettop(L)-index+1;
- if(narg == 1 && luaT_toudata(L, index, torch_LongStorage_id))
+ if(narg == 1 && luaT_toudata(L, index, "torch.LongStorage"))
{
- THLongStorage *storagesrc = luaT_toudata(L, index, torch_LongStorage_id);
+ THLongStorage *storagesrc = luaT_toudata(L, index, "torch.LongStorage");
storage = THLongStorage_newWithSize(storagesrc->size);
THLongStorage_copy(storage, storagesrc);
}
@@ -42,7 +39,7 @@ int torch_islongargs(lua_State *L, int index)
{
int narg = lua_gettop(L)-index+1;
- if(narg == 1 && luaT_toudata(L, index, torch_LongStorage_id))
+ if(narg == 1 && luaT_toudata(L, index, "torch.LongStorage"))
{
return 1;
}
@@ -81,34 +78,63 @@ static int torch_lua_toc(lua_State* L)
return 1;
}
-static int torch_lua_setdefaulttensortype(lua_State *L)
-{
- const void *id;
-
- luaL_checkstring(L, 1);
-
- if(!(id = luaT_typename2id(L, lua_tostring(L, 1)))) \
- return luaL_error(L, "<%s> is not a string describing a torch object", lua_tostring(L, 1)); \
-
- torch_default_tensor_id = id;
-
- return 0;
-}
-
static int torch_lua_getdefaulttensortype(lua_State *L)
{
- lua_pushstring(L, luaT_id2typename(L, torch_default_tensor_id));
- return 1;
-}
-
-void torch_setdefaulttensorid(const void* id)
-{
- torch_default_tensor_id = id;
+ const char* tname = torch_getdefaulttensortype(L);
+ if(tname)
+ {
+ lua_pushstring(L, tname);
+ return 1;
+ }
+ return 0;
}
-const void* torch_getdefaulttensorid()
+const char* torch_getdefaulttensortype(lua_State *L)
{
- return torch_default_tensor_id;
+ lua_getfield(L, LUA_GLOBALSINDEX, "torch");
+ if(lua_istable(L, -1))
+ {
+ lua_getfield(L, -1, "Tensor");
+ if(lua_istable(L, -1))
+ {
+ if(lua_getmetatable(L, -1))
+ {
+ lua_pushstring(L, "__index");
+ lua_rawget(L, -2);
+ if(lua_istable(L, -1))
+ {
+ lua_rawget(L, LUA_REGISTRYINDEX);
+ if(lua_isstring(L, -1))
+ {
+ const char *tname = lua_tostring(L, -1);
+ lua_pop(L, 4);
+ return tname;
+ }
+ }
+ else
+ {
+ lua_pop(L, 4);
+ return NULL;
+ }
+ }
+ else
+ {
+ lua_pop(L, 2);
+ return NULL;
+ }
+ }
+ else
+ {
+ lua_pop(L, 2);
+ return NULL;
+ }
+ }
+ else
+ {
+ lua_pop(L, 1);
+ return NULL;
+ }
+ return NULL;
}
static int torch_getnumthreads(lua_State *L)
@@ -131,7 +157,6 @@ static int torch_setnumthreads(lua_State *L)
}
static const struct luaL_Reg torch_utils__ [] = {
- {"__setdefaulttensortype", torch_lua_setdefaulttensortype},
{"getdefaulttensortype", torch_lua_getdefaulttensortype},
{"tic", torch_lua_tic},
{"toc", torch_lua_toc},
@@ -139,9 +164,7 @@ static const struct luaL_Reg torch_utils__ [] = {
{"getnumthreads", torch_getnumthreads},
{"factory", luaT_lua_factory},
{"getconstructortable", luaT_lua_getconstructortable},
- {"id", luaT_lua_id},
{"typename", luaT_lua_typename},
- {"typename2id", luaT_lua_typename2id},
{"isequal", luaT_lua_isequal},
{"getenv", luaT_lua_getenv},
{"setenv", luaT_lua_setenv},
@@ -155,6 +178,5 @@ static const struct luaL_Reg torch_utils__ [] = {
void torch_utils_init(lua_State *L)
{
- torch_LongStorage_id = luaT_checktypename2id(L, "torch.LongStorage");
luaL_register(L, NULL, torch_utils__);
}