Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAdam Paszke <adam.paszke@gmail.com>2016-03-07 13:21:14 +0300
committerAdam Paszke <adam.paszke@gmail.com>2016-03-13 19:29:56 +0300
commit8120786180a1d41fca863dbe33fca7e11c988a7e (patch)
tree996c949d03557ccbb5fbe22dbaa559f98ee4d395 /generic
parentd9d7d2f14cda1889d47a8f2623ac8eb40b7bad0b (diff)
Add FP16 support (CudaHalfStorage, CudaHalfTensor)
Diffstat (limited to 'generic')
-rw-r--r--generic/CStorage.c38
-rw-r--r--generic/CTensor.c28
2 files changed, 55 insertions, 11 deletions
diff --git a/generic/CStorage.c b/generic/CStorage.c
index c5626d4..790d1d8 100644
--- a/generic/CStorage.c
+++ b/generic/CStorage.c
@@ -4,22 +4,40 @@
/* everything is as the generic Storage.c, except few things (see below) */
+#ifndef THC_REAL_IS_HALF
#define THFile_readRealRaw(file, data, size) \
{ \
- real *fdata = (real*)THAlloc(sizeof(real)*size); \
- TH_CONCAT_3(THFile_read,Real,Raw)(file, fdata, size); \
+ real *fdata = (real*)THAlloc(sizeof(real)*size); \
+ TH_CONCAT_3(THFile_read,Real,Raw)(file, fdata, size); \
THCudaCheck(cudaMemcpy(data, fdata, size * sizeof(real), cudaMemcpyHostToDevice)); \
THFree(fdata); \
}
#define THFile_writeRealRaw(file, data, size) \
{ \
- real *fdata = (real*)THAlloc(sizeof(real)*size); \
+ real *fdata = (real*)THAlloc(sizeof(real)*size); \
THCudaCheck(cudaMemcpy(fdata, data, size * sizeof(real), cudaMemcpyDeviceToHost)); \
- TH_CONCAT_3(THFile_write,Real,Raw)(file, fdata, size); \
+ TH_CONCAT_3(THFile_write,Real,Raw)(file, fdata, size); \
+ THFree(fdata); \
+ }
+#else
+#define THFile_readRealRaw(file, data, size) \
+ { \
+ real *fdata = (real*)THAlloc(sizeof(real)*size); \
+ THFile_readCharRaw(file, (char *)fdata, sizeof(real) * size); \
+ THCudaCheck(cudaMemcpy(data, fdata, size * sizeof(real), cudaMemcpyHostToDevice)); \
THFree(fdata); \
}
+#define THFile_writeRealRaw(file, data, size) \
+ { \
+ real *fdata = (real*)THAlloc(sizeof(real)*size); \
+ THCudaCheck(cudaMemcpy(fdata, data, size * sizeof(real), cudaMemcpyDeviceToHost)); \
+ THFile_writeCharRaw(file, (char *)fdata, size * sizeof(real)); \
+ THFree(fdata); \
+ }
+#endif
+
#define TH_GENERIC_FILE "generic/Storage.c"
#include "generic/Storage.c"
@@ -48,6 +66,10 @@ static int cutorch_Storage_(copy)(lua_State *L)
THCStorage_(copyCudaFloat)(state, storage, src);
else if( (src = luaT_toudata(L, 2, "torch.CudaDoubleStorage")) )
THCStorage_(copyCudaDouble)(state, storage, src);
+#if CUDA_VERSION >= 7050
+ else if( (src = luaT_toudata(L, 2, "torch.CudaHalfStorage")) )
+ THCStorage_(copyCudaHalf)(state, storage, src);
+#endif
else if( (src = luaT_toudata(L, 2, "torch.ByteStorage")) )
THCStorage_(copyByte)(state, storage, src);
@@ -70,6 +92,7 @@ static int cutorch_Storage_(copy)(lua_State *L)
return 1;
}
+#ifndef THC_REAL_IS_HALF
static int TH_CONCAT_3(cutorch_,Real,Storage_copy)(lua_State *L)
{
THStorage *storage = luaT_checkudata(L, 1, TH_CONCAT_STRING_3(torch.,Real,Storage));
@@ -104,12 +127,17 @@ static int TH_CONCAT_3(cutorch_,Real,Storage_copy)(lua_State *L)
THStorage_(copyCudaInt)(cutorch_getstate(L), storage, src);
else if( (src = luaT_toudata(L, 2, "torch.CudaDoubleStorage")) )
THStorage_(copyCudaDouble)(cutorch_getstate(L), storage, src);
+#if CUDA_VERSION >= 7050
+ else if( (src = luaT_toudata(L, 2, "torch.CudaHalfStorage")) )
+ THStorage_(copyCudaHalf)(cutorch_getstate(L), storage, src);
+#endif
else
luaL_typerror(L, 2, "torch.*Storage");
lua_settop(L, 1);
return 1;
}
+#endif
void cutorch_Storage_(init)(lua_State* L)
{
@@ -118,10 +146,12 @@ void cutorch_Storage_(init)(lua_State* L)
// torch_Storage macro is defined in Storage.c produce the CudaTensor types
// so I have to construct the normal torch types by hand
+#ifndef THC_REAL_IS_HALF
luaT_pushmetatable(L, TH_CONCAT_STRING_3(torch.,Real,Storage));
lua_pushcfunction(L, TH_CONCAT_3(cutorch_,Real,Storage_copy));
lua_setfield(L, -2, "copy");
lua_pop(L, 1);
+#endif
luaT_pushmetatable(L, torch_Storage);
lua_pushcfunction(L, cutorch_Storage_(copy));
diff --git a/generic/CTensor.c b/generic/CTensor.c
index 79a8a48..c81f92c 100644
--- a/generic/CTensor.c
+++ b/generic/CTensor.c
@@ -28,6 +28,10 @@ static int cutorch_Tensor_(copy)(lua_State *L)
THCTensor_(copyCudaLong)(state, tensor, src);
else if( (src = luaT_toudata(L, 2, "torch.CudaDoubleTensor")) )
THCTensor_(copyCudaDouble)(state, tensor, src);
+#if CUDA_VERSION >= 7050
+ else if( (src = luaT_toudata(L, 2, "torch.CudaHalfTensor")) )
+ THCTensor_(copyCudaHalf)(state, tensor, src);
+#endif
else if( (src = luaT_toudata(L, 2, "torch.ByteTensor")) )
THCTensor_(copyByte)(state, tensor, src);
@@ -50,6 +54,7 @@ static int cutorch_Tensor_(copy)(lua_State *L)
return 1;
}
+#ifndef THC_REAL_IS_HALF
static int cutorch_Tensor_(copyAsyncCPU)(lua_State *L)
{
#define STRINGIFY_TENSOR(x) TH_CONCAT_STRING_3(torch.,x,Tensor)
@@ -67,8 +72,10 @@ static int cutorch_Tensor_(copyAsyncCPU)(lua_State *L)
return 1;
#undef STRINGIFY_TENSOR
}
+#endif
+#ifndef THC_REAL_IS_HALF
static int TH_CONCAT_3(cutorch_,Real,Tensor_copy)(lua_State *L)
{
THTensor *tensor = luaT_checkudata(L, 1, TH_CONCAT_STRING_3(torch.,Real,Tensor));
@@ -103,13 +110,19 @@ static int TH_CONCAT_3(cutorch_,Real,Tensor_copy)(lua_State *L)
THTensor_(copyCudaFloat)(cutorch_getstate(L), tensor, src);
else if( (src = luaT_toudata(L, 2, "torch.CudaDoubleTensor")) )
THTensor_(copyCudaDouble)(cutorch_getstate(L), tensor, src);
+#if CUDA_VERSION >= 7050
+ else if( (src = luaT_toudata(L, 2, "torch.CudaHalfTensor")) )
+ THTensor_(copyCudaHalf)(cutorch_getstate(L), tensor, src);
+#endif
else
luaL_typerror(L, 2, "torch.*Tensor");
lua_settop(L, 1);
return 1;
}
+#endif
+#ifndef THC_REAL_IS_HALF
static int TH_CONCAT_3(cutorch_,Real,Tensor_copyAsyncCuda)(lua_State *L)
{
#define STRINGIFY_TENSOR(x) TH_CONCAT_STRING_3(torch.,x,Tensor)
@@ -124,6 +137,7 @@ static int TH_CONCAT_3(cutorch_,Real,Tensor_copyAsyncCuda)(lua_State *L)
return 1;
#undef STRINGIFY_TENSOR
}
+#endif
@@ -232,18 +246,12 @@ void cutorch_Tensor_(init)(lua_State* L)
lua_pop(L, 1);
#endif
- // torch_Storage macro is defined in Storage.c produce the CudaTensor types
- // so I have to construct the normal torch types by hand
+#ifndef THC_REAL_IS_HALF
luaT_pushmetatable(L, TH_CONCAT_STRING_3(torch.,Real,Tensor));
lua_pushcfunction(L, TH_CONCAT_3(cutorch_,Real,Tensor_copy));
lua_setfield(L, -2, "copy");
lua_pop(L, 1);
- luaT_pushmetatable(L, torch_Tensor);
- lua_pushcfunction(L, cutorch_Tensor_(copy));
- lua_setfield(L, -2, "copy");
- lua_pop(L, 1);
-
// Register async copy methods.
luaT_pushmetatable(L, TH_CONCAT_STRING_3(torch.,Real,Tensor));
lua_pushcfunction(L, TH_CONCAT_3(cutorch_,Real,Tensor_copyAsyncCuda));
@@ -254,6 +262,12 @@ void cutorch_Tensor_(init)(lua_State* L)
lua_pushcfunction(L, cutorch_Tensor_(copyAsyncCPU));
lua_setfield(L, -2, "copyAsync");
lua_pop(L, 1);
+#endif
+
+ luaT_pushmetatable(L, torch_Tensor);
+ lua_pushcfunction(L, cutorch_Tensor_(copy));
+ lua_setfield(L, -2, "copy");
+ lua_pop(L, 1);
luaT_pushmetatable(L, torch_Tensor);
lua_pushcfunction(L, cutorch_Tensor_(getDevice));