Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cutorch.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2016-06-11 03:11:10 +0300
committerGitHub <noreply@github.com>2016-06-11 03:11:10 +0300
commitbef86f7dd0158b58f01f8219df0740174acb9b0c (patch)
tree9a23586388c75401faca1c2ae6e81dcdc3b79dc8
parent20e0319f910e50e85f48756aa4fee2e15925bb0c (diff)
parentdc4a00d34d95e70c559a2df76f20682ce90578e6 (diff)
Merge pull request #427 from torch/halfmath
add half cwrap type and enable math for CudaHalfTensor
-rw-r--r--TensorMath.lua120
-rw-r--r--init.c7
2 files changed, 111 insertions, 16 deletions
diff --git a/TensorMath.lua b/TensorMath.lua
index 579f27c..17d7547 100644
--- a/TensorMath.lua
+++ b/TensorMath.lua
@@ -29,12 +29,14 @@ local unpack = unpack or table.unpack
-- specific to CUDA
local typenames = {'CudaByteTensor',
- 'CudaCharTensor',
- 'CudaShortTensor',
- 'CudaIntTensor',
- 'CudaLongTensor',
- 'CudaTensor',
- 'CudaDoubleTensor'}
+ 'CudaCharTensor',
+ 'CudaShortTensor',
+ 'CudaIntTensor',
+ 'CudaLongTensor',
+ 'CudaTensor',
+ 'CudaDoubleTensor',
+ 'CudaHalfTensor'
+}
for _, typename in ipairs(typenames) do
-- cut and paste from wrap/types.lua
@@ -190,6 +192,81 @@ wrap.types[typename .. 'Array'] = {
}
end
+local function interpretdefaultvalue(arg)
+ local default = arg.default
+ if type(default) == 'boolean' then
+ if default then
+ return '1'
+ else
+ return '0'
+ end
+ elseif type(default) == 'number' then
+ return tostring(default)
+ elseif type(default) == 'string' then
+ return default
+ elseif type(default) == 'function' then
+ default = default(arg)
+ assert(type(default) == 'string', 'a default function must return a string')
+ return default
+ elseif type(default) == 'nil' then
+ return nil
+ else
+ error('unknown default type value')
+ end
+end
+
+wrap.types.half = {
+
+ helpname = function(arg)
+ return "half"
+ end,
+
+ declare = function(arg)
+ -- if it is a number we initialize here
+ local default = tonumber(interpretdefaultvalue(arg)) or 0
+ return string.format("half arg%d = THC_float2half((float) %d);", arg.i, tonumber(default))
+ end,
+
+ check = function(arg, idx)
+ return string.format("lua_isnumber(L, %d)", idx)
+ end,
+
+ read = function(arg, idx)
+ return string.format("arg%d = THC_float2half((float) lua_tonumber(L, %d));", arg.i, idx)
+ end,
+
+ init = function(arg)
+ -- otherwise do it here
+ if arg.default then
+ local default = interpretdefaultvalue(arg)
+ if not tonumber(default) then
+ return string.format("arg%d = THC_float2half((float) %s);", arg.i, default)
+ end
+ end
+ end,
+
+ carg = function(arg)
+ return string.format('arg%d', arg.i)
+ end,
+
+ creturn = function(arg)
+ return string.format('arg%d', arg.i)
+ end,
+
+ precall = function(arg)
+ if arg.returned then
+ return string.format('lua_pushnumber(L, (lua_Number) THC_half2float(arg%d));', arg.i)
+ end
+ end,
+
+ postcall = function(arg)
+ if arg.creturned then
+ return string.format('lua_pushnumber(L, (lua_Number) THC_half2float(arg%d));', arg.i)
+ end
+ end
+
+}
+
wrap.types.LongArg = {
vararg = true,
@@ -338,19 +415,26 @@ end
--
local handledTypenames = {'CudaByteTensor',
- 'CudaCharTensor',
- 'CudaShortTensor',
- 'CudaIntTensor',
- 'CudaLongTensor',
- 'CudaDoubleTensor'}
+ 'CudaCharTensor',
+ 'CudaShortTensor',
+ 'CudaIntTensor',
+ 'CudaLongTensor',
+ 'CudaDoubleTensor',
+ 'CudaHalfTensor',
+}
local handledTypereals = {'unsigned char',
- 'char',
- 'short',
- 'int',
- 'long',
- 'double'}
+ 'char',
+ 'short',
+ 'int',
+ 'long',
+ 'double',
+ 'half'
+}
for k, Tensor in pairs(handledTypenames) do
+ if Tensor == 'CudaHalfTensor' then
+ interface:print("#ifdef CUDA_HALF_TENSOR")
+ end
local real = handledTypereals[k]
function interface.luaname2wrapname(self, name)
@@ -461,6 +545,10 @@ void cutorch_%sMath_init(lua_State *L)
lua_pop(L, 1);
}
]], Tensor, Tensor, Tensor, Tensor))
+
+ if Tensor == 'CudaHalfTensor' then
+ interface:print("#endif")
+ end
end
diff --git a/init.c b/init.c
index de365f4..9352ef7 100644
--- a/init.c
+++ b/init.c
@@ -35,6 +35,10 @@ extern void cutorch_CudaIntTensorMath_init(lua_State* L);
extern void cutorch_CudaLongTensorMath_init(lua_State* L);
extern void cutorch_CudaTensorMath_init(lua_State* L);
extern void cutorch_CudaDoubleTensorMath_init(lua_State* L);
+#ifdef CUDA_HALF_TENSOR
+extern void cutorch_CudaHalfTensorMath_init(lua_State* L);
+#endif
+
/*
Iteration utilities for lists of streams and lists of gpus with streams
@@ -985,6 +989,9 @@ int luaopen_libcutorch(lua_State *L)
cutorch_CudaLongTensorMath_init(L);
cutorch_CudaTensorMath_init(L);
cutorch_CudaDoubleTensorMath_init(L);
+#ifdef CUDA_HALF_TENSOR
+ cutorch_CudaHalfTensorMath_init(L);
+#endif
cutorch_Event_init(L);