diff options
author | Soumith Chintala <soumith@gmail.com> | 2017-01-26 00:13:20 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-01-26 00:13:20 +0300 |
commit | 0d85922d116879448485ef88ae21e83a9255a0b0 (patch) | |
tree | 87449f65566e7c1b6d68e6b6671bb3d18083c600 /THCUNN.lua | |
parent | 87223032d716826207c97bdac72ccc269225790d (diff) |
Revert "Convert real to accreal in libTHCUNN"revert-416-half-fixes
Diffstat (limited to 'THCUNN.lua')
-rw-r--r-- | THCUNN.lua | 40 |
1 files changed, 25 insertions, 15 deletions
@@ -45,7 +45,7 @@ local replacements_generic = ['THCTensor'] = 'THCudaTensor', ['THCIndexTensor'] = 'THCudaLongTensor', ['TYPE'] = 'Cuda', - ['real'] = 'float', + ['real'] = 'float' }, { ['THCTensor'] = 'THCudaDoubleTensor', @@ -55,13 +55,6 @@ local replacements_generic = } } --- gsub(s, 'real', 'float') changes accreal to accfloat. --- typedef accfloat ahead of time. -ffi.cdef("typedef float accfloat;") --- gsub(s, 'real', 'double') changes accreal to accfloat. --- typedef accdouble ahead of time -ffi.cdef("typedef double accdouble;") - if cutorch.hasHalf then ffi.cdef("half THC_float2half(float a);") ffi.cdef("float THC_half2float(half a);") @@ -70,12 +63,9 @@ if cutorch.hasHalf then ['THCTensor'] = 'THCudaHalfTensor', ['THCIndexTensor'] = 'THCudaLongTensor', ['TYPE'] = 'CudaHalf', - ['real'] = 'half', + ['real'] = 'half' } table.insert(replacements_generic, half_replacement) - -- gsub(s, 'real', 'double') changes accreal to accfloat. - -- typedef acchalf ahead of time - ffi.cdef("typedef float acchalf;") end for i=1,#replacements_generic do @@ -143,9 +133,29 @@ THNN.kernels['torch.CudaDoubleTensor'] = THNN.bind(THCUNN.C, function_names_gene torch.getmetatable('torch.CudaDoubleTensor').THNN = THNN.kernels['torch.CudaDoubleTensor'] if cutorch.hasHalf then - local raw_half_functions = THNN.bind(THCUNN.C, function_names_generic, 'CudaHalf', THCUNN.getState) - THNN.kernels['torch.CudaHalfTensor'] = raw_half_functions - torch.getmetatable('torch.CudaHalfTensor').THNN = THNN.kernels['torch.CudaHalfTensor'] +-- in order to call 'half' functions from lua, convert real arguments from +-- to half since there is no other defined conversion +local transform_reals_to_half = function(func_name, real_args, ...) + t = {} + -- this select logic is necessary to deal with nil arguments + for i = 1, select('#', ...) do + t[i] = select(i, ...) + end + for k,v in ipairs(real_args[func_name]) do + -- first argument (THCState) is added implicitly by bind + t[v-1] = THC.THC_float2half(t[v-1]) + end + return t +end + +local raw_half_functions = THNN.bind(THCUNN.C, function_names_generic, 'CudaHalf', THCUNN.getState) +for k,v in pairs(raw_half_functions) do + -- select required in case there are trailing nils + raw_half_functions[k] = function(...) v(unpack(transform_reals_to_half(k, real_args, ...), 1, select("#",...))) +end +end +THNN.kernels['torch.CudaHalfTensor'] = raw_half_functions +torch.getmetatable('torch.CudaHalfTensor').THNN = THNN.kernels['torch.CudaHalfTensor'] end local function Module__converter(type) |