diff options
author | Adam Lerer <alerer@fb.com> | 2015-08-21 10:12:22 +0300 |
---|---|---|
committer | Adam Lerer <alerer@fb.com> | 2015-12-26 01:01:04 +0300 |
commit | df463695f0cd387736917e02abaafa63c00bbed3 (patch) | |
tree | ded8f3bd463e47a4a9c4b13f744e4010e6345703 /generic | |
parent | d3c6fa5e2648f902f497fe81d0fa30c62552e4e9 (diff) |
Add generic CudaTensor types to cutorch
Diffstat (limited to 'generic')
-rw-r--r-- | generic/CStorage.c | 193 | ||||
-rw-r--r-- | generic/CTensor.c | 238 |
2 files changed, 208 insertions, 223 deletions
diff --git a/generic/CStorage.c b/generic/CStorage.c index 11ea696..c5626d4 100644 --- a/generic/CStorage.c +++ b/generic/CStorage.c @@ -1,65 +1,68 @@ -#include "torch/utils.h" -#include "THC.h" -#include "THFile.h" -#include "luaT.h" +#ifndef THC_GENERIC_FILE +#define THC_GENERIC_FILE "generic/CStorage.c" +#else /* everything is as the generic Storage.c, except few things (see below) */ -#define real float -#define Real Cuda -#define TH_GENERIC_FILE "generic/Storage.c" - -#define torch_Storage_(NAME) TH_CONCAT_4(torch_,Real,Storage_,NAME) - #define THFile_readRealRaw(file, data, size) \ { \ - float *fdata = (float*)THAlloc(sizeof(float)*size); \ - THFile_readFloatRaw(file, fdata, size); \ - THCudaCheck(cudaMemcpy(data, fdata, size * sizeof(float), cudaMemcpyHostToDevice)); \ + real *fdata = (real*)THAlloc(sizeof(real)*size); \ + TH_CONCAT_3(THFile_read,Real,Raw)(file, fdata, size); \ + THCudaCheck(cudaMemcpy(data, fdata, size * sizeof(real), cudaMemcpyHostToDevice)); \ THFree(fdata); \ } #define THFile_writeRealRaw(file, data, size) \ { \ - float *fdata = (float*)THAlloc(sizeof(float)*size); \ - THCudaCheck(cudaMemcpy(fdata, data, size * sizeof(float), cudaMemcpyDeviceToHost)); \ - THFile_writeFloatRaw(file, fdata, size); \ + real *fdata = (real*)THAlloc(sizeof(real)*size); \ + THCudaCheck(cudaMemcpy(fdata, data, size * sizeof(real), cudaMemcpyDeviceToHost)); \ + TH_CONCAT_3(THFile_write,Real,Raw)(file, fdata, size); \ THFree(fdata); \ } -#define torch_Storage TH_CONCAT_STRING_3(torch.,Real,Storage) - +#define TH_GENERIC_FILE "generic/Storage.c" #include "generic/Storage.c" -#undef real -#undef Real #undef TH_GENERIC_FILE +#undef THFile_readRealRaw +#undef THFile_writeRealRaw /* now we overwrite some methods specific to CudaStorage */ -static int cutorch_CudaStorage_copy(lua_State *L) +static int cutorch_Storage_(copy)(lua_State *L) { THCState *state = cutorch_getstate(L); - THCudaStorage *storage = luaT_checkudata(L, 1, "torch.CudaStorage"); + THCStorage *storage = luaT_checkudata(L, 1, torch_Storage); void *src; - if( (src = luaT_toudata(L, 2, "torch.CudaStorage")) ) - THCudaStorage_copy(state, storage, src); + if( (src = luaT_toudata(L, 2, "torch.CudaByteStorage")) ) + THCStorage_(copyCudaByte)(state, storage, src); + else if( (src = luaT_toudata(L, 2, "torch.CudaCharStorage")) ) + THCStorage_(copyCudaChar)(state, storage, src); + else if( (src = luaT_toudata(L, 2, "torch.CudaShortStorage")) ) + THCStorage_(copyCudaShort)(state, storage, src); + else if( (src = luaT_toudata(L, 2, "torch.CudaIntStorage")) ) + THCStorage_(copyCudaInt)(state, storage, src); + else if( (src = luaT_toudata(L, 2, "torch.CudaLongStorage")) ) + THCStorage_(copyCudaLong)(state, storage, src); + else if( (src = luaT_toudata(L, 2, "torch.CudaStorage")) ) + THCStorage_(copyCudaFloat)(state, storage, src); + else if( (src = luaT_toudata(L, 2, "torch.CudaDoubleStorage")) ) + THCStorage_(copyCudaDouble)(state, storage, src); + else if( (src = luaT_toudata(L, 2, "torch.ByteStorage")) ) - THCudaStorage_copyByte(state, storage, src); + THCStorage_(copyByte)(state, storage, src); else if( (src = luaT_toudata(L, 2, "torch.CharStorage")) ) - THCudaStorage_copyChar(state, storage, src); + THCStorage_(copyChar)(state, storage, src); else if( (src = luaT_toudata(L, 2, "torch.ShortStorage")) ) - THCudaStorage_copyShort(state, storage, src); + THCStorage_(copyShort)(state, storage, src); else if( (src = luaT_toudata(L, 2, "torch.IntStorage")) ) - THCudaStorage_copyInt(state, storage, src); + THCStorage_(copyInt)(state, storage, src); else if( (src = luaT_toudata(L, 2, "torch.LongStorage")) ) - THCudaStorage_copyLong(state, storage, src); + THCStorage_(copyLong)(state, storage, src); else if( (src = luaT_toudata(L, 2, "torch.FloatStorage")) ) - THCudaStorage_copyFloat(state, storage, src); + THCStorage_(copyFloat)(state, storage, src); else if( (src = luaT_toudata(L, 2, "torch.DoubleStorage")) ) - THCudaStorage_copyDouble(state, storage, src); - else if( (src = luaT_toudata(L, 2, "torch.CudaStorage")) ) - THCudaStorage_copyCuda(state, storage, src); + THCStorage_(copyDouble)(state, storage, src); else luaL_typerror(L, 2, "torch.*Storage"); @@ -67,77 +70,63 @@ static int cutorch_CudaStorage_copy(lua_State *L) return 1; } -#define CUDA_IMPLEMENT_STORAGE_COPY(TYPEC) \ - static int cutorch_##TYPEC##Storage_copy(lua_State *L) \ - { \ - TH##TYPEC##Storage *storage = luaT_checkudata(L, 1, "torch." #TYPEC "Storage"); \ - void *src; \ - if( (src = luaT_toudata(L, 2, "torch." #TYPEC "Storage")) ) \ - TH##TYPEC##Storage_copy(storage, src); \ - else if( (src = luaT_toudata(L, 2, "torch.ByteStorage")) ) \ - TH##TYPEC##Storage_copyByte(storage, src); \ - else if( (src = luaT_toudata(L, 2, "torch.CharStorage")) ) \ - TH##TYPEC##Storage_copyChar(storage, src); \ - else if( (src = luaT_toudata(L, 2, "torch.ShortStorage")) ) \ - TH##TYPEC##Storage_copyShort(storage, src); \ - else if( (src = luaT_toudata(L, 2, "torch.IntStorage")) ) \ - TH##TYPEC##Storage_copyInt(storage, src); \ - else if( (src = luaT_toudata(L, 2, "torch.LongStorage")) ) \ - TH##TYPEC##Storage_copyLong(storage, src); \ - else if( (src = luaT_toudata(L, 2, "torch.FloatStorage")) ) \ - TH##TYPEC##Storage_copyFloat(storage, src); \ - else if( (src = luaT_toudata(L, 2, "torch.DoubleStorage")) ) \ - TH##TYPEC##Storage_copyDouble(storage, src); \ - else if( (src = luaT_toudata(L, 2, "torch.CudaStorage")) ) \ - TH##TYPEC##Storage_copyCuda(cutorch_getstate(L), storage, src); \ - else \ - luaL_typerror(L, 2, "torch.*Storage"); \ - \ - lua_settop(L, 1); \ - return 1; \ -} +static int TH_CONCAT_3(cutorch_,Real,Storage_copy)(lua_State *L) +{ + THStorage *storage = luaT_checkudata(L, 1, TH_CONCAT_STRING_3(torch.,Real,Storage)); + void *src; + if( (src = luaT_toudata(L, 2, TH_CONCAT_STRING_3(torch.,Real,Storage) ))) + THStorage_(copy)(storage, src); + else if( (src = luaT_toudata(L, 2, "torch.ByteStorage")) ) + THStorage_(copyByte)(storage, src); + else if( (src = luaT_toudata(L, 2, "torch.CharStorage")) ) + THStorage_(copyChar)(storage, src); + else if( (src = luaT_toudata(L, 2, "torch.ShortStorage")) ) + THStorage_(copyShort)(storage, src); + else if( (src = luaT_toudata(L, 2, "torch.IntStorage")) ) + THStorage_(copyInt)(storage, src); + else if( (src = luaT_toudata(L, 2, "torch.LongStorage")) ) + THStorage_(copyLong)(storage, src); + else if( (src = luaT_toudata(L, 2, "torch.FloatStorage")) ) + THStorage_(copyFloat)(storage, src); + else if( (src = luaT_toudata(L, 2, "torch.DoubleStorage")) ) + THStorage_(copyDouble)(storage, src); + else if( (src = luaT_toudata(L, 2, "torch.CudaStorage")) ) + THStorage_(copyCudaFloat)(cutorch_getstate(L), storage, src); + else if( (src = luaT_toudata(L, 2, "torch.CudaLongStorage")) ) + THStorage_(copyCudaLong)(cutorch_getstate(L), storage, src); + else if( (src = luaT_toudata(L, 2, "torch.CudaByteStorage")) ) + THStorage_(copyCudaByte)(cutorch_getstate(L), storage, src); + else if( (src = luaT_toudata(L, 2, "torch.CudaCharStorage")) ) + THStorage_(copyCudaChar)(cutorch_getstate(L), storage, src); + else if( (src = luaT_toudata(L, 2, "torch.CudaShortStorage")) ) + THStorage_(copyCudaShort)(cutorch_getstate(L), storage, src); + else if( (src = luaT_toudata(L, 2, "torch.CudaIntStorage")) ) + THStorage_(copyCudaInt)(cutorch_getstate(L), storage, src); + else if( (src = luaT_toudata(L, 2, "torch.CudaDoubleStorage")) ) + THStorage_(copyCudaDouble)(cutorch_getstate(L), storage, src); + else + luaL_typerror(L, 2, "torch.*Storage"); -CUDA_IMPLEMENT_STORAGE_COPY(Byte) -CUDA_IMPLEMENT_STORAGE_COPY(Char) -CUDA_IMPLEMENT_STORAGE_COPY(Short) -CUDA_IMPLEMENT_STORAGE_COPY(Int) -CUDA_IMPLEMENT_STORAGE_COPY(Long) -CUDA_IMPLEMENT_STORAGE_COPY(Float) -CUDA_IMPLEMENT_STORAGE_COPY(Double) + lua_settop(L, 1); + return 1; +} -void cutorch_CudaStorage_init(lua_State* L) +void cutorch_Storage_(init)(lua_State* L) { /* the standard stuff */ - torch_CudaStorage_init(L); - - /* the copy methods */ - { - int i; - - const void* tnames[8] = {"torch.ByteStorage", - "torch.CharStorage", - "torch.ShortStorage", - "torch.IntStorage", - "torch.LongStorage", - "torch.FloatStorage", - "torch.DoubleStorage", - "torch.CudaStorage"}; - - static int (*funcs[8])(lua_State*) = {cutorch_ByteStorage_copy, - cutorch_CharStorage_copy, - cutorch_ShortStorage_copy, - cutorch_IntStorage_copy, - cutorch_LongStorage_copy, - cutorch_FloatStorage_copy, - cutorch_DoubleStorage_copy, - cutorch_CudaStorage_copy}; - - for(i = 0; i < 8; i++) - { - luaT_pushmetatable(L, tnames[i]); - lua_pushcfunction(L, funcs[i]); - lua_setfield(L, -2, "copy"); - lua_pop(L, 1); - } - } + torch_Storage_(init)(L); + + // torch_Storage macro is defined in Storage.c produce the CudaTensor types + // so I have to construct the normal torch types by hand + luaT_pushmetatable(L, TH_CONCAT_STRING_3(torch.,Real,Storage)); + lua_pushcfunction(L, TH_CONCAT_3(cutorch_,Real,Storage_copy)); + lua_setfield(L, -2, "copy"); + lua_pop(L, 1); + + luaT_pushmetatable(L, torch_Storage); + lua_pushcfunction(L, cutorch_Storage_(copy)); + lua_setfield(L, -2, "copy"); + lua_pop(L, 1); } + +#endif diff --git a/generic/CTensor.c b/generic/CTensor.c index 5666dec..79a8a48 100644 --- a/generic/CTensor.c +++ b/generic/CTensor.c @@ -1,49 +1,48 @@ -#include "torch/utils.h" -#include "THC.h" -#include "THFile.h" -#include "luaT.h" +#ifndef THC_GENERIC_FILE +#define THC_GENERIC_FILE "generic/CTensor.c" +#else /* everything is as the generic Storage.c, except few things (see below) */ -#define real float -#define Real Cuda - -#define torch_Storage_(NAME) TH_CONCAT_4(torch_,Real,Storage_,NAME) -#define torch_Storage TH_CONCAT_STRING_3(torch.,Real,Storage) -#define torch_Tensor_(NAME) TH_CONCAT_4(torch_,Real,Tensor_,NAME) -#define torch_Tensor TH_CONCAT_STRING_3(torch.,Real,Tensor) - #define TH_GENERIC_FILE "generic/Tensor.c" #include "generic/Tensor.c" #undef TH_GENERIC_FILE -#undef real -#undef Real - /* now we overwrite some methods specific to CudaTensor */ -static int cutorch_CudaTensor_copy(lua_State *L) +static int cutorch_Tensor_(copy)(lua_State *L) { THCState *state = cutorch_getstate(L); - THCudaTensor *storage = luaT_checkudata(L, 1, "torch.CudaTensor"); + THCTensor *tensor = luaT_checkudata(L, 1, torch_Tensor); void *src; if( (src = luaT_toudata(L, 2, "torch.CudaTensor")) ) - THCudaTensor_copy(state, storage, src); + THCTensor_(copyCudaFloat)(state, tensor, src); + else if( (src = luaT_toudata(L, 2, "torch.CudaByteTensor")) ) + THCTensor_(copyCudaByte)(state, tensor, src); + else if( (src = luaT_toudata(L, 2, "torch.CudaCharTensor")) ) + THCTensor_(copyCudaChar)(state, tensor, src); + else if( (src = luaT_toudata(L, 2, "torch.CudaShortTensor")) ) + THCTensor_(copyCudaShort)(state, tensor, src); + else if( (src = luaT_toudata(L, 2, "torch.CudaIntTensor")) ) + THCTensor_(copyCudaInt)(state, tensor, src); + else if( (src = luaT_toudata(L, 2, "torch.CudaLongTensor")) ) + THCTensor_(copyCudaLong)(state, tensor, src); + else if( (src = luaT_toudata(L, 2, "torch.CudaDoubleTensor")) ) + THCTensor_(copyCudaDouble)(state, tensor, src); + else if( (src = luaT_toudata(L, 2, "torch.ByteTensor")) ) - THCudaTensor_copyByte(state, storage, src); + THCTensor_(copyByte)(state, tensor, src); else if( (src = luaT_toudata(L, 2, "torch.CharTensor")) ) - THCudaTensor_copyChar(state, storage, src); + THCTensor_(copyChar)(state, tensor, src); else if( (src = luaT_toudata(L, 2, "torch.ShortTensor")) ) - THCudaTensor_copyShort(state, storage, src); + THCTensor_(copyShort)(state, tensor, src); else if( (src = luaT_toudata(L, 2, "torch.IntTensor")) ) - THCudaTensor_copyInt(state, storage, src); + THCTensor_(copyInt)(state, tensor, src); else if( (src = luaT_toudata(L, 2, "torch.LongTensor")) ) - THCudaTensor_copyLong(state, storage, src); + THCTensor_(copyLong)(state, tensor, src); else if( (src = luaT_toudata(L, 2, "torch.FloatTensor")) ) - THCudaTensor_copyFloat(state, storage, src); + THCTensor_(copyFloat)(state, tensor, src); else if( (src = luaT_toudata(L, 2, "torch.DoubleTensor")) ) - THCudaTensor_copyDouble(state, storage, src); - else if( (src = luaT_toudata(L, 2, "torch.CudaTensor")) ) - THCudaTensor_copyCuda(state, storage, src); + THCTensor_(copyDouble)(state, tensor, src); else luaL_typerror(L, 2, "torch.*Tensor"); @@ -51,73 +50,84 @@ static int cutorch_CudaTensor_copy(lua_State *L) return 1; } -static int cutorch_CudaTensor_copyAsync(lua_State *L) +static int cutorch_Tensor_(copyAsyncCPU)(lua_State *L) { +#define STRINGIFY_TENSOR(x) TH_CONCAT_STRING_3(torch.,x,Tensor) THCState *state = cutorch_getstate(L); - THCudaTensor *storage = luaT_checkudata(L, 1, "torch.CudaTensor"); + THCTensor *tensor = luaT_checkudata(L, 1, STRINGIFY_TENSOR(CReal)); void *src; - if( (src = luaT_toudata(L, 2, "torch.CudaTensor")) ) - THCudaTensor_copy(state, storage, src); - else if( (src = luaT_toudata(L, 2, "torch.FloatTensor")) ) - THCudaTensor_copyAsyncFloat(state, storage, src); + if( (src = luaT_toudata(L, 2, STRINGIFY_TENSOR(CReal)))) + THCTensor_(copy)(state, tensor, src); + else if( (src = luaT_toudata(L, 2, STRINGIFY_TENSOR(Real)))) + THCTensor_(copyAsyncCPU)(state, tensor, src); else - luaL_typerror(L, 2, "torch.FloatTensor or torch.CudaTensor"); + luaL_typerror(L, 2, STRINGIFY_TENSOR(Real) " or " STRINGIFY_TENSOR(CReal)); lua_settop(L, 1); return 1; +#undef STRINGIFY_TENSOR } -#define CUDA_IMPLEMENT_TENSOR_COPY(TYPEC) \ - static int cutorch_##TYPEC##Tensor_copy(lua_State *L) \ - { \ - TH##TYPEC##Tensor *storage = luaT_checkudata(L, 1, "torch." #TYPEC "Tensor"); \ - void *src; \ - if( (src = luaT_toudata(L, 2, "torch." #TYPEC "Tensor")) ) \ - TH##TYPEC##Tensor_copy(storage, src); \ - else if( (src = luaT_toudata(L, 2, "torch.ByteTensor")) ) \ - TH##TYPEC##Tensor_copyByte(storage, src); \ - else if( (src = luaT_toudata(L, 2, "torch.CharTensor")) ) \ - TH##TYPEC##Tensor_copyChar(storage, src); \ - else if( (src = luaT_toudata(L, 2, "torch.ShortTensor")) ) \ - TH##TYPEC##Tensor_copyShort(storage, src); \ - else if( (src = luaT_toudata(L, 2, "torch.IntTensor")) ) \ - TH##TYPEC##Tensor_copyInt(storage, src); \ - else if( (src = luaT_toudata(L, 2, "torch.LongTensor")) ) \ - TH##TYPEC##Tensor_copyLong(storage, src); \ - else if( (src = luaT_toudata(L, 2, "torch.FloatTensor")) ) \ - TH##TYPEC##Tensor_copyFloat(storage, src); \ - else if( (src = luaT_toudata(L, 2, "torch.DoubleTensor")) ) \ - TH##TYPEC##Tensor_copyDouble(storage, src); \ - else if( (src = luaT_toudata(L, 2, "torch.CudaTensor")) ) \ - TH##TYPEC##Tensor_copyCuda(cutorch_getstate(L), storage, src); \ - else \ - luaL_typerror(L, 2, "torch.*Tensor"); \ - \ - lua_settop(L, 1); \ - return 1; \ - } -CUDA_IMPLEMENT_TENSOR_COPY(Byte) -CUDA_IMPLEMENT_TENSOR_COPY(Char) -CUDA_IMPLEMENT_TENSOR_COPY(Short) -CUDA_IMPLEMENT_TENSOR_COPY(Int) -CUDA_IMPLEMENT_TENSOR_COPY(Long) -CUDA_IMPLEMENT_TENSOR_COPY(Float) -CUDA_IMPLEMENT_TENSOR_COPY(Double) +static int TH_CONCAT_3(cutorch_,Real,Tensor_copy)(lua_State *L) +{ + THTensor *tensor = luaT_checkudata(L, 1, TH_CONCAT_STRING_3(torch.,Real,Tensor)); + void *src; + if( (src = luaT_toudata(L, 2, TH_CONCAT_STRING_3(torch.,Real,Tensor)) )) + THTensor_(copy)(tensor, src); + else if( (src = luaT_toudata(L, 2, "torch.ByteTensor")) ) + THTensor_(copyByte)(tensor, src); + else if( (src = luaT_toudata(L, 2, "torch.CharTensor")) ) + THTensor_(copyChar)(tensor, src); + else if( (src = luaT_toudata(L, 2, "torch.ShortTensor")) ) + THTensor_(copyShort)(tensor, src); + else if( (src = luaT_toudata(L, 2, "torch.IntTensor")) ) + THTensor_(copyInt)(tensor, src); + else if( (src = luaT_toudata(L, 2, "torch.LongTensor")) ) + THTensor_(copyLong)(tensor, src); + else if( (src = luaT_toudata(L, 2, "torch.FloatTensor")) ) + THTensor_(copyFloat)(tensor, src); + else if( (src = luaT_toudata(L, 2, "torch.DoubleTensor")) ) + THTensor_(copyDouble)(tensor, src); + else if( (src = luaT_toudata(L, 2, "torch.CudaByteTensor")) ) + THTensor_(copyCudaByte)(cutorch_getstate(L), tensor, src); + else if( (src = luaT_toudata(L, 2, "torch.CudaCharTensor")) ) + THTensor_(copyCudaChar)(cutorch_getstate(L), tensor, src); + else if( (src = luaT_toudata(L, 2, "torch.CudaShortTensor")) ) + THTensor_(copyCudaShort)(cutorch_getstate(L), tensor, src); + else if( (src = luaT_toudata(L, 2, "torch.CudaIntTensor")) ) + THTensor_(copyCudaInt)(cutorch_getstate(L), tensor, src); + else if( (src = luaT_toudata(L, 2, "torch.CudaLongTensor")) ) + THTensor_(copyCudaLong)(cutorch_getstate(L), tensor, src); + else if( (src = luaT_toudata(L, 2, "torch.CudaTensor")) ) + THTensor_(copyCudaFloat)(cutorch_getstate(L), tensor, src); + else if( (src = luaT_toudata(L, 2, "torch.CudaDoubleTensor")) ) + THTensor_(copyCudaDouble)(cutorch_getstate(L), tensor, src); + else + luaL_typerror(L, 2, "torch.*Tensor"); + + lua_settop(L, 1); + return 1; +} -static int cutorch_FloatTensor_copyAsync(lua_State *L) +static int TH_CONCAT_3(cutorch_,Real,Tensor_copyAsyncCuda)(lua_State *L) { - THFloatTensor *storage = luaT_checkudata(L, 1, "torch.FloatTensor"); +#define STRINGIFY_TENSOR(x) TH_CONCAT_STRING_3(torch.,x,Tensor) + THTensor *tensor = luaT_checkudata(L, 1, STRINGIFY_TENSOR(Real)); void *src; - if( (src = luaT_toudata(L, 2, "torch.CudaTensor")) ) - THFloatTensor_copyAsyncCuda(cutorch_getstate(L), storage, src); + if( (src = luaT_toudata(L, 2, STRINGIFY_TENSOR(CReal)))) + THTensor_(copyAsyncCuda)(cutorch_getstate(L), tensor, src); else - luaL_typerror(L, 2, "torch.CudaTensor"); + luaL_typerror(L, 2, STRINGIFY_TENSOR(CReal)); lua_settop(L, 1); return 1; +#undef STRINGIFY_TENSOR } + + +#ifdef THC_REAL_IS_FLOAT static void THFloatTensor_computesz(THFloatTensor *self, long **sz_, long **st_) { long *sz, *st, *szh; @@ -201,69 +211,55 @@ static int cuda_FloatTensor_fakecopy(lua_State *L) lua_settop(L, 1); return 1; } +#endif -static int cutorch_CudaTensor_getDevice(lua_State *L) { - THCudaTensor *tensor = luaT_checkudata(L, 1, "torch.CudaTensor"); - lua_pushinteger(L, THCudaTensor_getDevice(cutorch_getstate(L), tensor) + 1); +static int cutorch_Tensor_(getDevice)(lua_State *L) { + THCTensor *tensor = luaT_checkudata(L, 1, torch_Tensor); + lua_pushinteger(L, THCTensor_(getDevice)(cutorch_getstate(L), tensor) + 1); return 1; } -void cutorch_CudaTensor_init(lua_State* L) +void cutorch_Tensor_(init)(lua_State* L) { /* the standard stuff */ - torch_CudaTensor_init(L); + torch_Tensor_(init)(L); /* additional methods */ +#ifdef THC_REAL_IS_FLOAT luaT_pushmetatable(L, "torch.FloatTensor"); lua_pushcfunction(L, cuda_FloatTensor_fakecopy); lua_setfield(L, -2, "fakecopy"); lua_pop(L, 1); +#endif - /* the copy methods */ - { - int i; - - const void* tnames[8] = {"torch.ByteTensor", - "torch.CharTensor", - "torch.ShortTensor", - "torch.IntTensor", - "torch.LongTensor", - "torch.FloatTensor", - "torch.DoubleTensor", - "torch.CudaTensor"}; - - static int (*funcs[8])(lua_State*) = {cutorch_ByteTensor_copy, - cutorch_CharTensor_copy, - cutorch_ShortTensor_copy, - cutorch_IntTensor_copy, - cutorch_LongTensor_copy, - cutorch_FloatTensor_copy, - cutorch_DoubleTensor_copy, - cutorch_CudaTensor_copy}; - - for(i = 0; i < 8; i++) - { - luaT_pushmetatable(L, tnames[i]); - lua_pushcfunction(L, funcs[i]); - lua_setfield(L, -2, "copy"); - lua_pop(L, 1); - } + // torch_Storage macro is defined in Storage.c produce the CudaTensor types + // so I have to construct the normal torch types by hand + luaT_pushmetatable(L, TH_CONCAT_STRING_3(torch.,Real,Tensor)); + lua_pushcfunction(L, TH_CONCAT_3(cutorch_,Real,Tensor_copy)); + lua_setfield(L, -2, "copy"); + lua_pop(L, 1); - // Register async copy methods. - luaT_pushmetatable(L, "torch.CudaTensor"); - lua_pushcfunction(L, cutorch_CudaTensor_copyAsync); - lua_setfield(L, -2, "copyAsync"); - lua_pop(L, 1); + luaT_pushmetatable(L, torch_Tensor); + lua_pushcfunction(L, cutorch_Tensor_(copy)); + lua_setfield(L, -2, "copy"); + lua_pop(L, 1); - luaT_pushmetatable(L, "torch.FloatTensor"); - lua_pushcfunction(L, cutorch_FloatTensor_copyAsync); - lua_setfield(L, -2, "copyAsync"); - lua_pop(L, 1); - } + // Register async copy methods. + luaT_pushmetatable(L, TH_CONCAT_STRING_3(torch.,Real,Tensor)); + lua_pushcfunction(L, TH_CONCAT_3(cutorch_,Real,Tensor_copyAsyncCuda)); + lua_setfield(L, -2, "copyAsync"); + lua_pop(L, 1); + + luaT_pushmetatable(L, torch_Tensor); + lua_pushcfunction(L, cutorch_Tensor_(copyAsyncCPU)); + lua_setfield(L, -2, "copyAsync"); + lua_pop(L, 1); - luaT_pushmetatable(L, "torch.CudaTensor"); - lua_pushcfunction(L, cutorch_CudaTensor_getDevice); + luaT_pushmetatable(L, torch_Tensor); + lua_pushcfunction(L, cutorch_Tensor_(getDevice)); lua_setfield(L, -2, "getDevice"); lua_pop(L, 1); } + +#endif |