diff options
author | Adam Paszke <adam.paszke@gmail.com> | 2016-03-07 13:21:14 +0300 |
---|---|---|
committer | Adam Paszke <adam.paszke@gmail.com> | 2016-03-13 19:29:56 +0300 |
commit | 8120786180a1d41fca863dbe33fca7e11c988a7e (patch) | |
tree | 996c949d03557ccbb5fbe22dbaa559f98ee4d395 /generic | |
parent | d9d7d2f14cda1889d47a8f2623ac8eb40b7bad0b (diff) |
Add FP16 support (CudaHalfStorage, CudaHalfTensor)
Diffstat (limited to 'generic')
-rw-r--r-- | generic/CStorage.c | 38 | ||||
-rw-r--r-- | generic/CTensor.c | 28 |
2 files changed, 55 insertions, 11 deletions
diff --git a/generic/CStorage.c b/generic/CStorage.c index c5626d4..790d1d8 100644 --- a/generic/CStorage.c +++ b/generic/CStorage.c @@ -4,22 +4,40 @@ /* everything is as the generic Storage.c, except few things (see below) */ +#ifndef THC_REAL_IS_HALF #define THFile_readRealRaw(file, data, size) \ { \ - real *fdata = (real*)THAlloc(sizeof(real)*size); \ - TH_CONCAT_3(THFile_read,Real,Raw)(file, fdata, size); \ + real *fdata = (real*)THAlloc(sizeof(real)*size); \ + TH_CONCAT_3(THFile_read,Real,Raw)(file, fdata, size); \ THCudaCheck(cudaMemcpy(data, fdata, size * sizeof(real), cudaMemcpyHostToDevice)); \ THFree(fdata); \ } #define THFile_writeRealRaw(file, data, size) \ { \ - real *fdata = (real*)THAlloc(sizeof(real)*size); \ + real *fdata = (real*)THAlloc(sizeof(real)*size); \ THCudaCheck(cudaMemcpy(fdata, data, size * sizeof(real), cudaMemcpyDeviceToHost)); \ - TH_CONCAT_3(THFile_write,Real,Raw)(file, fdata, size); \ + TH_CONCAT_3(THFile_write,Real,Raw)(file, fdata, size); \ + THFree(fdata); \ + } +#else +#define THFile_readRealRaw(file, data, size) \ + { \ + real *fdata = (real*)THAlloc(sizeof(real)*size); \ + THFile_readCharRaw(file, (char *)fdata, sizeof(real) * size); \ + THCudaCheck(cudaMemcpy(data, fdata, size * sizeof(real), cudaMemcpyHostToDevice)); \ THFree(fdata); \ } +#define THFile_writeRealRaw(file, data, size) \ + { \ + real *fdata = (real*)THAlloc(sizeof(real)*size); \ + THCudaCheck(cudaMemcpy(fdata, data, size * sizeof(real), cudaMemcpyDeviceToHost)); \ + THFile_writeCharRaw(file, (char *)fdata, size * sizeof(real)); \ + THFree(fdata); \ + } +#endif + #define TH_GENERIC_FILE "generic/Storage.c" #include "generic/Storage.c" @@ -48,6 +66,10 @@ static int cutorch_Storage_(copy)(lua_State *L) THCStorage_(copyCudaFloat)(state, storage, src); else if( (src = luaT_toudata(L, 2, "torch.CudaDoubleStorage")) ) THCStorage_(copyCudaDouble)(state, storage, src); +#if CUDA_VERSION >= 7050 + else if( (src = luaT_toudata(L, 2, "torch.CudaHalfStorage")) ) + THCStorage_(copyCudaHalf)(state, storage, src); +#endif else if( (src = luaT_toudata(L, 2, "torch.ByteStorage")) ) THCStorage_(copyByte)(state, storage, src); @@ -70,6 +92,7 @@ static int cutorch_Storage_(copy)(lua_State *L) return 1; } +#ifndef THC_REAL_IS_HALF static int TH_CONCAT_3(cutorch_,Real,Storage_copy)(lua_State *L) { THStorage *storage = luaT_checkudata(L, 1, TH_CONCAT_STRING_3(torch.,Real,Storage)); @@ -104,12 +127,17 @@ static int TH_CONCAT_3(cutorch_,Real,Storage_copy)(lua_State *L) THStorage_(copyCudaInt)(cutorch_getstate(L), storage, src); else if( (src = luaT_toudata(L, 2, "torch.CudaDoubleStorage")) ) THStorage_(copyCudaDouble)(cutorch_getstate(L), storage, src); +#if CUDA_VERSION >= 7050 + else if( (src = luaT_toudata(L, 2, "torch.CudaHalfStorage")) ) + THStorage_(copyCudaHalf)(cutorch_getstate(L), storage, src); +#endif else luaL_typerror(L, 2, "torch.*Storage"); lua_settop(L, 1); return 1; } +#endif void cutorch_Storage_(init)(lua_State* L) { @@ -118,10 +146,12 @@ void cutorch_Storage_(init)(lua_State* L) // torch_Storage macro is defined in Storage.c produce the CudaTensor types // so I have to construct the normal torch types by hand +#ifndef THC_REAL_IS_HALF luaT_pushmetatable(L, TH_CONCAT_STRING_3(torch.,Real,Storage)); lua_pushcfunction(L, TH_CONCAT_3(cutorch_,Real,Storage_copy)); lua_setfield(L, -2, "copy"); lua_pop(L, 1); +#endif luaT_pushmetatable(L, torch_Storage); lua_pushcfunction(L, cutorch_Storage_(copy)); diff --git a/generic/CTensor.c b/generic/CTensor.c index 79a8a48..c81f92c 100644 --- a/generic/CTensor.c +++ b/generic/CTensor.c @@ -28,6 +28,10 @@ static int cutorch_Tensor_(copy)(lua_State *L) THCTensor_(copyCudaLong)(state, tensor, src); else if( (src = luaT_toudata(L, 2, "torch.CudaDoubleTensor")) ) THCTensor_(copyCudaDouble)(state, tensor, src); +#if CUDA_VERSION >= 7050 + else if( (src = luaT_toudata(L, 2, "torch.CudaHalfTensor")) ) + THCTensor_(copyCudaHalf)(state, tensor, src); +#endif else if( (src = luaT_toudata(L, 2, "torch.ByteTensor")) ) THCTensor_(copyByte)(state, tensor, src); @@ -50,6 +54,7 @@ static int cutorch_Tensor_(copy)(lua_State *L) return 1; } +#ifndef THC_REAL_IS_HALF static int cutorch_Tensor_(copyAsyncCPU)(lua_State *L) { #define STRINGIFY_TENSOR(x) TH_CONCAT_STRING_3(torch.,x,Tensor) @@ -67,8 +72,10 @@ static int cutorch_Tensor_(copyAsyncCPU)(lua_State *L) return 1; #undef STRINGIFY_TENSOR } +#endif +#ifndef THC_REAL_IS_HALF static int TH_CONCAT_3(cutorch_,Real,Tensor_copy)(lua_State *L) { THTensor *tensor = luaT_checkudata(L, 1, TH_CONCAT_STRING_3(torch.,Real,Tensor)); @@ -103,13 +110,19 @@ static int TH_CONCAT_3(cutorch_,Real,Tensor_copy)(lua_State *L) THTensor_(copyCudaFloat)(cutorch_getstate(L), tensor, src); else if( (src = luaT_toudata(L, 2, "torch.CudaDoubleTensor")) ) THTensor_(copyCudaDouble)(cutorch_getstate(L), tensor, src); +#if CUDA_VERSION >= 7050 + else if( (src = luaT_toudata(L, 2, "torch.CudaHalfTensor")) ) + THTensor_(copyCudaHalf)(cutorch_getstate(L), tensor, src); +#endif else luaL_typerror(L, 2, "torch.*Tensor"); lua_settop(L, 1); return 1; } +#endif +#ifndef THC_REAL_IS_HALF static int TH_CONCAT_3(cutorch_,Real,Tensor_copyAsyncCuda)(lua_State *L) { #define STRINGIFY_TENSOR(x) TH_CONCAT_STRING_3(torch.,x,Tensor) @@ -124,6 +137,7 @@ static int TH_CONCAT_3(cutorch_,Real,Tensor_copyAsyncCuda)(lua_State *L) return 1; #undef STRINGIFY_TENSOR } +#endif @@ -232,18 +246,12 @@ void cutorch_Tensor_(init)(lua_State* L) lua_pop(L, 1); #endif - // torch_Storage macro is defined in Storage.c produce the CudaTensor types - // so I have to construct the normal torch types by hand +#ifndef THC_REAL_IS_HALF luaT_pushmetatable(L, TH_CONCAT_STRING_3(torch.,Real,Tensor)); lua_pushcfunction(L, TH_CONCAT_3(cutorch_,Real,Tensor_copy)); lua_setfield(L, -2, "copy"); lua_pop(L, 1); - luaT_pushmetatable(L, torch_Tensor); - lua_pushcfunction(L, cutorch_Tensor_(copy)); - lua_setfield(L, -2, "copy"); - lua_pop(L, 1); - // Register async copy methods. luaT_pushmetatable(L, TH_CONCAT_STRING_3(torch.,Real,Tensor)); lua_pushcfunction(L, TH_CONCAT_3(cutorch_,Real,Tensor_copyAsyncCuda)); @@ -254,6 +262,12 @@ void cutorch_Tensor_(init)(lua_State* L) lua_pushcfunction(L, cutorch_Tensor_(copyAsyncCPU)); lua_setfield(L, -2, "copyAsync"); lua_pop(L, 1); +#endif + + luaT_pushmetatable(L, torch_Tensor); + lua_pushcfunction(L, cutorch_Tensor_(copy)); + lua_setfield(L, -2, "copy"); + lua_pop(L, 1); luaT_pushmetatable(L, torch_Tensor); lua_pushcfunction(L, cutorch_Tensor_(getDevice)); |