diff options
author | Soumith Chintala <soumith@gmail.com> | 2016-06-11 03:11:10 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2016-06-11 03:11:10 +0300 |
commit | bef86f7dd0158b58f01f8219df0740174acb9b0c (patch) | |
tree | 9a23586388c75401faca1c2ae6e81dcdc3b79dc8 | |
parent | 20e0319f910e50e85f48756aa4fee2e15925bb0c (diff) | |
parent | dc4a00d34d95e70c559a2df76f20682ce90578e6 (diff) |
Merge pull request #427 from torch/halfmath
add half cwrap type and enable math for CudaHalfTensor
-rw-r--r-- | TensorMath.lua | 120 | ||||
-rw-r--r-- | init.c | 7 |
2 files changed, 111 insertions, 16 deletions
diff --git a/TensorMath.lua b/TensorMath.lua index 579f27c..17d7547 100644 --- a/TensorMath.lua +++ b/TensorMath.lua @@ -29,12 +29,14 @@ local unpack = unpack or table.unpack -- specific to CUDA local typenames = {'CudaByteTensor', - 'CudaCharTensor', - 'CudaShortTensor', - 'CudaIntTensor', - 'CudaLongTensor', - 'CudaTensor', - 'CudaDoubleTensor'} + 'CudaCharTensor', + 'CudaShortTensor', + 'CudaIntTensor', + 'CudaLongTensor', + 'CudaTensor', + 'CudaDoubleTensor', + 'CudaHalfTensor' +} for _, typename in ipairs(typenames) do -- cut and paste from wrap/types.lua @@ -190,6 +192,81 @@ wrap.types[typename .. 'Array'] = { } end +local function interpretdefaultvalue(arg) + local default = arg.default + if type(default) == 'boolean' then + if default then + return '1' + else + return '0' + end + elseif type(default) == 'number' then + return tostring(default) + elseif type(default) == 'string' then + return default + elseif type(default) == 'function' then + default = default(arg) + assert(type(default) == 'string', 'a default function must return a string') + return default + elseif type(default) == 'nil' then + return nil + else + error('unknown default type value') + end +end + +wrap.types.half = { + + helpname = function(arg) + return "half" + end, + + declare = function(arg) + -- if it is a number we initialize here + local default = tonumber(interpretdefaultvalue(arg)) or 0 + return string.format("half arg%d = THC_float2half((float) %d);", arg.i, tonumber(default)) + end, + + check = function(arg, idx) + return string.format("lua_isnumber(L, %d)", idx) + end, + + read = function(arg, idx) + return string.format("arg%d = THC_float2half((float) lua_tonumber(L, %d));", arg.i, idx) + end, + + init = function(arg) + -- otherwise do it here + if arg.default then + local default = interpretdefaultvalue(arg) + if not tonumber(default) then + return string.format("arg%d = THC_float2half((float) %s);", arg.i, default) + end + end + end, + + carg = function(arg) + return string.format('arg%d', arg.i) + end, + + creturn = function(arg) + return string.format('arg%d', arg.i) + end, + + precall = function(arg) + if arg.returned then + return string.format('lua_pushnumber(L, (lua_Number) THC_half2float(arg%d));', arg.i) + end + end, + + postcall = function(arg) + if arg.creturned then + return string.format('lua_pushnumber(L, (lua_Number) THC_half2float(arg%d));', arg.i) + end + end + +} + wrap.types.LongArg = { vararg = true, @@ -338,19 +415,26 @@ end -- local handledTypenames = {'CudaByteTensor', - 'CudaCharTensor', - 'CudaShortTensor', - 'CudaIntTensor', - 'CudaLongTensor', - 'CudaDoubleTensor'} + 'CudaCharTensor', + 'CudaShortTensor', + 'CudaIntTensor', + 'CudaLongTensor', + 'CudaDoubleTensor', + 'CudaHalfTensor', +} local handledTypereals = {'unsigned char', - 'char', - 'short', - 'int', - 'long', - 'double'} + 'char', + 'short', + 'int', + 'long', + 'double', + 'half' +} for k, Tensor in pairs(handledTypenames) do + if Tensor == 'CudaHalfTensor' then + interface:print("#ifdef CUDA_HALF_TENSOR") + end local real = handledTypereals[k] function interface.luaname2wrapname(self, name) @@ -461,6 +545,10 @@ void cutorch_%sMath_init(lua_State *L) lua_pop(L, 1); } ]], Tensor, Tensor, Tensor, Tensor)) + + if Tensor == 'CudaHalfTensor' then + interface:print("#endif") + end end @@ -35,6 +35,10 @@ extern void cutorch_CudaIntTensorMath_init(lua_State* L); extern void cutorch_CudaLongTensorMath_init(lua_State* L); extern void cutorch_CudaTensorMath_init(lua_State* L); extern void cutorch_CudaDoubleTensorMath_init(lua_State* L); +#ifdef CUDA_HALF_TENSOR +extern void cutorch_CudaHalfTensorMath_init(lua_State* L); +#endif + /* Iteration utilities for lists of streams and lists of gpus with streams @@ -985,6 +989,9 @@ int luaopen_libcutorch(lua_State *L) cutorch_CudaLongTensorMath_init(L); cutorch_CudaTensorMath_init(L); cutorch_CudaDoubleTensorMath_init(L); +#ifdef CUDA_HALF_TENSOR + cutorch_CudaHalfTensorMath_init(L); +#endif cutorch_Event_init(L); |