diff options
author | Ronan Collobert <ronan@collobert.com> | 2012-08-13 16:25:58 +0400 |
---|---|---|
committer | Ronan Collobert <ronan@collobert.com> | 2012-08-13 16:25:58 +0400 |
commit | 05a6d882c434de34963dad2bfa56797cf82b97a7 (patch) | |
tree | 860505a4673d5352164148a24cf52eb7ab26bc8a /generic | |
parent | 59f75333637e08e62236aa2c5a038584274e220d (diff) |
torch now complies with the new luaT API
Diffstat (limited to 'generic')
-rw-r--r-- | generic/Storage.c | 52 | ||||
-rw-r--r-- | generic/Tensor.c | 180 | ||||
-rw-r--r-- | generic/TensorOperator.c | 32 |
3 files changed, 127 insertions, 137 deletions
diff --git a/generic/Storage.c b/generic/Storage.c index 8612d3b..b90d054 100644 --- a/generic/Storage.c +++ b/generic/Storage.c @@ -32,20 +32,20 @@ static int torch_Storage_(new)(lua_State *L) long size = luaL_optlong(L, 1, 0); storage = THStorage_(newWithSize)(size); } - luaT_pushudata(L, storage, torch_Storage_id); + luaT_pushudata(L, storage, torch_Storage); return 1; } static int torch_Storage_(free)(lua_State *L) { - THStorage *storage = luaT_checkudata(L, 1, torch_Storage_id); + THStorage *storage = luaT_checkudata(L, 1, torch_Storage); THStorage_(free)(storage); return 0; } static int torch_Storage_(resize)(lua_State *L) { - THStorage *storage = luaT_checkudata(L, 1, torch_Storage_id); + THStorage *storage = luaT_checkudata(L, 1, torch_Storage); long size = luaL_checklong(L, 2); /* int keepContent = luaT_optboolean(L, 3, 0); */ THStorage_(resize)(storage, size);/*, keepContent); */ @@ -55,23 +55,23 @@ static int torch_Storage_(resize)(lua_State *L) static int torch_Storage_(copy)(lua_State *L) { - THStorage *storage = luaT_checkudata(L, 1, torch_Storage_id); + THStorage *storage = luaT_checkudata(L, 1, torch_Storage); void *src; - if( (src = luaT_toudata(L, 2, torch_Storage_id)) ) + if( (src = luaT_toudata(L, 2, torch_Storage)) ) THStorage_(copy)(storage, src); - else if( (src = luaT_toudata(L, 2, torch_ByteStorage_id)) ) + else if( (src = luaT_toudata(L, 2, "torch.ByteStorage")) ) THStorage_(copyByte)(storage, src); - else if( (src = luaT_toudata(L, 2, torch_CharStorage_id)) ) + else if( (src = luaT_toudata(L, 2, "torch.CharStorage")) ) THStorage_(copyChar)(storage, src); - else if( (src = luaT_toudata(L, 2, torch_ShortStorage_id)) ) + else if( (src = luaT_toudata(L, 2, "torch.ShortStorage")) ) THStorage_(copyShort)(storage, src); - else if( (src = luaT_toudata(L, 2, torch_IntStorage_id)) ) + else if( (src = luaT_toudata(L, 2, "torch.IntStorage")) ) THStorage_(copyInt)(storage, src); - else if( (src = luaT_toudata(L, 2, torch_LongStorage_id)) ) + else if( (src = luaT_toudata(L, 2, "torch.LongStorage")) ) THStorage_(copyLong)(storage, src); - else if( (src = luaT_toudata(L, 2, torch_FloatStorage_id)) ) + else if( (src = luaT_toudata(L, 2, "torch.FloatStorage")) ) THStorage_(copyFloat)(storage, src); - else if( (src = luaT_toudata(L, 2, torch_DoubleStorage_id)) ) + else if( (src = luaT_toudata(L, 2, "torch.DoubleStorage")) ) THStorage_(copyDouble)(storage, src); else luaL_typerror(L, 2, "torch.*Storage"); @@ -81,7 +81,7 @@ static int torch_Storage_(copy)(lua_State *L) static int torch_Storage_(fill)(lua_State *L) { - THStorage *storage = luaT_checkudata(L, 1, torch_Storage_id); + THStorage *storage = luaT_checkudata(L, 1, torch_Storage); double value = luaL_checknumber(L, 2); THStorage_(fill)(storage, (real)value); lua_settop(L, 1); @@ -90,7 +90,7 @@ static int torch_Storage_(fill)(lua_State *L) static int torch_Storage_(__len__)(lua_State *L) { - THStorage *storage = luaT_checkudata(L, 1, torch_Storage_id); + THStorage *storage = luaT_checkudata(L, 1, torch_Storage); lua_pushnumber(L, storage->size); return 1; } @@ -99,7 +99,7 @@ static int torch_Storage_(__newindex__)(lua_State *L) { if(lua_isnumber(L, 2)) { - THStorage *storage = luaT_checkudata(L, 1, torch_Storage_id); + THStorage *storage = luaT_checkudata(L, 1, torch_Storage); long index = luaL_checklong(L, 2) - 1; double number = luaL_checknumber(L, 3); THStorage_(set)(storage, index, (real)number); @@ -115,7 +115,7 @@ static int torch_Storage_(__index__)(lua_State *L) { if(lua_isnumber(L, 2)) { - THStorage *storage = luaT_checkudata(L, 1, torch_Storage_id); + THStorage *storage = luaT_checkudata(L, 1, torch_Storage); long index = luaL_checklong(L, 2) - 1; lua_pushnumber(L, THStorage_(get)(storage, index)); lua_pushboolean(L, 1); @@ -131,7 +131,7 @@ static int torch_Storage_(__index__)(lua_State *L) #if defined(TH_REAL_IS_CHAR) || defined(TH_REAL_IS_BYTE) static int torch_Storage_(string)(lua_State *L) { - THStorage *storage = luaT_checkudata(L, 1, torch_Storage_id); + THStorage *storage = luaT_checkudata(L, 1, torch_Storage); if(lua_isstring(L, -1)) { size_t len = 0; @@ -149,7 +149,7 @@ static int torch_Storage_(string)(lua_State *L) static int torch_Storage_(totable)(lua_State *L) { - THStorage *storage = luaT_checkudata(L, 1, torch_Storage_id); + THStorage *storage = luaT_checkudata(L, 1, torch_Storage); long i; lua_newtable(L); @@ -164,14 +164,14 @@ static int torch_Storage_(totable)(lua_State *L) static int torch_Storage_(factory)(lua_State *L) { THStorage *storage = THStorage_(new)(); - luaT_pushudata(L, storage, torch_Storage_id); + luaT_pushudata(L, storage, torch_Storage); return 1; } static int torch_Storage_(write)(lua_State *L) { - THStorage *storage = luaT_checkudata(L, 1, torch_Storage_id); - THFile *file = luaT_checkudata(L, 2, torch_File_id); + THStorage *storage = luaT_checkudata(L, 1, torch_Storage); + THFile *file = luaT_checkudata(L, 2, "torch.File"); THFile_writeLongScalar(file, storage->size); THFile_writeRealRaw(file, storage->data, storage->size); @@ -181,8 +181,8 @@ static int torch_Storage_(write)(lua_State *L) static int torch_Storage_(read)(lua_State *L) { - THStorage *storage = luaT_checkudata(L, 1, torch_Storage_id); - THFile *file = luaT_checkudata(L, 2, torch_File_id); + THStorage *storage = luaT_checkudata(L, 1, torch_Storage); + THFile *file = luaT_checkudata(L, 2, "torch.File"); long size = THFile_readLongScalar(file); THStorage_(resize)(storage, size); @@ -210,10 +210,8 @@ static const struct luaL_Reg torch_Storage_(_) [] = { void torch_Storage_(init)(lua_State *L) { - torch_File_id = luaT_checktypename2id(L, "torch.File"); - - torch_Storage_id = luaT_newmetatable(L, STRING_torchStorage, NULL, - torch_Storage_(new), torch_Storage_(free), torch_Storage_(factory)); + luaT_newmetatable(L, torch_Storage, NULL, + torch_Storage_(new), torch_Storage_(free), torch_Storage_(factory)); luaL_register(L, NULL, torch_Storage_(_)); lua_pop(L, 1); } diff --git a/generic/Tensor.c b/generic/Tensor.c index 06de788..ced0a68 100644 --- a/generic/Tensor.c +++ b/generic/Tensor.c @@ -9,7 +9,7 @@ static void torch_Tensor_(c_readSizeStride)(lua_State *L, int index, int allowSt static int torch_Tensor_(size)(lua_State *L) { - THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id); + THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor); if(lua_isnumber(L,2)) { int dim = luaL_checkint(L, 2)-1; @@ -20,14 +20,14 @@ static int torch_Tensor_(size)(lua_State *L) { THLongStorage *storage = THLongStorage_newWithSize(tensor->nDimension); memmove(storage->data, tensor->size, sizeof(long)*tensor->nDimension); - luaT_pushudata(L, storage, torch_LongStorage_id); + luaT_pushudata(L, storage, "torch.LongStorage"); } return 1; } static int torch_Tensor_(stride)(lua_State *L) { - THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id); + THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor); if(lua_isnumber(L,2)) { int dim = luaL_checkint(L, 2)-1; @@ -38,25 +38,25 @@ static int torch_Tensor_(stride)(lua_State *L) { THLongStorage *storage = THLongStorage_newWithSize(tensor->nDimension); memmove(storage->data, tensor->stride, sizeof(long)*tensor->nDimension); - luaT_pushudata(L, storage, torch_LongStorage_id); + luaT_pushudata(L, storage, "torch.LongStorage"); } return 1; } static int torch_Tensor_(nDimension)(lua_State *L) { - THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id); + THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor); lua_pushnumber(L, tensor->nDimension); return 1; } static int torch_Tensor_(storage)(lua_State *L) { - THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id); + THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor); if(tensor->storage) { THStorage_(retain)(tensor->storage); - luaT_pushudata(L, tensor->storage, torch_Storage_id); + luaT_pushudata(L, tensor->storage, torch_Storage); } else lua_pushnil(L); @@ -66,7 +66,7 @@ static int torch_Tensor_(storage)(lua_State *L) static int torch_Tensor_(storageOffset)(lua_State *L) { - THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id); + THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor); lua_pushnumber(L, tensor->storageOffset+1); return 1; } @@ -197,13 +197,13 @@ static int torch_Tensor_(new)(lua_State *L) THLongStorage_free(stride); } - luaT_pushudata(L, tensor, torch_Tensor_id); + luaT_pushudata(L, tensor, torch_Tensor); return 1; } static int torch_Tensor_(set)(lua_State *L) { - THTensor *self = luaT_checkudata(L, 1, torch_Tensor_id); + THTensor *self = luaT_checkudata(L, 1, torch_Tensor); THStorage *storage; long storageOffset; THLongStorage *size, *stride; @@ -222,25 +222,25 @@ static int torch_Tensor_(set)(lua_State *L) static int torch_Tensor_(clone)(lua_State *L) { - THTensor *self = luaT_checkudata(L, 1, torch_Tensor_id); + THTensor *self = luaT_checkudata(L, 1, torch_Tensor); self = THTensor_(newClone)(self); - luaT_pushudata(L, self, torch_Tensor_id); + luaT_pushudata(L, self, torch_Tensor); return 1; } static int torch_Tensor_(contiguous)(lua_State *L) { - THTensor *self = luaT_checkudata(L, 1, torch_Tensor_id); + THTensor *self = luaT_checkudata(L, 1, torch_Tensor); self = THTensor_(newContiguous)(self); - luaT_pushudata(L, self, torch_Tensor_id); + luaT_pushudata(L, self, torch_Tensor); return 1; } /* Resize */ static int torch_Tensor_(resizeAs)(lua_State *L) { - THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id); - THTensor *src = luaT_checkudata(L, 2, torch_Tensor_id); + THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor); + THTensor *src = luaT_checkudata(L, 2, torch_Tensor); THTensor_(resizeAs)(tensor, src); lua_settop(L, 1); return 1; @@ -248,7 +248,7 @@ static int torch_Tensor_(resizeAs)(lua_State *L) static int torch_Tensor_(resize)(lua_State *L) { - THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id); + THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor); THLongStorage *size, *stride; torch_Tensor_(c_readSizeStride)(L, 2, 0, &size, &stride); @@ -264,7 +264,7 @@ static int torch_Tensor_(resize)(lua_State *L) static int torch_Tensor_(narrow)(lua_State *L) { - THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id); + THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor); int dimension = luaL_checkint(L, 2)-1; long firstIndex = luaL_checklong(L, 3)-1; long size = luaL_checklong(L, 4); @@ -275,13 +275,13 @@ static int torch_Tensor_(narrow)(lua_State *L) */ tensor = THTensor_(newWithTensor)(tensor); THTensor_(narrow)(tensor, NULL, dimension, firstIndex, size); - luaT_pushudata(L, tensor, torch_Tensor_id); + luaT_pushudata(L, tensor, torch_Tensor); return 1; } static int torch_Tensor_(sub)(lua_State *L) { - THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id); + THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor); long d0s = -1, d0e = -1, d1s = -1, d1e = -1, d2s = -1, d2e = -1, d3s = -1, d3e = -1; d0s = luaL_checklong(L, 2)-1; @@ -345,13 +345,13 @@ static int torch_Tensor_(sub)(lua_State *L) THTensor_(narrow)(tensor, NULL, 2, d2s, d2e-d2s+1); if(d3s >= 0) THTensor_(narrow)(tensor, NULL, 3, d3s, d3e-d3s+1); - luaT_pushudata(L, tensor, torch_Tensor_id); + luaT_pushudata(L, tensor, torch_Tensor); return 1; } static int torch_Tensor_(select)(lua_State *L) { - THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id); + THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor); int dimension = luaL_checkint(L, 2)-1; long sliceIndex = luaL_checklong(L, 3)-1; @@ -364,7 +364,7 @@ static int torch_Tensor_(select)(lua_State *L) { tensor = THTensor_(newWithTensor)(tensor); THTensor_(select)(tensor, NULL, dimension, sliceIndex); - luaT_pushudata(L, tensor, torch_Tensor_id); + luaT_pushudata(L, tensor, torch_Tensor); } else { @@ -378,7 +378,7 @@ static int torch_Tensor_(select)(lua_State *L) static int torch_Tensor_(transpose)(lua_State *L) { - THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id); + THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor); int dimension1 = luaL_checkint(L, 2)-1; int dimension2 = luaL_checkint(L, 3)-1; @@ -389,25 +389,25 @@ static int torch_Tensor_(transpose)(lua_State *L) tensor = THTensor_(newWithTensor)(tensor); THTensor_(transpose)(tensor, NULL, dimension1, dimension2); - luaT_pushudata(L, tensor, torch_Tensor_id); + luaT_pushudata(L, tensor, torch_Tensor); return 1; } static int torch_Tensor_(t)(lua_State *L) { - THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id); + THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor); luaL_argcheck(L, tensor->nDimension == 2, 1, "Tensor must have 2 dimensions"); tensor = THTensor_(newWithTensor)(tensor); THTensor_(transpose)(tensor, NULL, 0, 1); - luaT_pushudata(L, tensor, torch_Tensor_id); + luaT_pushudata(L, tensor, torch_Tensor); return 1; } static int torch_Tensor_(unfold)(lua_State *L) { - THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id); + THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor); int dimension = luaL_checkint(L, 2)-1; long size = luaL_checklong(L, 3); long step = luaL_checklong(L, 4); @@ -420,44 +420,44 @@ static int torch_Tensor_(unfold)(lua_State *L) tensor = THTensor_(newWithTensor)(tensor); THTensor_(unfold)(tensor, NULL, dimension, size, step); - luaT_pushudata(L, tensor, torch_Tensor_id); + luaT_pushudata(L, tensor, torch_Tensor); return 1; } /* is contiguous? [a bit like in TnXIterator] */ static int torch_Tensor_(isContiguous)(lua_State *L) { - THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id); + THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor); lua_pushboolean(L, THTensor_(isContiguous)(tensor)); return 1; } static int torch_Tensor_(nElement)(lua_State *L) { - THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id); + THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor); lua_pushnumber(L, THTensor_(nElement)(tensor)); return 1; } static int torch_Tensor_(copy)(lua_State *L) { - THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id); + THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor); void *src; - if( (src = luaT_toudata(L, 2, torch_Tensor_id)) ) + if( (src = luaT_toudata(L, 2, torch_Tensor)) ) THTensor_(copy)(tensor, src); - else if( (src = luaT_toudata(L, 2, torch_ByteTensor_id)) ) + else if( (src = luaT_toudata(L, 2, "torch.ByteTensor")) ) THTensor_(copyByte)(tensor, src); - else if( (src = luaT_toudata(L, 2, torch_CharTensor_id)) ) + else if( (src = luaT_toudata(L, 2, "torch.CharTensor")) ) THTensor_(copyChar)(tensor, src); - else if( (src = luaT_toudata(L, 2, torch_ShortTensor_id)) ) + else if( (src = luaT_toudata(L, 2, "torch.ShortTensor")) ) THTensor_(copyShort)(tensor, src); - else if( (src = luaT_toudata(L, 2, torch_IntTensor_id)) ) + else if( (src = luaT_toudata(L, 2, "torch.IntTensor")) ) THTensor_(copyInt)(tensor, src); - else if( (src = luaT_toudata(L, 2, torch_LongTensor_id)) ) + else if( (src = luaT_toudata(L, 2, "torch.LongTensor")) ) THTensor_(copyLong)(tensor, src); - else if( (src = luaT_toudata(L, 2, torch_FloatTensor_id)) ) + else if( (src = luaT_toudata(L, 2, "torch.FloatTensor")) ) THTensor_(copyFloat)(tensor, src); - else if( (src = luaT_toudata(L, 2, torch_DoubleTensor_id)) ) + else if( (src = luaT_toudata(L, 2, "torch.DoubleTensor")) ) THTensor_(copyDouble)(tensor, src); else luaL_typerror(L, 2, "torch.*Tensor"); @@ -467,7 +467,7 @@ static int torch_Tensor_(copy)(lua_State *L) static int torch_Tensor_(__newindex__)(lua_State *L) { - THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id); + THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor); THLongStorage *idx = NULL; THByteTensor *mask; @@ -487,42 +487,42 @@ static int torch_Tensor_(__newindex__)(lua_State *L) THTensor_(fill)(tensor, value); THTensor_(free)(tensor); } - } else if( (src = luaT_toudata(L, 3, torch_Tensor_id)) ) { + } else if( (src = luaT_toudata(L, 3, torch_Tensor)) ) { tensor = THTensor_(newWithTensor)(tensor); THTensor_(narrow)(tensor, NULL, 0, index, 1); THTensor_(copy)(tensor, src); THTensor_(free)(tensor); - } else if( (src = luaT_toudata(L, 3, torch_ByteTensor_id)) ) { + } else if( (src = luaT_toudata(L, 3, "torch.ByteTensor")) ) { tensor = THTensor_(newWithTensor)(tensor); THTensor_(narrow)(tensor, NULL, 0, index, 1); THTensor_(copyByte)(tensor, src); THTensor_(free)(tensor); - } else if( (src = luaT_toudata(L, 3, torch_CharTensor_id)) ) { + } else if( (src = luaT_toudata(L, 3, "torch.CharTensor")) ) { tensor = THTensor_(newWithTensor)(tensor); THTensor_(narrow)(tensor, NULL, 0, index, 1); THTensor_(copyChar)(tensor, src); THTensor_(free)(tensor); - } else if( (src = luaT_toudata(L, 3, torch_ShortTensor_id)) ) { + } else if( (src = luaT_toudata(L, 3, "torch.ShortTensor")) ) { tensor = THTensor_(newWithTensor)(tensor); THTensor_(narrow)(tensor, NULL, 0, index, 1); THTensor_(copyShort)(tensor, src); THTensor_(free)(tensor); - } else if( (src = luaT_toudata(L, 3, torch_IntTensor_id)) ) { + } else if( (src = luaT_toudata(L, 3, "torch.IntTensor")) ) { tensor = THTensor_(newWithTensor)(tensor); THTensor_(narrow)(tensor, NULL, 0, index, 1); THTensor_(copyInt)(tensor, src); THTensor_(free)(tensor); - } else if( (src = luaT_toudata(L, 3, torch_LongTensor_id)) ) { + } else if( (src = luaT_toudata(L, 3, "torch.LongTensor")) ) { tensor = THTensor_(newWithTensor)(tensor); THTensor_(narrow)(tensor, NULL, 0, index, 1); THTensor_(copyLong)(tensor, src); THTensor_(free)(tensor); - } else if( (src = luaT_toudata(L, 3, torch_FloatTensor_id)) ) { + } else if( (src = luaT_toudata(L, 3, "torch.FloatTensor")) ) { tensor = THTensor_(newWithTensor)(tensor); THTensor_(narrow)(tensor, NULL, 0, index, 1); THTensor_(copyFloat)(tensor, src); THTensor_(free)(tensor); - } else if( (src = luaT_toudata(L, 3, torch_DoubleTensor_id)) ) { + } else if( (src = luaT_toudata(L, 3, "torch.DoubleTensor")) ) { tensor = THTensor_(newWithTensor)(tensor); THTensor_(narrow)(tensor, NULL, 0, index, 1); THTensor_(copyDouble)(tensor, src); @@ -532,7 +532,7 @@ static int torch_Tensor_(__newindex__)(lua_State *L) } lua_pushboolean(L, 1); } - else if((idx = luaT_toudata(L, 2, torch_LongStorage_id))) + else if((idx = luaT_toudata(L, 2, "torch.LongStorage"))) { long index = THTensor_(storageOffset)(tensor); real value = (real)luaL_checknumber(L,3); @@ -611,37 +611,37 @@ static int torch_Tensor_(__newindex__)(lua_State *L) void *src; if (lua_isnumber(L,3)) { THTensor_(fill)(tensor, lua_tonumber(L,3)); - } else if( (src = luaT_toudata(L, 3, torch_Tensor_id)) ) { + } else if( (src = luaT_toudata(L, 3, torch_Tensor)) ) { THTensor_(copy)(tensor, src); - } else if( (src = luaT_toudata(L, 3, torch_ByteTensor_id)) ) { + } else if( (src = luaT_toudata(L, 3, "torch.ByteTensor")) ) { THTensor_(copyByte)(tensor, src); - } else if( (src = luaT_toudata(L, 3, torch_CharTensor_id)) ) { + } else if( (src = luaT_toudata(L, 3, "torch.CharTensor")) ) { THTensor_(copyChar)(tensor, src); - } else if( (src = luaT_toudata(L, 3, torch_ShortTensor_id)) ) { + } else if( (src = luaT_toudata(L, 3, "torch.ShortTensor")) ) { THTensor_(copyShort)(tensor, src); - } else if( (src = luaT_toudata(L, 3, torch_IntTensor_id)) ) { + } else if( (src = luaT_toudata(L, 3, "torch.IntTensor")) ) { THTensor_(copyInt)(tensor, src); - } else if( (src = luaT_toudata(L, 3, torch_LongTensor_id)) ) { + } else if( (src = luaT_toudata(L, 3, "torch.LongTensor")) ) { THTensor_(copyLong)(tensor, src); - } else if( (src = luaT_toudata(L, 3, torch_FloatTensor_id)) ) { + } else if( (src = luaT_toudata(L, 3, "torch.FloatTensor")) ) { THTensor_(copyFloat)(tensor, src); - } else if( (src = luaT_toudata(L, 3, torch_DoubleTensor_id)) ) { + } else if( (src = luaT_toudata(L, 3, "torch.DoubleTensor")) ) { THTensor_(copyDouble)(tensor, src); } else { luaL_typerror(L, 3, "torch.*Tensor"); } - } + } THTensor_(free)(tensor); lua_pushboolean(L, 1); } - else if((mask = luaT_toudata(L, 2, torch_ByteTensor_id))) + else if((mask = luaT_toudata(L, 2, "torch.ByteTensor"))) { THTensor *vals; if (lua_isnumber(L, 3)) { THTensor_(maskedFill)(tensor, mask, (real)(luaL_checknumber(L,3))); } - else if((vals = luaT_toudata(L, 3, torch_Tensor_id))) + else if((vals = luaT_toudata(L, 3, torch_Tensor))) { THTensor_(maskedCopy)(tensor, mask, vals); } @@ -658,7 +658,7 @@ static int torch_Tensor_(__newindex__)(lua_State *L) static int torch_Tensor_(__index__)(lua_State *L) { - THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id); + THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor); THLongStorage *idx = NULL; THByteTensor *mask; @@ -678,12 +678,12 @@ static int torch_Tensor_(__index__)(lua_State *L) { tensor = THTensor_(newWithTensor)(tensor); THTensor_(select)(tensor, NULL, 0, index); - luaT_pushudata(L, tensor, torch_Tensor_id); + luaT_pushudata(L, tensor, torch_Tensor); } lua_pushboolean(L, 1); return 2; } - else if((idx = luaT_toudata(L, 2, torch_LongStorage_id))) + else if((idx = luaT_toudata(L, 2, "torch.LongStorage"))) { long index = THTensor_(storageOffset)(tensor); int dim; @@ -756,18 +756,18 @@ static int torch_Tensor_(__index__)(lua_State *L) } } if(!done) { - luaT_pushudata(L, tensor, torch_Tensor_id); + luaT_pushudata(L, tensor, torch_Tensor); } else { THTensor_(free)(tensor); } lua_pushboolean(L, 1); return 2; } - else if((mask = luaT_toudata(L, 2, torch_ByteTensor_id))) + else if((mask = luaT_toudata(L, 2, "torch.ByteTensor"))) { THTensor *vals = THTensor_(new)(); THTensor_(maskedSelect)(vals, tensor, mask); - luaT_pushudata(L, vals, torch_Tensor_id); + luaT_pushudata(L, vals, torch_Tensor); lua_pushboolean(L, 1); return 2; } @@ -780,7 +780,7 @@ static int torch_Tensor_(__index__)(lua_State *L) static int torch_Tensor_(free)(lua_State *L) { - THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id); + THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor); THTensor_(free)(tensor); return 0; } @@ -791,11 +791,11 @@ static void torch_Tensor_(c_readSizeStride)(lua_State *L, int index, int allowSt THLongStorage *size = NULL; THLongStorage *stride = NULL; - if( (size = luaT_toudata(L, index, torch_LongStorage_id)) ) + if( (size = luaT_toudata(L, index, "torch.LongStorage")) ) { if(!lua_isnoneornil(L, index+1)) { - if( (stride = luaT_toudata(L, index+1, torch_LongStorage_id)) ) + if( (stride = luaT_toudata(L, index+1, "torch.LongStorage")) ) luaL_argcheck(L, stride->size == size->size, index+1, "provided stride and size are inconsistent"); else luaL_argcheck(L, 0, index+1, "torch.LongStorage expected"); @@ -858,7 +858,7 @@ static void torch_Tensor_(c_readTensorStorageSizeStride)(lua_State *L, int index *stride_ = NULL; return; } - else if( allowTensor && (arg1Type == LUA_TUSERDATA) && (src = luaT_toudata(L, index, torch_Tensor_id)) ) + else if( allowTensor && (arg1Type == LUA_TUSERDATA) && (src = luaT_toudata(L, index, torch_Tensor)) ) { *storage_ = src->storage; *storageOffset_ = src->storageOffset; @@ -866,7 +866,7 @@ static void torch_Tensor_(c_readTensorStorageSizeStride)(lua_State *L, int index *stride_ = THTensor_(newStrideOf)(src); return; } - else if( allowStorage && (arg1Type == LUA_TUSERDATA) && (storage = luaT_toudata(L, index, torch_Storage_id)) ) + else if( allowStorage && (arg1Type == LUA_TUSERDATA) && (storage = luaT_toudata(L, index, torch_Storage)) ) { *storage_ = storage; if(lua_isnone(L, index+1)) @@ -882,7 +882,7 @@ static void torch_Tensor_(c_readTensorStorageSizeStride)(lua_State *L, int index } return; } - else if( (arg1Type == LUA_TNUMBER) || (luaT_toudata(L, index, torch_LongStorage_id)) ) + else if( (arg1Type == LUA_TNUMBER) || (luaT_toudata(L, index, "torch.LongStorage")) ) { *storage_ = NULL; *storageOffset_ = 0; @@ -900,7 +900,7 @@ static void torch_Tensor_(c_readTensorStorageSizeStride)(lua_State *L, int index static int torch_Tensor_(apply)(lua_State *L) { - THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id); + THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor); luaL_checktype(L, 2, LUA_TFUNCTION); lua_settop(L, 2); @@ -924,8 +924,8 @@ static int torch_Tensor_(apply)(lua_State *L) static int torch_Tensor_(map)(lua_State *L) { - THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id); - THTensor *src = luaT_checkudata(L, 2, torch_Tensor_id); + THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor); + THTensor *src = luaT_checkudata(L, 2, torch_Tensor); luaL_checktype(L, 3, LUA_TFUNCTION); lua_settop(L, 3); @@ -950,9 +950,9 @@ static int torch_Tensor_(map)(lua_State *L) static int torch_Tensor_(map2)(lua_State *L) { - THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id); - THTensor *src1 = luaT_checkudata(L, 2, torch_Tensor_id); - THTensor *src2 = luaT_checkudata(L, 3, torch_Tensor_id); + THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor); + THTensor *src1 = luaT_checkudata(L, 2, torch_Tensor); + THTensor *src2 = luaT_checkudata(L, 3, torch_Tensor); luaL_checktype(L, 4, LUA_TFUNCTION); lua_settop(L, 4); @@ -979,14 +979,14 @@ static int torch_Tensor_(map2)(lua_State *L) static int torch_Tensor_(factory)(lua_State *L) { THTensor *tensor = THTensor_(new)(); - luaT_pushudata(L, tensor, torch_Tensor_id); + luaT_pushudata(L, tensor, torch_Tensor); return 1; } static int torch_Tensor_(write)(lua_State *L) { - THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id); - THFile *file = luaT_checkudata(L, 2, torch_File_id); + THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor); + THFile *file = luaT_checkudata(L, 2, "torch.File"); THFile_writeIntScalar(file, tensor->nDimension); THFile_writeLongRaw(file, tensor->size, tensor->nDimension); @@ -999,7 +999,7 @@ static int torch_Tensor_(write)(lua_State *L) if(tensor->storage) { THStorage_(retain)(tensor->storage); - luaT_pushudata(L, tensor->storage, torch_Storage_id); + luaT_pushudata(L, tensor->storage, torch_Storage); } else lua_pushnil(L); @@ -1011,8 +1011,8 @@ static int torch_Tensor_(write)(lua_State *L) static int torch_Tensor_(read)(lua_State *L) { - THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id); - THFile *file = luaT_checkudata(L, 2, torch_File_id); + THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor); + THFile *file = luaT_checkudata(L, 2, "torch.File"); tensor->nDimension = THFile_readIntScalar(file); tensor->size = THAlloc(sizeof(long)*tensor->nDimension); @@ -1026,7 +1026,7 @@ static int torch_Tensor_(read)(lua_State *L) lua_pushvalue(L, 2); /* the file */ lua_call(L, 1, 1); /* call the method */ - tensor->storage = luaT_toudata(L, -1, torch_Storage_id); + tensor->storage = luaT_toudata(L, -1, torch_Storage); if(tensor->storage) THStorage_(retain)(tensor->storage); @@ -1068,12 +1068,8 @@ static const struct luaL_Reg torch_Tensor_(_) [] = { void torch_Tensor_(init)(lua_State *L) { - torch_File_id = luaT_checktypename2id(L, "torch.File"); - torch_LongStorage_id = luaT_checktypename2id(L, "torch.LongStorage"); - torch_Storage_id = luaT_checktypename2id(L, STRING_torchStorage); - - torch_Tensor_id = luaT_newmetatable(L, STRING_torchTensor, NULL, - torch_Tensor_(new), torch_Tensor_(free), torch_Tensor_(factory)); + luaT_newmetatable(L, torch_Tensor, NULL, + torch_Tensor_(new), torch_Tensor_(free), torch_Tensor_(factory)); luaL_register(L, NULL, torch_Tensor_(_)); lua_pop(L, 1); } diff --git a/generic/TensorOperator.c b/generic/TensorOperator.c index 10d8f8e..c69ca38 100644 --- a/generic/TensorOperator.c +++ b/generic/TensorOperator.c @@ -2,12 +2,10 @@ #define TH_GENERIC_FILE "generic/TensorOperator.c" #else -static const void* torch_Tensor_id; - static int torch_TensorOperator_(__add__)(lua_State *L) { - THTensor *tensor1 = luaT_toudata(L, 1, torch_Tensor_id); - THTensor *tensor2 = luaT_toudata(L, 2, torch_Tensor_id); + THTensor *tensor1 = luaT_toudata(L, 1, torch_Tensor); + THTensor *tensor2 = luaT_toudata(L, 2, torch_Tensor); THTensor *r; if(!tensor1 && !tensor2) @@ -15,7 +13,7 @@ static int torch_TensorOperator_(__add__)(lua_State *L) else { r = THTensor_(new)(); - luaT_pushudata(L, r, torch_Tensor_id); + luaT_pushudata(L, r, torch_Tensor); if(!tensor1 && tensor2) { @@ -41,8 +39,8 @@ static int torch_TensorOperator_(__add__)(lua_State *L) static int torch_TensorOperator_(__sub__)(lua_State *L) { - THTensor *tensor1 = luaT_toudata(L, 1, torch_Tensor_id); - THTensor *tensor2 = luaT_toudata(L, 2, torch_Tensor_id); + THTensor *tensor1 = luaT_toudata(L, 1, torch_Tensor); + THTensor *tensor2 = luaT_toudata(L, 2, torch_Tensor); THTensor *r; if(!tensor1 && !tensor2) @@ -50,7 +48,7 @@ static int torch_TensorOperator_(__sub__)(lua_State *L) else { r = THTensor_(new)(); - luaT_pushudata(L, r, torch_Tensor_id); + luaT_pushudata(L, r, torch_Tensor); if(!tensor1 && tensor2) { @@ -76,11 +74,11 @@ static int torch_TensorOperator_(__sub__)(lua_State *L) static int torch_TensorOperator_(__unm__)(lua_State *L) { - THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id); + THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor); THTensor *r; r = THTensor_(new)(); - luaT_pushudata(L, r, torch_Tensor_id); + luaT_pushudata(L, r, torch_Tensor); THTensor_(resizeAs)(r, tensor); THTensor_(copy)(r, tensor); THTensor_(mul)(r, r, -1); @@ -90,8 +88,8 @@ static int torch_TensorOperator_(__unm__)(lua_State *L) static int torch_TensorOperator_(__mul__)(lua_State *L) { - THTensor *tensor1 = luaT_toudata(L, 1, torch_Tensor_id); - THTensor *tensor2 = luaT_toudata(L, 2, torch_Tensor_id); + THTensor *tensor1 = luaT_toudata(L, 1, torch_Tensor); + THTensor *tensor2 = luaT_toudata(L, 2, torch_Tensor); THTensor *r; if(!tensor1 && !tensor2) @@ -99,7 +97,7 @@ static int torch_TensorOperator_(__mul__)(lua_State *L) else { r = THTensor_(new)(); - luaT_pushudata(L, r, torch_Tensor_id); + luaT_pushudata(L, r, torch_Tensor); if(!tensor1 && tensor2) { @@ -141,13 +139,13 @@ static int torch_TensorOperator_(__mul__)(lua_State *L) static int torch_TensorOperator_(__div__)(lua_State *L) { - THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id); + THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor); THTensor *r; luaL_argcheck(L, lua_isnumber(L,2), 2, "number expected"); r = THTensor_(new)(); - luaT_pushudata(L, r, torch_Tensor_id); + luaT_pushudata(L, r, torch_Tensor); THTensor_(resizeAs)(r, tensor); THTensor_(copy)(r, tensor); @@ -167,9 +165,7 @@ static const struct luaL_Reg torch_TensorOperator_(_) [] = { void torch_TensorOperator_(init)(lua_State *L) { - torch_Tensor_id = luaT_checktypename2id(L, STRING_torchTensor); - - luaT_pushmetaclass(L, torch_Tensor_id); + luaT_pushmetatable(L, torch_Tensor); luaL_register(L, NULL, torch_TensorOperator_(_)); lua_pop(L, 1); } |