Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAdam Lerer <alerer@fb.com>2015-08-21 10:12:22 +0300
committerAdam Lerer <alerer@fb.com>2015-12-26 01:01:04 +0300
commitdf463695f0cd387736917e02abaafa63c00bbed3 (patch)
treeded8f3bd463e47a4a9c4b13f744e4010e6345703 /generic/CTensor.c
parentd3c6fa5e2648f902f497fe81d0fa30c62552e4e9 (diff)
Add generic CudaTensor types to cutorch
Diffstat (limited to 'generic/CTensor.c')
-rw-r--r--generic/CTensor.c238
1 files changed, 117 insertions, 121 deletions
diff --git a/generic/CTensor.c b/generic/CTensor.c
index 5666dec..79a8a48 100644
--- a/generic/CTensor.c
+++ b/generic/CTensor.c
@@ -1,49 +1,48 @@
-#include "torch/utils.h"
-#include "THC.h"
-#include "THFile.h"
-#include "luaT.h"
+#ifndef THC_GENERIC_FILE
+#define THC_GENERIC_FILE "generic/CTensor.c"
+#else
/* everything is as the generic Storage.c, except few things (see below) */
-#define real float
-#define Real Cuda
-
-#define torch_Storage_(NAME) TH_CONCAT_4(torch_,Real,Storage_,NAME)
-#define torch_Storage TH_CONCAT_STRING_3(torch.,Real,Storage)
-#define torch_Tensor_(NAME) TH_CONCAT_4(torch_,Real,Tensor_,NAME)
-#define torch_Tensor TH_CONCAT_STRING_3(torch.,Real,Tensor)
-
#define TH_GENERIC_FILE "generic/Tensor.c"
#include "generic/Tensor.c"
#undef TH_GENERIC_FILE
-#undef real
-#undef Real
-
/* now we overwrite some methods specific to CudaTensor */
-static int cutorch_CudaTensor_copy(lua_State *L)
+static int cutorch_Tensor_(copy)(lua_State *L)
{
THCState *state = cutorch_getstate(L);
- THCudaTensor *storage = luaT_checkudata(L, 1, "torch.CudaTensor");
+ THCTensor *tensor = luaT_checkudata(L, 1, torch_Tensor);
void *src;
if( (src = luaT_toudata(L, 2, "torch.CudaTensor")) )
- THCudaTensor_copy(state, storage, src);
+ THCTensor_(copyCudaFloat)(state, tensor, src);
+ else if( (src = luaT_toudata(L, 2, "torch.CudaByteTensor")) )
+ THCTensor_(copyCudaByte)(state, tensor, src);
+ else if( (src = luaT_toudata(L, 2, "torch.CudaCharTensor")) )
+ THCTensor_(copyCudaChar)(state, tensor, src);
+ else if( (src = luaT_toudata(L, 2, "torch.CudaShortTensor")) )
+ THCTensor_(copyCudaShort)(state, tensor, src);
+ else if( (src = luaT_toudata(L, 2, "torch.CudaIntTensor")) )
+ THCTensor_(copyCudaInt)(state, tensor, src);
+ else if( (src = luaT_toudata(L, 2, "torch.CudaLongTensor")) )
+ THCTensor_(copyCudaLong)(state, tensor, src);
+ else if( (src = luaT_toudata(L, 2, "torch.CudaDoubleTensor")) )
+ THCTensor_(copyCudaDouble)(state, tensor, src);
+
else if( (src = luaT_toudata(L, 2, "torch.ByteTensor")) )
- THCudaTensor_copyByte(state, storage, src);
+ THCTensor_(copyByte)(state, tensor, src);
else if( (src = luaT_toudata(L, 2, "torch.CharTensor")) )
- THCudaTensor_copyChar(state, storage, src);
+ THCTensor_(copyChar)(state, tensor, src);
else if( (src = luaT_toudata(L, 2, "torch.ShortTensor")) )
- THCudaTensor_copyShort(state, storage, src);
+ THCTensor_(copyShort)(state, tensor, src);
else if( (src = luaT_toudata(L, 2, "torch.IntTensor")) )
- THCudaTensor_copyInt(state, storage, src);
+ THCTensor_(copyInt)(state, tensor, src);
else if( (src = luaT_toudata(L, 2, "torch.LongTensor")) )
- THCudaTensor_copyLong(state, storage, src);
+ THCTensor_(copyLong)(state, tensor, src);
else if( (src = luaT_toudata(L, 2, "torch.FloatTensor")) )
- THCudaTensor_copyFloat(state, storage, src);
+ THCTensor_(copyFloat)(state, tensor, src);
else if( (src = luaT_toudata(L, 2, "torch.DoubleTensor")) )
- THCudaTensor_copyDouble(state, storage, src);
- else if( (src = luaT_toudata(L, 2, "torch.CudaTensor")) )
- THCudaTensor_copyCuda(state, storage, src);
+ THCTensor_(copyDouble)(state, tensor, src);
else
luaL_typerror(L, 2, "torch.*Tensor");
@@ -51,73 +50,84 @@ static int cutorch_CudaTensor_copy(lua_State *L)
return 1;
}
-static int cutorch_CudaTensor_copyAsync(lua_State *L)
+static int cutorch_Tensor_(copyAsyncCPU)(lua_State *L)
{
+#define STRINGIFY_TENSOR(x) TH_CONCAT_STRING_3(torch.,x,Tensor)
THCState *state = cutorch_getstate(L);
- THCudaTensor *storage = luaT_checkudata(L, 1, "torch.CudaTensor");
+ THCTensor *tensor = luaT_checkudata(L, 1, STRINGIFY_TENSOR(CReal));
void *src;
- if( (src = luaT_toudata(L, 2, "torch.CudaTensor")) )
- THCudaTensor_copy(state, storage, src);
- else if( (src = luaT_toudata(L, 2, "torch.FloatTensor")) )
- THCudaTensor_copyAsyncFloat(state, storage, src);
+ if( (src = luaT_toudata(L, 2, STRINGIFY_TENSOR(CReal))))
+ THCTensor_(copy)(state, tensor, src);
+ else if( (src = luaT_toudata(L, 2, STRINGIFY_TENSOR(Real))))
+ THCTensor_(copyAsyncCPU)(state, tensor, src);
else
- luaL_typerror(L, 2, "torch.FloatTensor or torch.CudaTensor");
+ luaL_typerror(L, 2, STRINGIFY_TENSOR(Real) " or " STRINGIFY_TENSOR(CReal));
lua_settop(L, 1);
return 1;
+#undef STRINGIFY_TENSOR
}
-#define CUDA_IMPLEMENT_TENSOR_COPY(TYPEC) \
- static int cutorch_##TYPEC##Tensor_copy(lua_State *L) \
- { \
- TH##TYPEC##Tensor *storage = luaT_checkudata(L, 1, "torch." #TYPEC "Tensor"); \
- void *src; \
- if( (src = luaT_toudata(L, 2, "torch." #TYPEC "Tensor")) ) \
- TH##TYPEC##Tensor_copy(storage, src); \
- else if( (src = luaT_toudata(L, 2, "torch.ByteTensor")) ) \
- TH##TYPEC##Tensor_copyByte(storage, src); \
- else if( (src = luaT_toudata(L, 2, "torch.CharTensor")) ) \
- TH##TYPEC##Tensor_copyChar(storage, src); \
- else if( (src = luaT_toudata(L, 2, "torch.ShortTensor")) ) \
- TH##TYPEC##Tensor_copyShort(storage, src); \
- else if( (src = luaT_toudata(L, 2, "torch.IntTensor")) ) \
- TH##TYPEC##Tensor_copyInt(storage, src); \
- else if( (src = luaT_toudata(L, 2, "torch.LongTensor")) ) \
- TH##TYPEC##Tensor_copyLong(storage, src); \
- else if( (src = luaT_toudata(L, 2, "torch.FloatTensor")) ) \
- TH##TYPEC##Tensor_copyFloat(storage, src); \
- else if( (src = luaT_toudata(L, 2, "torch.DoubleTensor")) ) \
- TH##TYPEC##Tensor_copyDouble(storage, src); \
- else if( (src = luaT_toudata(L, 2, "torch.CudaTensor")) ) \
- TH##TYPEC##Tensor_copyCuda(cutorch_getstate(L), storage, src); \
- else \
- luaL_typerror(L, 2, "torch.*Tensor"); \
- \
- lua_settop(L, 1); \
- return 1; \
- }
-CUDA_IMPLEMENT_TENSOR_COPY(Byte)
-CUDA_IMPLEMENT_TENSOR_COPY(Char)
-CUDA_IMPLEMENT_TENSOR_COPY(Short)
-CUDA_IMPLEMENT_TENSOR_COPY(Int)
-CUDA_IMPLEMENT_TENSOR_COPY(Long)
-CUDA_IMPLEMENT_TENSOR_COPY(Float)
-CUDA_IMPLEMENT_TENSOR_COPY(Double)
+static int TH_CONCAT_3(cutorch_,Real,Tensor_copy)(lua_State *L)
+{
+ THTensor *tensor = luaT_checkudata(L, 1, TH_CONCAT_STRING_3(torch.,Real,Tensor));
+ void *src;
+ if( (src = luaT_toudata(L, 2, TH_CONCAT_STRING_3(torch.,Real,Tensor)) ))
+ THTensor_(copy)(tensor, src);
+ else if( (src = luaT_toudata(L, 2, "torch.ByteTensor")) )
+ THTensor_(copyByte)(tensor, src);
+ else if( (src = luaT_toudata(L, 2, "torch.CharTensor")) )
+ THTensor_(copyChar)(tensor, src);
+ else if( (src = luaT_toudata(L, 2, "torch.ShortTensor")) )
+ THTensor_(copyShort)(tensor, src);
+ else if( (src = luaT_toudata(L, 2, "torch.IntTensor")) )
+ THTensor_(copyInt)(tensor, src);
+ else if( (src = luaT_toudata(L, 2, "torch.LongTensor")) )
+ THTensor_(copyLong)(tensor, src);
+ else if( (src = luaT_toudata(L, 2, "torch.FloatTensor")) )
+ THTensor_(copyFloat)(tensor, src);
+ else if( (src = luaT_toudata(L, 2, "torch.DoubleTensor")) )
+ THTensor_(copyDouble)(tensor, src);
+ else if( (src = luaT_toudata(L, 2, "torch.CudaByteTensor")) )
+ THTensor_(copyCudaByte)(cutorch_getstate(L), tensor, src);
+ else if( (src = luaT_toudata(L, 2, "torch.CudaCharTensor")) )
+ THTensor_(copyCudaChar)(cutorch_getstate(L), tensor, src);
+ else if( (src = luaT_toudata(L, 2, "torch.CudaShortTensor")) )
+ THTensor_(copyCudaShort)(cutorch_getstate(L), tensor, src);
+ else if( (src = luaT_toudata(L, 2, "torch.CudaIntTensor")) )
+ THTensor_(copyCudaInt)(cutorch_getstate(L), tensor, src);
+ else if( (src = luaT_toudata(L, 2, "torch.CudaLongTensor")) )
+ THTensor_(copyCudaLong)(cutorch_getstate(L), tensor, src);
+ else if( (src = luaT_toudata(L, 2, "torch.CudaTensor")) )
+ THTensor_(copyCudaFloat)(cutorch_getstate(L), tensor, src);
+ else if( (src = luaT_toudata(L, 2, "torch.CudaDoubleTensor")) )
+ THTensor_(copyCudaDouble)(cutorch_getstate(L), tensor, src);
+ else
+ luaL_typerror(L, 2, "torch.*Tensor");
+
+ lua_settop(L, 1);
+ return 1;
+}
-static int cutorch_FloatTensor_copyAsync(lua_State *L)
+static int TH_CONCAT_3(cutorch_,Real,Tensor_copyAsyncCuda)(lua_State *L)
{
- THFloatTensor *storage = luaT_checkudata(L, 1, "torch.FloatTensor");
+#define STRINGIFY_TENSOR(x) TH_CONCAT_STRING_3(torch.,x,Tensor)
+ THTensor *tensor = luaT_checkudata(L, 1, STRINGIFY_TENSOR(Real));
void *src;
- if( (src = luaT_toudata(L, 2, "torch.CudaTensor")) )
- THFloatTensor_copyAsyncCuda(cutorch_getstate(L), storage, src);
+ if( (src = luaT_toudata(L, 2, STRINGIFY_TENSOR(CReal))))
+ THTensor_(copyAsyncCuda)(cutorch_getstate(L), tensor, src);
else
- luaL_typerror(L, 2, "torch.CudaTensor");
+ luaL_typerror(L, 2, STRINGIFY_TENSOR(CReal));
lua_settop(L, 1);
return 1;
+#undef STRINGIFY_TENSOR
}
+
+
+#ifdef THC_REAL_IS_FLOAT
static void THFloatTensor_computesz(THFloatTensor *self, long **sz_, long **st_)
{
long *sz, *st, *szh;
@@ -201,69 +211,55 @@ static int cuda_FloatTensor_fakecopy(lua_State *L)
lua_settop(L, 1);
return 1;
}
+#endif
-static int cutorch_CudaTensor_getDevice(lua_State *L) {
- THCudaTensor *tensor = luaT_checkudata(L, 1, "torch.CudaTensor");
- lua_pushinteger(L, THCudaTensor_getDevice(cutorch_getstate(L), tensor) + 1);
+static int cutorch_Tensor_(getDevice)(lua_State *L) {
+ THCTensor *tensor = luaT_checkudata(L, 1, torch_Tensor);
+ lua_pushinteger(L, THCTensor_(getDevice)(cutorch_getstate(L), tensor) + 1);
return 1;
}
-void cutorch_CudaTensor_init(lua_State* L)
+void cutorch_Tensor_(init)(lua_State* L)
{
/* the standard stuff */
- torch_CudaTensor_init(L);
+ torch_Tensor_(init)(L);
/* additional methods */
+#ifdef THC_REAL_IS_FLOAT
luaT_pushmetatable(L, "torch.FloatTensor");
lua_pushcfunction(L, cuda_FloatTensor_fakecopy);
lua_setfield(L, -2, "fakecopy");
lua_pop(L, 1);
+#endif
- /* the copy methods */
- {
- int i;
-
- const void* tnames[8] = {"torch.ByteTensor",
- "torch.CharTensor",
- "torch.ShortTensor",
- "torch.IntTensor",
- "torch.LongTensor",
- "torch.FloatTensor",
- "torch.DoubleTensor",
- "torch.CudaTensor"};
-
- static int (*funcs[8])(lua_State*) = {cutorch_ByteTensor_copy,
- cutorch_CharTensor_copy,
- cutorch_ShortTensor_copy,
- cutorch_IntTensor_copy,
- cutorch_LongTensor_copy,
- cutorch_FloatTensor_copy,
- cutorch_DoubleTensor_copy,
- cutorch_CudaTensor_copy};
-
- for(i = 0; i < 8; i++)
- {
- luaT_pushmetatable(L, tnames[i]);
- lua_pushcfunction(L, funcs[i]);
- lua_setfield(L, -2, "copy");
- lua_pop(L, 1);
- }
+ // torch_Storage macro is defined in Storage.c produce the CudaTensor types
+ // so I have to construct the normal torch types by hand
+ luaT_pushmetatable(L, TH_CONCAT_STRING_3(torch.,Real,Tensor));
+ lua_pushcfunction(L, TH_CONCAT_3(cutorch_,Real,Tensor_copy));
+ lua_setfield(L, -2, "copy");
+ lua_pop(L, 1);
- // Register async copy methods.
- luaT_pushmetatable(L, "torch.CudaTensor");
- lua_pushcfunction(L, cutorch_CudaTensor_copyAsync);
- lua_setfield(L, -2, "copyAsync");
- lua_pop(L, 1);
+ luaT_pushmetatable(L, torch_Tensor);
+ lua_pushcfunction(L, cutorch_Tensor_(copy));
+ lua_setfield(L, -2, "copy");
+ lua_pop(L, 1);
- luaT_pushmetatable(L, "torch.FloatTensor");
- lua_pushcfunction(L, cutorch_FloatTensor_copyAsync);
- lua_setfield(L, -2, "copyAsync");
- lua_pop(L, 1);
- }
+ // Register async copy methods.
+ luaT_pushmetatable(L, TH_CONCAT_STRING_3(torch.,Real,Tensor));
+ lua_pushcfunction(L, TH_CONCAT_3(cutorch_,Real,Tensor_copyAsyncCuda));
+ lua_setfield(L, -2, "copyAsync");
+ lua_pop(L, 1);
+
+ luaT_pushmetatable(L, torch_Tensor);
+ lua_pushcfunction(L, cutorch_Tensor_(copyAsyncCPU));
+ lua_setfield(L, -2, "copyAsync");
+ lua_pop(L, 1);
- luaT_pushmetatable(L, "torch.CudaTensor");
- lua_pushcfunction(L, cutorch_CudaTensor_getDevice);
+ luaT_pushmetatable(L, torch_Tensor);
+ lua_pushcfunction(L, cutorch_Tensor_(getDevice));
lua_setfield(L, -2, "getDevice");
lua_pop(L, 1);
}
+
+#endif