Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/torch
diff options
context:
space:
mode:
authorAdam Paszke <adam.paszke@gmail.com>2016-03-07 13:21:14 +0300
committerAdam Paszke <adam.paszke@gmail.com>2016-03-13 19:29:56 +0300
commit8120786180a1d41fca863dbe33fca7e11c988a7e (patch)
tree996c949d03557ccbb5fbe22dbaa559f98ee4d395 /torch
parentd9d7d2f14cda1889d47a8f2623ac8eb40b7bad0b (diff)
Add FP16 support (CudaHalfStorage, CudaHalfTensor)
Diffstat (limited to 'torch')
-rw-r--r--torch/generic/Storage.c22
-rw-r--r--torch/generic/Tensor.c4
2 files changed, 18 insertions, 8 deletions
diff --git a/torch/generic/Storage.c b/torch/generic/Storage.c
index 9036e49..0623fb4 100644
--- a/torch/generic/Storage.c
+++ b/torch/generic/Storage.c
@@ -26,7 +26,7 @@ static int torch_Storage_(new)(lua_State *L)
THCStorage_(free)(state, storage);
luaL_error(L, "element at index %d is not a number", i);
}
- THCStorage_(set)(state, storage, i-1, (real)lua_tonumber(L, -1));
+ THCStorage_(set)(state, storage, i-1, (hostreal)lua_tonumber(L, -1));
lua_pop(L, 1);
}
}
@@ -118,14 +118,14 @@ static int torch_Storage_(fill)(lua_State *L)
{
THCStorage *storage = luaT_checkudata(L, 1, torch_Storage);
double value = luaL_checknumber(L, 2);
- THCStorage_(fill)(cutorch_getstate(L), storage, (real)value);
+ THCStorage_(fill)(cutorch_getstate(L), storage, (hostreal)value);
lua_settop(L, 1);
return 1;
}
static int torch_Storage_(elementSize)(lua_State *L)
{
- lua_pushnumber(L, THStorage_(elementSize)(cutorch_getstate(L)));
+ lua_pushnumber(L, THCStorage_(elementSize)(cutorch_getstate(L)));
return 1;
}
@@ -143,7 +143,7 @@ static int torch_Storage_(__newindex__)(lua_State *L)
THCStorage *storage = luaT_checkudata(L, 1, torch_Storage);
long index = luaL_checklong(L, 2) - 1;
double number = luaL_checknumber(L, 3);
- THCStorage_(set)(cutorch_getstate(L), storage, index, (real)number);
+ THCStorage_(set)(cutorch_getstate(L), storage, index, (hostreal)number);
lua_pushboolean(L, 1);
}
else
@@ -192,12 +192,18 @@ static int torch_Storage_(totable)(lua_State *L)
{
THCState *state = cutorch_getstate(L);
THCStorage *storage = luaT_checkudata(L, 1, torch_Storage);
- THStorage *host_storage;
long i;
/* Copy storage from device to host. */
- host_storage = THStorage_(newWithSize)(THCStorage_(size)(state, storage));
+#ifndef THC_REAL_IS_HALF
+ THStorage *host_storage =
+ THStorage_(newWithSize)(THCStorage_(size)(state, storage));
THStorage_(copyCuda)(state, host_storage, storage);
+#else
+ THFloatStorage *host_storage =
+ THFloatStorage_newWithSize(THCStorage_(size)(state, storage));
+ THFloatStorage_copyCudaHalf(state, host_storage, storage);
+#endif
lua_newtable(L);
for(i = 0; i < storage->size; i++)
@@ -205,7 +211,11 @@ static int torch_Storage_(totable)(lua_State *L)
lua_pushnumber(L, (lua_Number)host_storage->data[i]);
lua_rawseti(L, -2, i+1);
}
+#ifndef THC_REAL_IS_HALF
THStorage_(free)(host_storage);
+#else
+ THFloatStorage_free(host_storage);
+#endif
return 1;
}
diff --git a/torch/generic/Tensor.c b/torch/generic/Tensor.c
index bbed718..a1c5489 100644
--- a/torch/generic/Tensor.c
+++ b/torch/generic/Tensor.c
@@ -27,7 +27,7 @@ static int torch_Tensor_(size)(lua_State *L)
static int torch_Tensor_(elementSize)(lua_State *L)
{
- lua_pushnumber(L, THStorage_(elementSize)(cutorch_getstate(L)));
+ lua_pushnumber(L, THCStorage_(elementSize)(cutorch_getstate(L)));
return 1;
}
@@ -140,7 +140,7 @@ static int torch_Tensor_(new)(lua_State *L)
THCTensor_(free)(state, tensor);
luaL_error(L, "invalid element (not a number)");
}
- THCStorage_(set)(state, THCTensor_(storage)(state, tensor), si++, (real)lua_tonumber(L, -1));
+ THCStorage_(set)(state, THCTensor_(storage)(state, tensor), si++, (hostreal)lua_tonumber(L, -1));
lua_pop(L, 1);
}