Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAdam Lerer <alerer@fb.com>2015-08-21 10:12:22 +0300
committerAdam Lerer <alerer@fb.com>2015-12-26 01:01:04 +0300
commitdf463695f0cd387736917e02abaafa63c00bbed3 (patch)
treeded8f3bd463e47a4a9c4b13f744e4010e6345703 /generic
parentd3c6fa5e2648f902f497fe81d0fa30c62552e4e9 (diff)
Add generic CudaTensor types to cutorch
Diffstat (limited to 'generic')
-rw-r--r--generic/CStorage.c193
-rw-r--r--generic/CTensor.c238
2 files changed, 208 insertions, 223 deletions
diff --git a/generic/CStorage.c b/generic/CStorage.c
index 11ea696..c5626d4 100644
--- a/generic/CStorage.c
+++ b/generic/CStorage.c
@@ -1,65 +1,68 @@
-#include "torch/utils.h"
-#include "THC.h"
-#include "THFile.h"
-#include "luaT.h"
+#ifndef THC_GENERIC_FILE
+#define THC_GENERIC_FILE "generic/CStorage.c"
+#else
/* everything is as the generic Storage.c, except few things (see below) */
-#define real float
-#define Real Cuda
-#define TH_GENERIC_FILE "generic/Storage.c"
-
-#define torch_Storage_(NAME) TH_CONCAT_4(torch_,Real,Storage_,NAME)
-
#define THFile_readRealRaw(file, data, size) \
{ \
- float *fdata = (float*)THAlloc(sizeof(float)*size); \
- THFile_readFloatRaw(file, fdata, size); \
- THCudaCheck(cudaMemcpy(data, fdata, size * sizeof(float), cudaMemcpyHostToDevice)); \
+ real *fdata = (real*)THAlloc(sizeof(real)*size); \
+ TH_CONCAT_3(THFile_read,Real,Raw)(file, fdata, size); \
+ THCudaCheck(cudaMemcpy(data, fdata, size * sizeof(real), cudaMemcpyHostToDevice)); \
THFree(fdata); \
}
#define THFile_writeRealRaw(file, data, size) \
{ \
- float *fdata = (float*)THAlloc(sizeof(float)*size); \
- THCudaCheck(cudaMemcpy(fdata, data, size * sizeof(float), cudaMemcpyDeviceToHost)); \
- THFile_writeFloatRaw(file, fdata, size); \
+ real *fdata = (real*)THAlloc(sizeof(real)*size); \
+ THCudaCheck(cudaMemcpy(fdata, data, size * sizeof(real), cudaMemcpyDeviceToHost)); \
+ TH_CONCAT_3(THFile_write,Real,Raw)(file, fdata, size); \
THFree(fdata); \
}
-#define torch_Storage TH_CONCAT_STRING_3(torch.,Real,Storage)
-
+#define TH_GENERIC_FILE "generic/Storage.c"
#include "generic/Storage.c"
-#undef real
-#undef Real
#undef TH_GENERIC_FILE
+#undef THFile_readRealRaw
+#undef THFile_writeRealRaw
/* now we overwrite some methods specific to CudaStorage */
-static int cutorch_CudaStorage_copy(lua_State *L)
+static int cutorch_Storage_(copy)(lua_State *L)
{
THCState *state = cutorch_getstate(L);
- THCudaStorage *storage = luaT_checkudata(L, 1, "torch.CudaStorage");
+ THCStorage *storage = luaT_checkudata(L, 1, torch_Storage);
void *src;
- if( (src = luaT_toudata(L, 2, "torch.CudaStorage")) )
- THCudaStorage_copy(state, storage, src);
+ if( (src = luaT_toudata(L, 2, "torch.CudaByteStorage")) )
+ THCStorage_(copyCudaByte)(state, storage, src);
+ else if( (src = luaT_toudata(L, 2, "torch.CudaCharStorage")) )
+ THCStorage_(copyCudaChar)(state, storage, src);
+ else if( (src = luaT_toudata(L, 2, "torch.CudaShortStorage")) )
+ THCStorage_(copyCudaShort)(state, storage, src);
+ else if( (src = luaT_toudata(L, 2, "torch.CudaIntStorage")) )
+ THCStorage_(copyCudaInt)(state, storage, src);
+ else if( (src = luaT_toudata(L, 2, "torch.CudaLongStorage")) )
+ THCStorage_(copyCudaLong)(state, storage, src);
+ else if( (src = luaT_toudata(L, 2, "torch.CudaStorage")) )
+ THCStorage_(copyCudaFloat)(state, storage, src);
+ else if( (src = luaT_toudata(L, 2, "torch.CudaDoubleStorage")) )
+ THCStorage_(copyCudaDouble)(state, storage, src);
+
else if( (src = luaT_toudata(L, 2, "torch.ByteStorage")) )
- THCudaStorage_copyByte(state, storage, src);
+ THCStorage_(copyByte)(state, storage, src);
else if( (src = luaT_toudata(L, 2, "torch.CharStorage")) )
- THCudaStorage_copyChar(state, storage, src);
+ THCStorage_(copyChar)(state, storage, src);
else if( (src = luaT_toudata(L, 2, "torch.ShortStorage")) )
- THCudaStorage_copyShort(state, storage, src);
+ THCStorage_(copyShort)(state, storage, src);
else if( (src = luaT_toudata(L, 2, "torch.IntStorage")) )
- THCudaStorage_copyInt(state, storage, src);
+ THCStorage_(copyInt)(state, storage, src);
else if( (src = luaT_toudata(L, 2, "torch.LongStorage")) )
- THCudaStorage_copyLong(state, storage, src);
+ THCStorage_(copyLong)(state, storage, src);
else if( (src = luaT_toudata(L, 2, "torch.FloatStorage")) )
- THCudaStorage_copyFloat(state, storage, src);
+ THCStorage_(copyFloat)(state, storage, src);
else if( (src = luaT_toudata(L, 2, "torch.DoubleStorage")) )
- THCudaStorage_copyDouble(state, storage, src);
- else if( (src = luaT_toudata(L, 2, "torch.CudaStorage")) )
- THCudaStorage_copyCuda(state, storage, src);
+ THCStorage_(copyDouble)(state, storage, src);
else
luaL_typerror(L, 2, "torch.*Storage");
@@ -67,77 +70,63 @@ static int cutorch_CudaStorage_copy(lua_State *L)
return 1;
}
-#define CUDA_IMPLEMENT_STORAGE_COPY(TYPEC) \
- static int cutorch_##TYPEC##Storage_copy(lua_State *L) \
- { \
- TH##TYPEC##Storage *storage = luaT_checkudata(L, 1, "torch." #TYPEC "Storage"); \
- void *src; \
- if( (src = luaT_toudata(L, 2, "torch." #TYPEC "Storage")) ) \
- TH##TYPEC##Storage_copy(storage, src); \
- else if( (src = luaT_toudata(L, 2, "torch.ByteStorage")) ) \
- TH##TYPEC##Storage_copyByte(storage, src); \
- else if( (src = luaT_toudata(L, 2, "torch.CharStorage")) ) \
- TH##TYPEC##Storage_copyChar(storage, src); \
- else if( (src = luaT_toudata(L, 2, "torch.ShortStorage")) ) \
- TH##TYPEC##Storage_copyShort(storage, src); \
- else if( (src = luaT_toudata(L, 2, "torch.IntStorage")) ) \
- TH##TYPEC##Storage_copyInt(storage, src); \
- else if( (src = luaT_toudata(L, 2, "torch.LongStorage")) ) \
- TH##TYPEC##Storage_copyLong(storage, src); \
- else if( (src = luaT_toudata(L, 2, "torch.FloatStorage")) ) \
- TH##TYPEC##Storage_copyFloat(storage, src); \
- else if( (src = luaT_toudata(L, 2, "torch.DoubleStorage")) ) \
- TH##TYPEC##Storage_copyDouble(storage, src); \
- else if( (src = luaT_toudata(L, 2, "torch.CudaStorage")) ) \
- TH##TYPEC##Storage_copyCuda(cutorch_getstate(L), storage, src); \
- else \
- luaL_typerror(L, 2, "torch.*Storage"); \
- \
- lua_settop(L, 1); \
- return 1; \
-}
+static int TH_CONCAT_3(cutorch_,Real,Storage_copy)(lua_State *L)
+{
+ THStorage *storage = luaT_checkudata(L, 1, TH_CONCAT_STRING_3(torch.,Real,Storage));
+ void *src;
+ if( (src = luaT_toudata(L, 2, TH_CONCAT_STRING_3(torch.,Real,Storage) )))
+ THStorage_(copy)(storage, src);
+ else if( (src = luaT_toudata(L, 2, "torch.ByteStorage")) )
+ THStorage_(copyByte)(storage, src);
+ else if( (src = luaT_toudata(L, 2, "torch.CharStorage")) )
+ THStorage_(copyChar)(storage, src);
+ else if( (src = luaT_toudata(L, 2, "torch.ShortStorage")) )
+ THStorage_(copyShort)(storage, src);
+ else if( (src = luaT_toudata(L, 2, "torch.IntStorage")) )
+ THStorage_(copyInt)(storage, src);
+ else if( (src = luaT_toudata(L, 2, "torch.LongStorage")) )
+ THStorage_(copyLong)(storage, src);
+ else if( (src = luaT_toudata(L, 2, "torch.FloatStorage")) )
+ THStorage_(copyFloat)(storage, src);
+ else if( (src = luaT_toudata(L, 2, "torch.DoubleStorage")) )
+ THStorage_(copyDouble)(storage, src);
+ else if( (src = luaT_toudata(L, 2, "torch.CudaStorage")) )
+ THStorage_(copyCudaFloat)(cutorch_getstate(L), storage, src);
+ else if( (src = luaT_toudata(L, 2, "torch.CudaLongStorage")) )
+ THStorage_(copyCudaLong)(cutorch_getstate(L), storage, src);
+ else if( (src = luaT_toudata(L, 2, "torch.CudaByteStorage")) )
+ THStorage_(copyCudaByte)(cutorch_getstate(L), storage, src);
+ else if( (src = luaT_toudata(L, 2, "torch.CudaCharStorage")) )
+ THStorage_(copyCudaChar)(cutorch_getstate(L), storage, src);
+ else if( (src = luaT_toudata(L, 2, "torch.CudaShortStorage")) )
+ THStorage_(copyCudaShort)(cutorch_getstate(L), storage, src);
+ else if( (src = luaT_toudata(L, 2, "torch.CudaIntStorage")) )
+ THStorage_(copyCudaInt)(cutorch_getstate(L), storage, src);
+ else if( (src = luaT_toudata(L, 2, "torch.CudaDoubleStorage")) )
+ THStorage_(copyCudaDouble)(cutorch_getstate(L), storage, src);
+ else
+ luaL_typerror(L, 2, "torch.*Storage");
-CUDA_IMPLEMENT_STORAGE_COPY(Byte)
-CUDA_IMPLEMENT_STORAGE_COPY(Char)
-CUDA_IMPLEMENT_STORAGE_COPY(Short)
-CUDA_IMPLEMENT_STORAGE_COPY(Int)
-CUDA_IMPLEMENT_STORAGE_COPY(Long)
-CUDA_IMPLEMENT_STORAGE_COPY(Float)
-CUDA_IMPLEMENT_STORAGE_COPY(Double)
+ lua_settop(L, 1);
+ return 1;
+}
-void cutorch_CudaStorage_init(lua_State* L)
+void cutorch_Storage_(init)(lua_State* L)
{
/* the standard stuff */
- torch_CudaStorage_init(L);
-
- /* the copy methods */
- {
- int i;
-
- const void* tnames[8] = {"torch.ByteStorage",
- "torch.CharStorage",
- "torch.ShortStorage",
- "torch.IntStorage",
- "torch.LongStorage",
- "torch.FloatStorage",
- "torch.DoubleStorage",
- "torch.CudaStorage"};
-
- static int (*funcs[8])(lua_State*) = {cutorch_ByteStorage_copy,
- cutorch_CharStorage_copy,
- cutorch_ShortStorage_copy,
- cutorch_IntStorage_copy,
- cutorch_LongStorage_copy,
- cutorch_FloatStorage_copy,
- cutorch_DoubleStorage_copy,
- cutorch_CudaStorage_copy};
-
- for(i = 0; i < 8; i++)
- {
- luaT_pushmetatable(L, tnames[i]);
- lua_pushcfunction(L, funcs[i]);
- lua_setfield(L, -2, "copy");
- lua_pop(L, 1);
- }
- }
+ torch_Storage_(init)(L);
+
+ // torch_Storage macro is defined in Storage.c produce the CudaTensor types
+ // so I have to construct the normal torch types by hand
+ luaT_pushmetatable(L, TH_CONCAT_STRING_3(torch.,Real,Storage));
+ lua_pushcfunction(L, TH_CONCAT_3(cutorch_,Real,Storage_copy));
+ lua_setfield(L, -2, "copy");
+ lua_pop(L, 1);
+
+ luaT_pushmetatable(L, torch_Storage);
+ lua_pushcfunction(L, cutorch_Storage_(copy));
+ lua_setfield(L, -2, "copy");
+ lua_pop(L, 1);
}
+
+#endif
diff --git a/generic/CTensor.c b/generic/CTensor.c
index 5666dec..79a8a48 100644
--- a/generic/CTensor.c
+++ b/generic/CTensor.c
@@ -1,49 +1,48 @@
-#include "torch/utils.h"
-#include "THC.h"
-#include "THFile.h"
-#include "luaT.h"
+#ifndef THC_GENERIC_FILE
+#define THC_GENERIC_FILE "generic/CTensor.c"
+#else
/* everything is as the generic Storage.c, except few things (see below) */
-#define real float
-#define Real Cuda
-
-#define torch_Storage_(NAME) TH_CONCAT_4(torch_,Real,Storage_,NAME)
-#define torch_Storage TH_CONCAT_STRING_3(torch.,Real,Storage)
-#define torch_Tensor_(NAME) TH_CONCAT_4(torch_,Real,Tensor_,NAME)
-#define torch_Tensor TH_CONCAT_STRING_3(torch.,Real,Tensor)
-
#define TH_GENERIC_FILE "generic/Tensor.c"
#include "generic/Tensor.c"
#undef TH_GENERIC_FILE
-#undef real
-#undef Real
-
/* now we overwrite some methods specific to CudaTensor */
-static int cutorch_CudaTensor_copy(lua_State *L)
+static int cutorch_Tensor_(copy)(lua_State *L)
{
THCState *state = cutorch_getstate(L);
- THCudaTensor *storage = luaT_checkudata(L, 1, "torch.CudaTensor");
+ THCTensor *tensor = luaT_checkudata(L, 1, torch_Tensor);
void *src;
if( (src = luaT_toudata(L, 2, "torch.CudaTensor")) )
- THCudaTensor_copy(state, storage, src);
+ THCTensor_(copyCudaFloat)(state, tensor, src);
+ else if( (src = luaT_toudata(L, 2, "torch.CudaByteTensor")) )
+ THCTensor_(copyCudaByte)(state, tensor, src);
+ else if( (src = luaT_toudata(L, 2, "torch.CudaCharTensor")) )
+ THCTensor_(copyCudaChar)(state, tensor, src);
+ else if( (src = luaT_toudata(L, 2, "torch.CudaShortTensor")) )
+ THCTensor_(copyCudaShort)(state, tensor, src);
+ else if( (src = luaT_toudata(L, 2, "torch.CudaIntTensor")) )
+ THCTensor_(copyCudaInt)(state, tensor, src);
+ else if( (src = luaT_toudata(L, 2, "torch.CudaLongTensor")) )
+ THCTensor_(copyCudaLong)(state, tensor, src);
+ else if( (src = luaT_toudata(L, 2, "torch.CudaDoubleTensor")) )
+ THCTensor_(copyCudaDouble)(state, tensor, src);
+
else if( (src = luaT_toudata(L, 2, "torch.ByteTensor")) )
- THCudaTensor_copyByte(state, storage, src);
+ THCTensor_(copyByte)(state, tensor, src);
else if( (src = luaT_toudata(L, 2, "torch.CharTensor")) )
- THCudaTensor_copyChar(state, storage, src);
+ THCTensor_(copyChar)(state, tensor, src);
else if( (src = luaT_toudata(L, 2, "torch.ShortTensor")) )
- THCudaTensor_copyShort(state, storage, src);
+ THCTensor_(copyShort)(state, tensor, src);
else if( (src = luaT_toudata(L, 2, "torch.IntTensor")) )
- THCudaTensor_copyInt(state, storage, src);
+ THCTensor_(copyInt)(state, tensor, src);
else if( (src = luaT_toudata(L, 2, "torch.LongTensor")) )
- THCudaTensor_copyLong(state, storage, src);
+ THCTensor_(copyLong)(state, tensor, src);
else if( (src = luaT_toudata(L, 2, "torch.FloatTensor")) )
- THCudaTensor_copyFloat(state, storage, src);
+ THCTensor_(copyFloat)(state, tensor, src);
else if( (src = luaT_toudata(L, 2, "torch.DoubleTensor")) )
- THCudaTensor_copyDouble(state, storage, src);
- else if( (src = luaT_toudata(L, 2, "torch.CudaTensor")) )
- THCudaTensor_copyCuda(state, storage, src);
+ THCTensor_(copyDouble)(state, tensor, src);
else
luaL_typerror(L, 2, "torch.*Tensor");
@@ -51,73 +50,84 @@ static int cutorch_CudaTensor_copy(lua_State *L)
return 1;
}
-static int cutorch_CudaTensor_copyAsync(lua_State *L)
+static int cutorch_Tensor_(copyAsyncCPU)(lua_State *L)
{
+#define STRINGIFY_TENSOR(x) TH_CONCAT_STRING_3(torch.,x,Tensor)
THCState *state = cutorch_getstate(L);
- THCudaTensor *storage = luaT_checkudata(L, 1, "torch.CudaTensor");
+ THCTensor *tensor = luaT_checkudata(L, 1, STRINGIFY_TENSOR(CReal));
void *src;
- if( (src = luaT_toudata(L, 2, "torch.CudaTensor")) )
- THCudaTensor_copy(state, storage, src);
- else if( (src = luaT_toudata(L, 2, "torch.FloatTensor")) )
- THCudaTensor_copyAsyncFloat(state, storage, src);
+ if( (src = luaT_toudata(L, 2, STRINGIFY_TENSOR(CReal))))
+ THCTensor_(copy)(state, tensor, src);
+ else if( (src = luaT_toudata(L, 2, STRINGIFY_TENSOR(Real))))
+ THCTensor_(copyAsyncCPU)(state, tensor, src);
else
- luaL_typerror(L, 2, "torch.FloatTensor or torch.CudaTensor");
+ luaL_typerror(L, 2, STRINGIFY_TENSOR(Real) " or " STRINGIFY_TENSOR(CReal));
lua_settop(L, 1);
return 1;
+#undef STRINGIFY_TENSOR
}
-#define CUDA_IMPLEMENT_TENSOR_COPY(TYPEC) \
- static int cutorch_##TYPEC##Tensor_copy(lua_State *L) \
- { \
- TH##TYPEC##Tensor *storage = luaT_checkudata(L, 1, "torch." #TYPEC "Tensor"); \
- void *src; \
- if( (src = luaT_toudata(L, 2, "torch." #TYPEC "Tensor")) ) \
- TH##TYPEC##Tensor_copy(storage, src); \
- else if( (src = luaT_toudata(L, 2, "torch.ByteTensor")) ) \
- TH##TYPEC##Tensor_copyByte(storage, src); \
- else if( (src = luaT_toudata(L, 2, "torch.CharTensor")) ) \
- TH##TYPEC##Tensor_copyChar(storage, src); \
- else if( (src = luaT_toudata(L, 2, "torch.ShortTensor")) ) \
- TH##TYPEC##Tensor_copyShort(storage, src); \
- else if( (src = luaT_toudata(L, 2, "torch.IntTensor")) ) \
- TH##TYPEC##Tensor_copyInt(storage, src); \
- else if( (src = luaT_toudata(L, 2, "torch.LongTensor")) ) \
- TH##TYPEC##Tensor_copyLong(storage, src); \
- else if( (src = luaT_toudata(L, 2, "torch.FloatTensor")) ) \
- TH##TYPEC##Tensor_copyFloat(storage, src); \
- else if( (src = luaT_toudata(L, 2, "torch.DoubleTensor")) ) \
- TH##TYPEC##Tensor_copyDouble(storage, src); \
- else if( (src = luaT_toudata(L, 2, "torch.CudaTensor")) ) \
- TH##TYPEC##Tensor_copyCuda(cutorch_getstate(L), storage, src); \
- else \
- luaL_typerror(L, 2, "torch.*Tensor"); \
- \
- lua_settop(L, 1); \
- return 1; \
- }
-CUDA_IMPLEMENT_TENSOR_COPY(Byte)
-CUDA_IMPLEMENT_TENSOR_COPY(Char)
-CUDA_IMPLEMENT_TENSOR_COPY(Short)
-CUDA_IMPLEMENT_TENSOR_COPY(Int)
-CUDA_IMPLEMENT_TENSOR_COPY(Long)
-CUDA_IMPLEMENT_TENSOR_COPY(Float)
-CUDA_IMPLEMENT_TENSOR_COPY(Double)
+static int TH_CONCAT_3(cutorch_,Real,Tensor_copy)(lua_State *L)
+{
+ THTensor *tensor = luaT_checkudata(L, 1, TH_CONCAT_STRING_3(torch.,Real,Tensor));
+ void *src;
+ if( (src = luaT_toudata(L, 2, TH_CONCAT_STRING_3(torch.,Real,Tensor)) ))
+ THTensor_(copy)(tensor, src);
+ else if( (src = luaT_toudata(L, 2, "torch.ByteTensor")) )
+ THTensor_(copyByte)(tensor, src);
+ else if( (src = luaT_toudata(L, 2, "torch.CharTensor")) )
+ THTensor_(copyChar)(tensor, src);
+ else if( (src = luaT_toudata(L, 2, "torch.ShortTensor")) )
+ THTensor_(copyShort)(tensor, src);
+ else if( (src = luaT_toudata(L, 2, "torch.IntTensor")) )
+ THTensor_(copyInt)(tensor, src);
+ else if( (src = luaT_toudata(L, 2, "torch.LongTensor")) )
+ THTensor_(copyLong)(tensor, src);
+ else if( (src = luaT_toudata(L, 2, "torch.FloatTensor")) )
+ THTensor_(copyFloat)(tensor, src);
+ else if( (src = luaT_toudata(L, 2, "torch.DoubleTensor")) )
+ THTensor_(copyDouble)(tensor, src);
+ else if( (src = luaT_toudata(L, 2, "torch.CudaByteTensor")) )
+ THTensor_(copyCudaByte)(cutorch_getstate(L), tensor, src);
+ else if( (src = luaT_toudata(L, 2, "torch.CudaCharTensor")) )
+ THTensor_(copyCudaChar)(cutorch_getstate(L), tensor, src);
+ else if( (src = luaT_toudata(L, 2, "torch.CudaShortTensor")) )
+ THTensor_(copyCudaShort)(cutorch_getstate(L), tensor, src);
+ else if( (src = luaT_toudata(L, 2, "torch.CudaIntTensor")) )
+ THTensor_(copyCudaInt)(cutorch_getstate(L), tensor, src);
+ else if( (src = luaT_toudata(L, 2, "torch.CudaLongTensor")) )
+ THTensor_(copyCudaLong)(cutorch_getstate(L), tensor, src);
+ else if( (src = luaT_toudata(L, 2, "torch.CudaTensor")) )
+ THTensor_(copyCudaFloat)(cutorch_getstate(L), tensor, src);
+ else if( (src = luaT_toudata(L, 2, "torch.CudaDoubleTensor")) )
+ THTensor_(copyCudaDouble)(cutorch_getstate(L), tensor, src);
+ else
+ luaL_typerror(L, 2, "torch.*Tensor");
+
+ lua_settop(L, 1);
+ return 1;
+}
-static int cutorch_FloatTensor_copyAsync(lua_State *L)
+static int TH_CONCAT_3(cutorch_,Real,Tensor_copyAsyncCuda)(lua_State *L)
{
- THFloatTensor *storage = luaT_checkudata(L, 1, "torch.FloatTensor");
+#define STRINGIFY_TENSOR(x) TH_CONCAT_STRING_3(torch.,x,Tensor)
+ THTensor *tensor = luaT_checkudata(L, 1, STRINGIFY_TENSOR(Real));
void *src;
- if( (src = luaT_toudata(L, 2, "torch.CudaTensor")) )
- THFloatTensor_copyAsyncCuda(cutorch_getstate(L), storage, src);
+ if( (src = luaT_toudata(L, 2, STRINGIFY_TENSOR(CReal))))
+ THTensor_(copyAsyncCuda)(cutorch_getstate(L), tensor, src);
else
- luaL_typerror(L, 2, "torch.CudaTensor");
+ luaL_typerror(L, 2, STRINGIFY_TENSOR(CReal));
lua_settop(L, 1);
return 1;
+#undef STRINGIFY_TENSOR
}
+
+
+#ifdef THC_REAL_IS_FLOAT
static void THFloatTensor_computesz(THFloatTensor *self, long **sz_, long **st_)
{
long *sz, *st, *szh;
@@ -201,69 +211,55 @@ static int cuda_FloatTensor_fakecopy(lua_State *L)
lua_settop(L, 1);
return 1;
}
+#endif
-static int cutorch_CudaTensor_getDevice(lua_State *L) {
- THCudaTensor *tensor = luaT_checkudata(L, 1, "torch.CudaTensor");
- lua_pushinteger(L, THCudaTensor_getDevice(cutorch_getstate(L), tensor) + 1);
+static int cutorch_Tensor_(getDevice)(lua_State *L) {
+ THCTensor *tensor = luaT_checkudata(L, 1, torch_Tensor);
+ lua_pushinteger(L, THCTensor_(getDevice)(cutorch_getstate(L), tensor) + 1);
return 1;
}
-void cutorch_CudaTensor_init(lua_State* L)
+void cutorch_Tensor_(init)(lua_State* L)
{
/* the standard stuff */
- torch_CudaTensor_init(L);
+ torch_Tensor_(init)(L);
/* additional methods */
+#ifdef THC_REAL_IS_FLOAT
luaT_pushmetatable(L, "torch.FloatTensor");
lua_pushcfunction(L, cuda_FloatTensor_fakecopy);
lua_setfield(L, -2, "fakecopy");
lua_pop(L, 1);
+#endif
- /* the copy methods */
- {
- int i;
-
- const void* tnames[8] = {"torch.ByteTensor",
- "torch.CharTensor",
- "torch.ShortTensor",
- "torch.IntTensor",
- "torch.LongTensor",
- "torch.FloatTensor",
- "torch.DoubleTensor",
- "torch.CudaTensor"};
-
- static int (*funcs[8])(lua_State*) = {cutorch_ByteTensor_copy,
- cutorch_CharTensor_copy,
- cutorch_ShortTensor_copy,
- cutorch_IntTensor_copy,
- cutorch_LongTensor_copy,
- cutorch_FloatTensor_copy,
- cutorch_DoubleTensor_copy,
- cutorch_CudaTensor_copy};
-
- for(i = 0; i < 8; i++)
- {
- luaT_pushmetatable(L, tnames[i]);
- lua_pushcfunction(L, funcs[i]);
- lua_setfield(L, -2, "copy");
- lua_pop(L, 1);
- }
+ // torch_Storage macro is defined in Storage.c produce the CudaTensor types
+ // so I have to construct the normal torch types by hand
+ luaT_pushmetatable(L, TH_CONCAT_STRING_3(torch.,Real,Tensor));
+ lua_pushcfunction(L, TH_CONCAT_3(cutorch_,Real,Tensor_copy));
+ lua_setfield(L, -2, "copy");
+ lua_pop(L, 1);
- // Register async copy methods.
- luaT_pushmetatable(L, "torch.CudaTensor");
- lua_pushcfunction(L, cutorch_CudaTensor_copyAsync);
- lua_setfield(L, -2, "copyAsync");
- lua_pop(L, 1);
+ luaT_pushmetatable(L, torch_Tensor);
+ lua_pushcfunction(L, cutorch_Tensor_(copy));
+ lua_setfield(L, -2, "copy");
+ lua_pop(L, 1);
- luaT_pushmetatable(L, "torch.FloatTensor");
- lua_pushcfunction(L, cutorch_FloatTensor_copyAsync);
- lua_setfield(L, -2, "copyAsync");
- lua_pop(L, 1);
- }
+ // Register async copy methods.
+ luaT_pushmetatable(L, TH_CONCAT_STRING_3(torch.,Real,Tensor));
+ lua_pushcfunction(L, TH_CONCAT_3(cutorch_,Real,Tensor_copyAsyncCuda));
+ lua_setfield(L, -2, "copyAsync");
+ lua_pop(L, 1);
+
+ luaT_pushmetatable(L, torch_Tensor);
+ lua_pushcfunction(L, cutorch_Tensor_(copyAsyncCPU));
+ lua_setfield(L, -2, "copyAsync");
+ lua_pop(L, 1);
- luaT_pushmetatable(L, "torch.CudaTensor");
- lua_pushcfunction(L, cutorch_CudaTensor_getDevice);
+ luaT_pushmetatable(L, torch_Tensor);
+ lua_pushcfunction(L, cutorch_Tensor_(getDevice));
lua_setfield(L, -2, "getDevice");
lua_pop(L, 1);
}
+
+#endif