diff options
author | Pavan Yalamanchili <pyalamanchili@twitter.com> | 2017-02-17 04:25:33 +0300 |
---|---|---|
committer | Pavan Yalamanchili <pyalamanchili@twitter.com> | 2017-02-17 04:33:03 +0300 |
commit | 3996dbb87ec79d087c37bc6f4fe8f23a3767c88c (patch) | |
tree | d79dced267c31b5250930087ebd0e9fee21be340 /THCUNN.lua | |
parent | 618f847d94ad65baef1c1614ed241d6e4bea7151 (diff) |
Convert real to accreal in libTHCUNN
- This reverts commit 0d85922d116879448485ef88ae21e83a9255a0b0.
- Includes fixes for TemporalRowConvolution
Diffstat (limited to 'THCUNN.lua')
-rw-r--r-- | THCUNN.lua | 40 |
1 files changed, 15 insertions, 25 deletions
@@ -45,7 +45,7 @@ local replacements_generic = ['THCTensor'] = 'THCudaTensor', ['THCIndexTensor'] = 'THCudaLongTensor', ['TYPE'] = 'Cuda', - ['real'] = 'float' + ['real'] = 'float', }, { ['THCTensor'] = 'THCudaDoubleTensor', @@ -55,6 +55,13 @@ local replacements_generic = } } +-- gsub(s, 'real', 'float') changes accreal to accfloat. +-- typedef accfloat ahead of time. +ffi.cdef("typedef float accfloat;") +-- gsub(s, 'real', 'double') changes accreal to accfloat. +-- typedef accdouble ahead of time +ffi.cdef("typedef double accdouble;") + if cutorch.hasHalf then ffi.cdef("half THC_float2half(float a);") ffi.cdef("float THC_half2float(half a);") @@ -63,9 +70,12 @@ if cutorch.hasHalf then ['THCTensor'] = 'THCudaHalfTensor', ['THCIndexTensor'] = 'THCudaLongTensor', ['TYPE'] = 'CudaHalf', - ['real'] = 'half' + ['real'] = 'half', } table.insert(replacements_generic, half_replacement) + -- gsub(s, 'real', 'double') changes accreal to accfloat. + -- typedef acchalf ahead of time + ffi.cdef("typedef float acchalf;") end for i=1,#replacements_generic do @@ -133,29 +143,9 @@ THNN.kernels['torch.CudaDoubleTensor'] = THNN.bind(THCUNN.C, function_names_gene torch.getmetatable('torch.CudaDoubleTensor').THNN = THNN.kernels['torch.CudaDoubleTensor'] if cutorch.hasHalf then --- in order to call 'half' functions from lua, convert real arguments from --- to half since there is no other defined conversion -local transform_reals_to_half = function(func_name, real_args, ...) - local t = {} - -- this select logic is necessary to deal with nil arguments - for i = 1, select('#', ...) do - t[i] = select(i, ...) - end - for k,v in ipairs(real_args[func_name]) do - -- first argument (THCState) is added implicitly by bind - t[v-1] = THC.THC_float2half(t[v-1]) - end - return t -end - -local raw_half_functions = THNN.bind(THCUNN.C, function_names_generic, 'CudaHalf', THCUNN.getState) -for k,v in pairs(raw_half_functions) do - -- select required in case there are trailing nils - raw_half_functions[k] = function(...) v(unpack(transform_reals_to_half(k, real_args, ...), 1, select("#",...))) -end -end -THNN.kernels['torch.CudaHalfTensor'] = raw_half_functions -torch.getmetatable('torch.CudaHalfTensor').THNN = THNN.kernels['torch.CudaHalfTensor'] + local raw_half_functions = THNN.bind(THCUNN.C, function_names_generic, 'CudaHalf', THCUNN.getState) + THNN.kernels['torch.CudaHalfTensor'] = raw_half_functions + torch.getmetatable('torch.CudaHalfTensor').THNN = THNN.kernels['torch.CudaHalfTensor'] end local function Module__converter(type) |