diff options
author | Soumith Chintala <soumith@gmail.com> | 2016-03-28 03:53:39 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2016-03-28 03:53:39 +0300 |
commit | 7d2a5d54736c6f5f52f04e4f86354c3cde3ee103 (patch) | |
tree | a855acf36861eed45f81f18e0d0e5f52c97a43b5 /torch | |
parent | 07faade60be32c5b860b9e53c0bf1d16b0ac0d5b (diff) | |
parent | 8120786180a1d41fca863dbe33fca7e11c988a7e (diff) |
Merge pull request #355 from apaszke/fp16
Add FP16 support (CudaHalfStorage, CudaHalfTensor)
Diffstat (limited to 'torch')
-rw-r--r-- | torch/generic/Storage.c | 22 | ||||
-rw-r--r-- | torch/generic/Tensor.c | 4 |
2 files changed, 18 insertions, 8 deletions
diff --git a/torch/generic/Storage.c b/torch/generic/Storage.c index 2dfe434..8a82e42 100644 --- a/torch/generic/Storage.c +++ b/torch/generic/Storage.c @@ -26,7 +26,7 @@ static int torch_Storage_(new)(lua_State *L) THCStorage_(free)(state, storage); luaL_error(L, "element at index %d is not a number", i); } - THCStorage_(set)(state, storage, i-1, (real)lua_tonumber(L, -1)); + THCStorage_(set)(state, storage, i-1, (hostreal)lua_tonumber(L, -1)); lua_pop(L, 1); } } @@ -118,14 +118,14 @@ static int torch_Storage_(fill)(lua_State *L) { THCStorage *storage = luaT_checkudata(L, 1, torch_Storage); double value = luaL_checknumber(L, 2); - THCStorage_(fill)(cutorch_getstate(L), storage, (real)value); + THCStorage_(fill)(cutorch_getstate(L), storage, (hostreal)value); lua_settop(L, 1); return 1; } static int torch_Storage_(elementSize)(lua_State *L) { - lua_pushnumber(L, THStorage_(elementSize)(cutorch_getstate(L))); + lua_pushnumber(L, THCStorage_(elementSize)(cutorch_getstate(L))); return 1; } @@ -143,7 +143,7 @@ static int torch_Storage_(__newindex__)(lua_State *L) THCStorage *storage = luaT_checkudata(L, 1, torch_Storage); long index = luaL_checklong(L, 2) - 1; double number = luaL_checknumber(L, 3); - THCStorage_(set)(cutorch_getstate(L), storage, index, (real)number); + THCStorage_(set)(cutorch_getstate(L), storage, index, (hostreal)number); lua_pushboolean(L, 1); } else @@ -173,12 +173,18 @@ static int torch_Storage_(totable)(lua_State *L) { THCState *state = cutorch_getstate(L); THCStorage *storage = luaT_checkudata(L, 1, torch_Storage); - THStorage *host_storage; long i; /* Copy storage from device to host. */ - host_storage = THStorage_(newWithSize)(THCStorage_(size)(state, storage)); +#ifndef THC_REAL_IS_HALF + THStorage *host_storage = + THStorage_(newWithSize)(THCStorage_(size)(state, storage)); THStorage_(copyCuda)(state, host_storage, storage); +#else + THFloatStorage *host_storage = + THFloatStorage_newWithSize(THCStorage_(size)(state, storage)); + THFloatStorage_copyCudaHalf(state, host_storage, storage); +#endif lua_newtable(L); for(i = 0; i < storage->size; i++) @@ -186,7 +192,11 @@ static int torch_Storage_(totable)(lua_State *L) lua_pushnumber(L, (lua_Number)host_storage->data[i]); lua_rawseti(L, -2, i+1); } +#ifndef THC_REAL_IS_HALF THStorage_(free)(host_storage); +#else + THFloatStorage_free(host_storage); +#endif return 1; } diff --git a/torch/generic/Tensor.c b/torch/generic/Tensor.c index bbed718..a1c5489 100644 --- a/torch/generic/Tensor.c +++ b/torch/generic/Tensor.c @@ -27,7 +27,7 @@ static int torch_Tensor_(size)(lua_State *L) static int torch_Tensor_(elementSize)(lua_State *L) { - lua_pushnumber(L, THStorage_(elementSize)(cutorch_getstate(L))); + lua_pushnumber(L, THCStorage_(elementSize)(cutorch_getstate(L))); return 1; } @@ -140,7 +140,7 @@ static int torch_Tensor_(new)(lua_State *L) THCTensor_(free)(state, tensor); luaL_error(L, "invalid element (not a number)"); } - THCStorage_(set)(state, THCTensor_(storage)(state, tensor), si++, (real)lua_tonumber(L, -1)); + THCStorage_(set)(state, THCTensor_(storage)(state, tensor), si++, (hostreal)lua_tonumber(L, -1)); lua_pop(L, 1); } |