Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/torch7.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRonan Collobert <ronan@collobert.com>2012-08-13 16:25:58 +0400
committerRonan Collobert <ronan@collobert.com>2012-08-13 16:25:58 +0400
commit05a6d882c434de34963dad2bfa56797cf82b97a7 (patch)
tree860505a4673d5352164148a24cf52eb7ab26bc8a /generic
parent59f75333637e08e62236aa2c5a038584274e220d (diff)
torch now complies with the new luaT API
Diffstat (limited to 'generic')
-rw-r--r--generic/Storage.c52
-rw-r--r--generic/Tensor.c180
-rw-r--r--generic/TensorOperator.c32
3 files changed, 127 insertions, 137 deletions
diff --git a/generic/Storage.c b/generic/Storage.c
index 8612d3b..b90d054 100644
--- a/generic/Storage.c
+++ b/generic/Storage.c
@@ -32,20 +32,20 @@ static int torch_Storage_(new)(lua_State *L)
long size = luaL_optlong(L, 1, 0);
storage = THStorage_(newWithSize)(size);
}
- luaT_pushudata(L, storage, torch_Storage_id);
+ luaT_pushudata(L, storage, torch_Storage);
return 1;
}
static int torch_Storage_(free)(lua_State *L)
{
- THStorage *storage = luaT_checkudata(L, 1, torch_Storage_id);
+ THStorage *storage = luaT_checkudata(L, 1, torch_Storage);
THStorage_(free)(storage);
return 0;
}
static int torch_Storage_(resize)(lua_State *L)
{
- THStorage *storage = luaT_checkudata(L, 1, torch_Storage_id);
+ THStorage *storage = luaT_checkudata(L, 1, torch_Storage);
long size = luaL_checklong(L, 2);
/* int keepContent = luaT_optboolean(L, 3, 0); */
THStorage_(resize)(storage, size);/*, keepContent); */
@@ -55,23 +55,23 @@ static int torch_Storage_(resize)(lua_State *L)
static int torch_Storage_(copy)(lua_State *L)
{
- THStorage *storage = luaT_checkudata(L, 1, torch_Storage_id);
+ THStorage *storage = luaT_checkudata(L, 1, torch_Storage);
void *src;
- if( (src = luaT_toudata(L, 2, torch_Storage_id)) )
+ if( (src = luaT_toudata(L, 2, torch_Storage)) )
THStorage_(copy)(storage, src);
- else if( (src = luaT_toudata(L, 2, torch_ByteStorage_id)) )
+ else if( (src = luaT_toudata(L, 2, "torch.ByteStorage")) )
THStorage_(copyByte)(storage, src);
- else if( (src = luaT_toudata(L, 2, torch_CharStorage_id)) )
+ else if( (src = luaT_toudata(L, 2, "torch.CharStorage")) )
THStorage_(copyChar)(storage, src);
- else if( (src = luaT_toudata(L, 2, torch_ShortStorage_id)) )
+ else if( (src = luaT_toudata(L, 2, "torch.ShortStorage")) )
THStorage_(copyShort)(storage, src);
- else if( (src = luaT_toudata(L, 2, torch_IntStorage_id)) )
+ else if( (src = luaT_toudata(L, 2, "torch.IntStorage")) )
THStorage_(copyInt)(storage, src);
- else if( (src = luaT_toudata(L, 2, torch_LongStorage_id)) )
+ else if( (src = luaT_toudata(L, 2, "torch.LongStorage")) )
THStorage_(copyLong)(storage, src);
- else if( (src = luaT_toudata(L, 2, torch_FloatStorage_id)) )
+ else if( (src = luaT_toudata(L, 2, "torch.FloatStorage")) )
THStorage_(copyFloat)(storage, src);
- else if( (src = luaT_toudata(L, 2, torch_DoubleStorage_id)) )
+ else if( (src = luaT_toudata(L, 2, "torch.DoubleStorage")) )
THStorage_(copyDouble)(storage, src);
else
luaL_typerror(L, 2, "torch.*Storage");
@@ -81,7 +81,7 @@ static int torch_Storage_(copy)(lua_State *L)
static int torch_Storage_(fill)(lua_State *L)
{
- THStorage *storage = luaT_checkudata(L, 1, torch_Storage_id);
+ THStorage *storage = luaT_checkudata(L, 1, torch_Storage);
double value = luaL_checknumber(L, 2);
THStorage_(fill)(storage, (real)value);
lua_settop(L, 1);
@@ -90,7 +90,7 @@ static int torch_Storage_(fill)(lua_State *L)
static int torch_Storage_(__len__)(lua_State *L)
{
- THStorage *storage = luaT_checkudata(L, 1, torch_Storage_id);
+ THStorage *storage = luaT_checkudata(L, 1, torch_Storage);
lua_pushnumber(L, storage->size);
return 1;
}
@@ -99,7 +99,7 @@ static int torch_Storage_(__newindex__)(lua_State *L)
{
if(lua_isnumber(L, 2))
{
- THStorage *storage = luaT_checkudata(L, 1, torch_Storage_id);
+ THStorage *storage = luaT_checkudata(L, 1, torch_Storage);
long index = luaL_checklong(L, 2) - 1;
double number = luaL_checknumber(L, 3);
THStorage_(set)(storage, index, (real)number);
@@ -115,7 +115,7 @@ static int torch_Storage_(__index__)(lua_State *L)
{
if(lua_isnumber(L, 2))
{
- THStorage *storage = luaT_checkudata(L, 1, torch_Storage_id);
+ THStorage *storage = luaT_checkudata(L, 1, torch_Storage);
long index = luaL_checklong(L, 2) - 1;
lua_pushnumber(L, THStorage_(get)(storage, index));
lua_pushboolean(L, 1);
@@ -131,7 +131,7 @@ static int torch_Storage_(__index__)(lua_State *L)
#if defined(TH_REAL_IS_CHAR) || defined(TH_REAL_IS_BYTE)
static int torch_Storage_(string)(lua_State *L)
{
- THStorage *storage = luaT_checkudata(L, 1, torch_Storage_id);
+ THStorage *storage = luaT_checkudata(L, 1, torch_Storage);
if(lua_isstring(L, -1))
{
size_t len = 0;
@@ -149,7 +149,7 @@ static int torch_Storage_(string)(lua_State *L)
static int torch_Storage_(totable)(lua_State *L)
{
- THStorage *storage = luaT_checkudata(L, 1, torch_Storage_id);
+ THStorage *storage = luaT_checkudata(L, 1, torch_Storage);
long i;
lua_newtable(L);
@@ -164,14 +164,14 @@ static int torch_Storage_(totable)(lua_State *L)
static int torch_Storage_(factory)(lua_State *L)
{
THStorage *storage = THStorage_(new)();
- luaT_pushudata(L, storage, torch_Storage_id);
+ luaT_pushudata(L, storage, torch_Storage);
return 1;
}
static int torch_Storage_(write)(lua_State *L)
{
- THStorage *storage = luaT_checkudata(L, 1, torch_Storage_id);
- THFile *file = luaT_checkudata(L, 2, torch_File_id);
+ THStorage *storage = luaT_checkudata(L, 1, torch_Storage);
+ THFile *file = luaT_checkudata(L, 2, "torch.File");
THFile_writeLongScalar(file, storage->size);
THFile_writeRealRaw(file, storage->data, storage->size);
@@ -181,8 +181,8 @@ static int torch_Storage_(write)(lua_State *L)
static int torch_Storage_(read)(lua_State *L)
{
- THStorage *storage = luaT_checkudata(L, 1, torch_Storage_id);
- THFile *file = luaT_checkudata(L, 2, torch_File_id);
+ THStorage *storage = luaT_checkudata(L, 1, torch_Storage);
+ THFile *file = luaT_checkudata(L, 2, "torch.File");
long size = THFile_readLongScalar(file);
THStorage_(resize)(storage, size);
@@ -210,10 +210,8 @@ static const struct luaL_Reg torch_Storage_(_) [] = {
void torch_Storage_(init)(lua_State *L)
{
- torch_File_id = luaT_checktypename2id(L, "torch.File");
-
- torch_Storage_id = luaT_newmetatable(L, STRING_torchStorage, NULL,
- torch_Storage_(new), torch_Storage_(free), torch_Storage_(factory));
+ luaT_newmetatable(L, torch_Storage, NULL,
+ torch_Storage_(new), torch_Storage_(free), torch_Storage_(factory));
luaL_register(L, NULL, torch_Storage_(_));
lua_pop(L, 1);
}
diff --git a/generic/Tensor.c b/generic/Tensor.c
index 06de788..ced0a68 100644
--- a/generic/Tensor.c
+++ b/generic/Tensor.c
@@ -9,7 +9,7 @@ static void torch_Tensor_(c_readSizeStride)(lua_State *L, int index, int allowSt
static int torch_Tensor_(size)(lua_State *L)
{
- THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id);
+ THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor);
if(lua_isnumber(L,2))
{
int dim = luaL_checkint(L, 2)-1;
@@ -20,14 +20,14 @@ static int torch_Tensor_(size)(lua_State *L)
{
THLongStorage *storage = THLongStorage_newWithSize(tensor->nDimension);
memmove(storage->data, tensor->size, sizeof(long)*tensor->nDimension);
- luaT_pushudata(L, storage, torch_LongStorage_id);
+ luaT_pushudata(L, storage, "torch.LongStorage");
}
return 1;
}
static int torch_Tensor_(stride)(lua_State *L)
{
- THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id);
+ THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor);
if(lua_isnumber(L,2))
{
int dim = luaL_checkint(L, 2)-1;
@@ -38,25 +38,25 @@ static int torch_Tensor_(stride)(lua_State *L)
{
THLongStorage *storage = THLongStorage_newWithSize(tensor->nDimension);
memmove(storage->data, tensor->stride, sizeof(long)*tensor->nDimension);
- luaT_pushudata(L, storage, torch_LongStorage_id);
+ luaT_pushudata(L, storage, "torch.LongStorage");
}
return 1;
}
static int torch_Tensor_(nDimension)(lua_State *L)
{
- THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id);
+ THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor);
lua_pushnumber(L, tensor->nDimension);
return 1;
}
static int torch_Tensor_(storage)(lua_State *L)
{
- THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id);
+ THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor);
if(tensor->storage)
{
THStorage_(retain)(tensor->storage);
- luaT_pushudata(L, tensor->storage, torch_Storage_id);
+ luaT_pushudata(L, tensor->storage, torch_Storage);
}
else
lua_pushnil(L);
@@ -66,7 +66,7 @@ static int torch_Tensor_(storage)(lua_State *L)
static int torch_Tensor_(storageOffset)(lua_State *L)
{
- THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id);
+ THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor);
lua_pushnumber(L, tensor->storageOffset+1);
return 1;
}
@@ -197,13 +197,13 @@ static int torch_Tensor_(new)(lua_State *L)
THLongStorage_free(stride);
}
- luaT_pushudata(L, tensor, torch_Tensor_id);
+ luaT_pushudata(L, tensor, torch_Tensor);
return 1;
}
static int torch_Tensor_(set)(lua_State *L)
{
- THTensor *self = luaT_checkudata(L, 1, torch_Tensor_id);
+ THTensor *self = luaT_checkudata(L, 1, torch_Tensor);
THStorage *storage;
long storageOffset;
THLongStorage *size, *stride;
@@ -222,25 +222,25 @@ static int torch_Tensor_(set)(lua_State *L)
static int torch_Tensor_(clone)(lua_State *L)
{
- THTensor *self = luaT_checkudata(L, 1, torch_Tensor_id);
+ THTensor *self = luaT_checkudata(L, 1, torch_Tensor);
self = THTensor_(newClone)(self);
- luaT_pushudata(L, self, torch_Tensor_id);
+ luaT_pushudata(L, self, torch_Tensor);
return 1;
}
static int torch_Tensor_(contiguous)(lua_State *L)
{
- THTensor *self = luaT_checkudata(L, 1, torch_Tensor_id);
+ THTensor *self = luaT_checkudata(L, 1, torch_Tensor);
self = THTensor_(newContiguous)(self);
- luaT_pushudata(L, self, torch_Tensor_id);
+ luaT_pushudata(L, self, torch_Tensor);
return 1;
}
/* Resize */
static int torch_Tensor_(resizeAs)(lua_State *L)
{
- THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id);
- THTensor *src = luaT_checkudata(L, 2, torch_Tensor_id);
+ THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor);
+ THTensor *src = luaT_checkudata(L, 2, torch_Tensor);
THTensor_(resizeAs)(tensor, src);
lua_settop(L, 1);
return 1;
@@ -248,7 +248,7 @@ static int torch_Tensor_(resizeAs)(lua_State *L)
static int torch_Tensor_(resize)(lua_State *L)
{
- THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id);
+ THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor);
THLongStorage *size, *stride;
torch_Tensor_(c_readSizeStride)(L, 2, 0, &size, &stride);
@@ -264,7 +264,7 @@ static int torch_Tensor_(resize)(lua_State *L)
static int torch_Tensor_(narrow)(lua_State *L)
{
- THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id);
+ THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor);
int dimension = luaL_checkint(L, 2)-1;
long firstIndex = luaL_checklong(L, 3)-1;
long size = luaL_checklong(L, 4);
@@ -275,13 +275,13 @@ static int torch_Tensor_(narrow)(lua_State *L)
*/
tensor = THTensor_(newWithTensor)(tensor);
THTensor_(narrow)(tensor, NULL, dimension, firstIndex, size);
- luaT_pushudata(L, tensor, torch_Tensor_id);
+ luaT_pushudata(L, tensor, torch_Tensor);
return 1;
}
static int torch_Tensor_(sub)(lua_State *L)
{
- THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id);
+ THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor);
long d0s = -1, d0e = -1, d1s = -1, d1e = -1, d2s = -1, d2e = -1, d3s = -1, d3e = -1;
d0s = luaL_checklong(L, 2)-1;
@@ -345,13 +345,13 @@ static int torch_Tensor_(sub)(lua_State *L)
THTensor_(narrow)(tensor, NULL, 2, d2s, d2e-d2s+1);
if(d3s >= 0)
THTensor_(narrow)(tensor, NULL, 3, d3s, d3e-d3s+1);
- luaT_pushudata(L, tensor, torch_Tensor_id);
+ luaT_pushudata(L, tensor, torch_Tensor);
return 1;
}
static int torch_Tensor_(select)(lua_State *L)
{
- THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id);
+ THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor);
int dimension = luaL_checkint(L, 2)-1;
long sliceIndex = luaL_checklong(L, 3)-1;
@@ -364,7 +364,7 @@ static int torch_Tensor_(select)(lua_State *L)
{
tensor = THTensor_(newWithTensor)(tensor);
THTensor_(select)(tensor, NULL, dimension, sliceIndex);
- luaT_pushudata(L, tensor, torch_Tensor_id);
+ luaT_pushudata(L, tensor, torch_Tensor);
}
else
{
@@ -378,7 +378,7 @@ static int torch_Tensor_(select)(lua_State *L)
static int torch_Tensor_(transpose)(lua_State *L)
{
- THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id);
+ THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor);
int dimension1 = luaL_checkint(L, 2)-1;
int dimension2 = luaL_checkint(L, 3)-1;
@@ -389,25 +389,25 @@ static int torch_Tensor_(transpose)(lua_State *L)
tensor = THTensor_(newWithTensor)(tensor);
THTensor_(transpose)(tensor, NULL, dimension1, dimension2);
- luaT_pushudata(L, tensor, torch_Tensor_id);
+ luaT_pushudata(L, tensor, torch_Tensor);
return 1;
}
static int torch_Tensor_(t)(lua_State *L)
{
- THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id);
+ THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor);
luaL_argcheck(L, tensor->nDimension == 2, 1, "Tensor must have 2 dimensions");
tensor = THTensor_(newWithTensor)(tensor);
THTensor_(transpose)(tensor, NULL, 0, 1);
- luaT_pushudata(L, tensor, torch_Tensor_id);
+ luaT_pushudata(L, tensor, torch_Tensor);
return 1;
}
static int torch_Tensor_(unfold)(lua_State *L)
{
- THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id);
+ THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor);
int dimension = luaL_checkint(L, 2)-1;
long size = luaL_checklong(L, 3);
long step = luaL_checklong(L, 4);
@@ -420,44 +420,44 @@ static int torch_Tensor_(unfold)(lua_State *L)
tensor = THTensor_(newWithTensor)(tensor);
THTensor_(unfold)(tensor, NULL, dimension, size, step);
- luaT_pushudata(L, tensor, torch_Tensor_id);
+ luaT_pushudata(L, tensor, torch_Tensor);
return 1;
}
/* is contiguous? [a bit like in TnXIterator] */
static int torch_Tensor_(isContiguous)(lua_State *L)
{
- THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id);
+ THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor);
lua_pushboolean(L, THTensor_(isContiguous)(tensor));
return 1;
}
static int torch_Tensor_(nElement)(lua_State *L)
{
- THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id);
+ THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor);
lua_pushnumber(L, THTensor_(nElement)(tensor));
return 1;
}
static int torch_Tensor_(copy)(lua_State *L)
{
- THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id);
+ THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor);
void *src;
- if( (src = luaT_toudata(L, 2, torch_Tensor_id)) )
+ if( (src = luaT_toudata(L, 2, torch_Tensor)) )
THTensor_(copy)(tensor, src);
- else if( (src = luaT_toudata(L, 2, torch_ByteTensor_id)) )
+ else if( (src = luaT_toudata(L, 2, "torch.ByteTensor")) )
THTensor_(copyByte)(tensor, src);
- else if( (src = luaT_toudata(L, 2, torch_CharTensor_id)) )
+ else if( (src = luaT_toudata(L, 2, "torch.CharTensor")) )
THTensor_(copyChar)(tensor, src);
- else if( (src = luaT_toudata(L, 2, torch_ShortTensor_id)) )
+ else if( (src = luaT_toudata(L, 2, "torch.ShortTensor")) )
THTensor_(copyShort)(tensor, src);
- else if( (src = luaT_toudata(L, 2, torch_IntTensor_id)) )
+ else if( (src = luaT_toudata(L, 2, "torch.IntTensor")) )
THTensor_(copyInt)(tensor, src);
- else if( (src = luaT_toudata(L, 2, torch_LongTensor_id)) )
+ else if( (src = luaT_toudata(L, 2, "torch.LongTensor")) )
THTensor_(copyLong)(tensor, src);
- else if( (src = luaT_toudata(L, 2, torch_FloatTensor_id)) )
+ else if( (src = luaT_toudata(L, 2, "torch.FloatTensor")) )
THTensor_(copyFloat)(tensor, src);
- else if( (src = luaT_toudata(L, 2, torch_DoubleTensor_id)) )
+ else if( (src = luaT_toudata(L, 2, "torch.DoubleTensor")) )
THTensor_(copyDouble)(tensor, src);
else
luaL_typerror(L, 2, "torch.*Tensor");
@@ -467,7 +467,7 @@ static int torch_Tensor_(copy)(lua_State *L)
static int torch_Tensor_(__newindex__)(lua_State *L)
{
- THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id);
+ THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor);
THLongStorage *idx = NULL;
THByteTensor *mask;
@@ -487,42 +487,42 @@ static int torch_Tensor_(__newindex__)(lua_State *L)
THTensor_(fill)(tensor, value);
THTensor_(free)(tensor);
}
- } else if( (src = luaT_toudata(L, 3, torch_Tensor_id)) ) {
+ } else if( (src = luaT_toudata(L, 3, torch_Tensor)) ) {
tensor = THTensor_(newWithTensor)(tensor);
THTensor_(narrow)(tensor, NULL, 0, index, 1);
THTensor_(copy)(tensor, src);
THTensor_(free)(tensor);
- } else if( (src = luaT_toudata(L, 3, torch_ByteTensor_id)) ) {
+ } else if( (src = luaT_toudata(L, 3, "torch.ByteTensor")) ) {
tensor = THTensor_(newWithTensor)(tensor);
THTensor_(narrow)(tensor, NULL, 0, index, 1);
THTensor_(copyByte)(tensor, src);
THTensor_(free)(tensor);
- } else if( (src = luaT_toudata(L, 3, torch_CharTensor_id)) ) {
+ } else if( (src = luaT_toudata(L, 3, "torch.CharTensor")) ) {
tensor = THTensor_(newWithTensor)(tensor);
THTensor_(narrow)(tensor, NULL, 0, index, 1);
THTensor_(copyChar)(tensor, src);
THTensor_(free)(tensor);
- } else if( (src = luaT_toudata(L, 3, torch_ShortTensor_id)) ) {
+ } else if( (src = luaT_toudata(L, 3, "torch.ShortTensor")) ) {
tensor = THTensor_(newWithTensor)(tensor);
THTensor_(narrow)(tensor, NULL, 0, index, 1);
THTensor_(copyShort)(tensor, src);
THTensor_(free)(tensor);
- } else if( (src = luaT_toudata(L, 3, torch_IntTensor_id)) ) {
+ } else if( (src = luaT_toudata(L, 3, "torch.IntTensor")) ) {
tensor = THTensor_(newWithTensor)(tensor);
THTensor_(narrow)(tensor, NULL, 0, index, 1);
THTensor_(copyInt)(tensor, src);
THTensor_(free)(tensor);
- } else if( (src = luaT_toudata(L, 3, torch_LongTensor_id)) ) {
+ } else if( (src = luaT_toudata(L, 3, "torch.LongTensor")) ) {
tensor = THTensor_(newWithTensor)(tensor);
THTensor_(narrow)(tensor, NULL, 0, index, 1);
THTensor_(copyLong)(tensor, src);
THTensor_(free)(tensor);
- } else if( (src = luaT_toudata(L, 3, torch_FloatTensor_id)) ) {
+ } else if( (src = luaT_toudata(L, 3, "torch.FloatTensor")) ) {
tensor = THTensor_(newWithTensor)(tensor);
THTensor_(narrow)(tensor, NULL, 0, index, 1);
THTensor_(copyFloat)(tensor, src);
THTensor_(free)(tensor);
- } else if( (src = luaT_toudata(L, 3, torch_DoubleTensor_id)) ) {
+ } else if( (src = luaT_toudata(L, 3, "torch.DoubleTensor")) ) {
tensor = THTensor_(newWithTensor)(tensor);
THTensor_(narrow)(tensor, NULL, 0, index, 1);
THTensor_(copyDouble)(tensor, src);
@@ -532,7 +532,7 @@ static int torch_Tensor_(__newindex__)(lua_State *L)
}
lua_pushboolean(L, 1);
}
- else if((idx = luaT_toudata(L, 2, torch_LongStorage_id)))
+ else if((idx = luaT_toudata(L, 2, "torch.LongStorage")))
{
long index = THTensor_(storageOffset)(tensor);
real value = (real)luaL_checknumber(L,3);
@@ -611,37 +611,37 @@ static int torch_Tensor_(__newindex__)(lua_State *L)
void *src;
if (lua_isnumber(L,3)) {
THTensor_(fill)(tensor, lua_tonumber(L,3));
- } else if( (src = luaT_toudata(L, 3, torch_Tensor_id)) ) {
+ } else if( (src = luaT_toudata(L, 3, torch_Tensor)) ) {
THTensor_(copy)(tensor, src);
- } else if( (src = luaT_toudata(L, 3, torch_ByteTensor_id)) ) {
+ } else if( (src = luaT_toudata(L, 3, "torch.ByteTensor")) ) {
THTensor_(copyByte)(tensor, src);
- } else if( (src = luaT_toudata(L, 3, torch_CharTensor_id)) ) {
+ } else if( (src = luaT_toudata(L, 3, "torch.CharTensor")) ) {
THTensor_(copyChar)(tensor, src);
- } else if( (src = luaT_toudata(L, 3, torch_ShortTensor_id)) ) {
+ } else if( (src = luaT_toudata(L, 3, "torch.ShortTensor")) ) {
THTensor_(copyShort)(tensor, src);
- } else if( (src = luaT_toudata(L, 3, torch_IntTensor_id)) ) {
+ } else if( (src = luaT_toudata(L, 3, "torch.IntTensor")) ) {
THTensor_(copyInt)(tensor, src);
- } else if( (src = luaT_toudata(L, 3, torch_LongTensor_id)) ) {
+ } else if( (src = luaT_toudata(L, 3, "torch.LongTensor")) ) {
THTensor_(copyLong)(tensor, src);
- } else if( (src = luaT_toudata(L, 3, torch_FloatTensor_id)) ) {
+ } else if( (src = luaT_toudata(L, 3, "torch.FloatTensor")) ) {
THTensor_(copyFloat)(tensor, src);
- } else if( (src = luaT_toudata(L, 3, torch_DoubleTensor_id)) ) {
+ } else if( (src = luaT_toudata(L, 3, "torch.DoubleTensor")) ) {
THTensor_(copyDouble)(tensor, src);
} else {
luaL_typerror(L, 3, "torch.*Tensor");
}
- }
+ }
THTensor_(free)(tensor);
lua_pushboolean(L, 1);
}
- else if((mask = luaT_toudata(L, 2, torch_ByteTensor_id)))
+ else if((mask = luaT_toudata(L, 2, "torch.ByteTensor")))
{
THTensor *vals;
if (lua_isnumber(L, 3))
{
THTensor_(maskedFill)(tensor, mask, (real)(luaL_checknumber(L,3)));
}
- else if((vals = luaT_toudata(L, 3, torch_Tensor_id)))
+ else if((vals = luaT_toudata(L, 3, torch_Tensor)))
{
THTensor_(maskedCopy)(tensor, mask, vals);
}
@@ -658,7 +658,7 @@ static int torch_Tensor_(__newindex__)(lua_State *L)
static int torch_Tensor_(__index__)(lua_State *L)
{
- THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id);
+ THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor);
THLongStorage *idx = NULL;
THByteTensor *mask;
@@ -678,12 +678,12 @@ static int torch_Tensor_(__index__)(lua_State *L)
{
tensor = THTensor_(newWithTensor)(tensor);
THTensor_(select)(tensor, NULL, 0, index);
- luaT_pushudata(L, tensor, torch_Tensor_id);
+ luaT_pushudata(L, tensor, torch_Tensor);
}
lua_pushboolean(L, 1);
return 2;
}
- else if((idx = luaT_toudata(L, 2, torch_LongStorage_id)))
+ else if((idx = luaT_toudata(L, 2, "torch.LongStorage")))
{
long index = THTensor_(storageOffset)(tensor);
int dim;
@@ -756,18 +756,18 @@ static int torch_Tensor_(__index__)(lua_State *L)
}
}
if(!done) {
- luaT_pushudata(L, tensor, torch_Tensor_id);
+ luaT_pushudata(L, tensor, torch_Tensor);
} else {
THTensor_(free)(tensor);
}
lua_pushboolean(L, 1);
return 2;
}
- else if((mask = luaT_toudata(L, 2, torch_ByteTensor_id)))
+ else if((mask = luaT_toudata(L, 2, "torch.ByteTensor")))
{
THTensor *vals = THTensor_(new)();
THTensor_(maskedSelect)(vals, tensor, mask);
- luaT_pushudata(L, vals, torch_Tensor_id);
+ luaT_pushudata(L, vals, torch_Tensor);
lua_pushboolean(L, 1);
return 2;
}
@@ -780,7 +780,7 @@ static int torch_Tensor_(__index__)(lua_State *L)
static int torch_Tensor_(free)(lua_State *L)
{
- THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id);
+ THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor);
THTensor_(free)(tensor);
return 0;
}
@@ -791,11 +791,11 @@ static void torch_Tensor_(c_readSizeStride)(lua_State *L, int index, int allowSt
THLongStorage *size = NULL;
THLongStorage *stride = NULL;
- if( (size = luaT_toudata(L, index, torch_LongStorage_id)) )
+ if( (size = luaT_toudata(L, index, "torch.LongStorage")) )
{
if(!lua_isnoneornil(L, index+1))
{
- if( (stride = luaT_toudata(L, index+1, torch_LongStorage_id)) )
+ if( (stride = luaT_toudata(L, index+1, "torch.LongStorage")) )
luaL_argcheck(L, stride->size == size->size, index+1, "provided stride and size are inconsistent");
else
luaL_argcheck(L, 0, index+1, "torch.LongStorage expected");
@@ -858,7 +858,7 @@ static void torch_Tensor_(c_readTensorStorageSizeStride)(lua_State *L, int index
*stride_ = NULL;
return;
}
- else if( allowTensor && (arg1Type == LUA_TUSERDATA) && (src = luaT_toudata(L, index, torch_Tensor_id)) )
+ else if( allowTensor && (arg1Type == LUA_TUSERDATA) && (src = luaT_toudata(L, index, torch_Tensor)) )
{
*storage_ = src->storage;
*storageOffset_ = src->storageOffset;
@@ -866,7 +866,7 @@ static void torch_Tensor_(c_readTensorStorageSizeStride)(lua_State *L, int index
*stride_ = THTensor_(newStrideOf)(src);
return;
}
- else if( allowStorage && (arg1Type == LUA_TUSERDATA) && (storage = luaT_toudata(L, index, torch_Storage_id)) )
+ else if( allowStorage && (arg1Type == LUA_TUSERDATA) && (storage = luaT_toudata(L, index, torch_Storage)) )
{
*storage_ = storage;
if(lua_isnone(L, index+1))
@@ -882,7 +882,7 @@ static void torch_Tensor_(c_readTensorStorageSizeStride)(lua_State *L, int index
}
return;
}
- else if( (arg1Type == LUA_TNUMBER) || (luaT_toudata(L, index, torch_LongStorage_id)) )
+ else if( (arg1Type == LUA_TNUMBER) || (luaT_toudata(L, index, "torch.LongStorage")) )
{
*storage_ = NULL;
*storageOffset_ = 0;
@@ -900,7 +900,7 @@ static void torch_Tensor_(c_readTensorStorageSizeStride)(lua_State *L, int index
static int torch_Tensor_(apply)(lua_State *L)
{
- THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id);
+ THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor);
luaL_checktype(L, 2, LUA_TFUNCTION);
lua_settop(L, 2);
@@ -924,8 +924,8 @@ static int torch_Tensor_(apply)(lua_State *L)
static int torch_Tensor_(map)(lua_State *L)
{
- THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id);
- THTensor *src = luaT_checkudata(L, 2, torch_Tensor_id);
+ THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor);
+ THTensor *src = luaT_checkudata(L, 2, torch_Tensor);
luaL_checktype(L, 3, LUA_TFUNCTION);
lua_settop(L, 3);
@@ -950,9 +950,9 @@ static int torch_Tensor_(map)(lua_State *L)
static int torch_Tensor_(map2)(lua_State *L)
{
- THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id);
- THTensor *src1 = luaT_checkudata(L, 2, torch_Tensor_id);
- THTensor *src2 = luaT_checkudata(L, 3, torch_Tensor_id);
+ THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor);
+ THTensor *src1 = luaT_checkudata(L, 2, torch_Tensor);
+ THTensor *src2 = luaT_checkudata(L, 3, torch_Tensor);
luaL_checktype(L, 4, LUA_TFUNCTION);
lua_settop(L, 4);
@@ -979,14 +979,14 @@ static int torch_Tensor_(map2)(lua_State *L)
static int torch_Tensor_(factory)(lua_State *L)
{
THTensor *tensor = THTensor_(new)();
- luaT_pushudata(L, tensor, torch_Tensor_id);
+ luaT_pushudata(L, tensor, torch_Tensor);
return 1;
}
static int torch_Tensor_(write)(lua_State *L)
{
- THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id);
- THFile *file = luaT_checkudata(L, 2, torch_File_id);
+ THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor);
+ THFile *file = luaT_checkudata(L, 2, "torch.File");
THFile_writeIntScalar(file, tensor->nDimension);
THFile_writeLongRaw(file, tensor->size, tensor->nDimension);
@@ -999,7 +999,7 @@ static int torch_Tensor_(write)(lua_State *L)
if(tensor->storage)
{
THStorage_(retain)(tensor->storage);
- luaT_pushudata(L, tensor->storage, torch_Storage_id);
+ luaT_pushudata(L, tensor->storage, torch_Storage);
}
else
lua_pushnil(L);
@@ -1011,8 +1011,8 @@ static int torch_Tensor_(write)(lua_State *L)
static int torch_Tensor_(read)(lua_State *L)
{
- THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id);
- THFile *file = luaT_checkudata(L, 2, torch_File_id);
+ THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor);
+ THFile *file = luaT_checkudata(L, 2, "torch.File");
tensor->nDimension = THFile_readIntScalar(file);
tensor->size = THAlloc(sizeof(long)*tensor->nDimension);
@@ -1026,7 +1026,7 @@ static int torch_Tensor_(read)(lua_State *L)
lua_pushvalue(L, 2); /* the file */
lua_call(L, 1, 1); /* call the method */
- tensor->storage = luaT_toudata(L, -1, torch_Storage_id);
+ tensor->storage = luaT_toudata(L, -1, torch_Storage);
if(tensor->storage)
THStorage_(retain)(tensor->storage);
@@ -1068,12 +1068,8 @@ static const struct luaL_Reg torch_Tensor_(_) [] = {
void torch_Tensor_(init)(lua_State *L)
{
- torch_File_id = luaT_checktypename2id(L, "torch.File");
- torch_LongStorage_id = luaT_checktypename2id(L, "torch.LongStorage");
- torch_Storage_id = luaT_checktypename2id(L, STRING_torchStorage);
-
- torch_Tensor_id = luaT_newmetatable(L, STRING_torchTensor, NULL,
- torch_Tensor_(new), torch_Tensor_(free), torch_Tensor_(factory));
+ luaT_newmetatable(L, torch_Tensor, NULL,
+ torch_Tensor_(new), torch_Tensor_(free), torch_Tensor_(factory));
luaL_register(L, NULL, torch_Tensor_(_));
lua_pop(L, 1);
}
diff --git a/generic/TensorOperator.c b/generic/TensorOperator.c
index 10d8f8e..c69ca38 100644
--- a/generic/TensorOperator.c
+++ b/generic/TensorOperator.c
@@ -2,12 +2,10 @@
#define TH_GENERIC_FILE "generic/TensorOperator.c"
#else
-static const void* torch_Tensor_id;
-
static int torch_TensorOperator_(__add__)(lua_State *L)
{
- THTensor *tensor1 = luaT_toudata(L, 1, torch_Tensor_id);
- THTensor *tensor2 = luaT_toudata(L, 2, torch_Tensor_id);
+ THTensor *tensor1 = luaT_toudata(L, 1, torch_Tensor);
+ THTensor *tensor2 = luaT_toudata(L, 2, torch_Tensor);
THTensor *r;
if(!tensor1 && !tensor2)
@@ -15,7 +13,7 @@ static int torch_TensorOperator_(__add__)(lua_State *L)
else
{
r = THTensor_(new)();
- luaT_pushudata(L, r, torch_Tensor_id);
+ luaT_pushudata(L, r, torch_Tensor);
if(!tensor1 && tensor2)
{
@@ -41,8 +39,8 @@ static int torch_TensorOperator_(__add__)(lua_State *L)
static int torch_TensorOperator_(__sub__)(lua_State *L)
{
- THTensor *tensor1 = luaT_toudata(L, 1, torch_Tensor_id);
- THTensor *tensor2 = luaT_toudata(L, 2, torch_Tensor_id);
+ THTensor *tensor1 = luaT_toudata(L, 1, torch_Tensor);
+ THTensor *tensor2 = luaT_toudata(L, 2, torch_Tensor);
THTensor *r;
if(!tensor1 && !tensor2)
@@ -50,7 +48,7 @@ static int torch_TensorOperator_(__sub__)(lua_State *L)
else
{
r = THTensor_(new)();
- luaT_pushudata(L, r, torch_Tensor_id);
+ luaT_pushudata(L, r, torch_Tensor);
if(!tensor1 && tensor2)
{
@@ -76,11 +74,11 @@ static int torch_TensorOperator_(__sub__)(lua_State *L)
static int torch_TensorOperator_(__unm__)(lua_State *L)
{
- THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id);
+ THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor);
THTensor *r;
r = THTensor_(new)();
- luaT_pushudata(L, r, torch_Tensor_id);
+ luaT_pushudata(L, r, torch_Tensor);
THTensor_(resizeAs)(r, tensor);
THTensor_(copy)(r, tensor);
THTensor_(mul)(r, r, -1);
@@ -90,8 +88,8 @@ static int torch_TensorOperator_(__unm__)(lua_State *L)
static int torch_TensorOperator_(__mul__)(lua_State *L)
{
- THTensor *tensor1 = luaT_toudata(L, 1, torch_Tensor_id);
- THTensor *tensor2 = luaT_toudata(L, 2, torch_Tensor_id);
+ THTensor *tensor1 = luaT_toudata(L, 1, torch_Tensor);
+ THTensor *tensor2 = luaT_toudata(L, 2, torch_Tensor);
THTensor *r;
if(!tensor1 && !tensor2)
@@ -99,7 +97,7 @@ static int torch_TensorOperator_(__mul__)(lua_State *L)
else
{
r = THTensor_(new)();
- luaT_pushudata(L, r, torch_Tensor_id);
+ luaT_pushudata(L, r, torch_Tensor);
if(!tensor1 && tensor2)
{
@@ -141,13 +139,13 @@ static int torch_TensorOperator_(__mul__)(lua_State *L)
static int torch_TensorOperator_(__div__)(lua_State *L)
{
- THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor_id);
+ THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor);
THTensor *r;
luaL_argcheck(L, lua_isnumber(L,2), 2, "number expected");
r = THTensor_(new)();
- luaT_pushudata(L, r, torch_Tensor_id);
+ luaT_pushudata(L, r, torch_Tensor);
THTensor_(resizeAs)(r, tensor);
THTensor_(copy)(r, tensor);
@@ -167,9 +165,7 @@ static const struct luaL_Reg torch_TensorOperator_(_) [] = {
void torch_TensorOperator_(init)(lua_State *L)
{
- torch_Tensor_id = luaT_checktypename2id(L, STRING_torchTensor);
-
- luaT_pushmetaclass(L, torch_Tensor_id);
+ luaT_pushmetatable(L, torch_Tensor);
luaL_register(L, NULL, torch_TensorOperator_(_));
lua_pop(L, 1);
}