Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cunn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPavan Yalamanchili <pyalamanchili@twitter.com>2017-02-17 04:25:33 +0300
committerPavan Yalamanchili <pyalamanchili@twitter.com>2017-02-17 04:33:03 +0300
commit3996dbb87ec79d087c37bc6f4fe8f23a3767c88c (patch)
treed79dced267c31b5250930087ebd0e9fee21be340 /THCUNN.lua
parent618f847d94ad65baef1c1614ed241d6e4bea7151 (diff)
Convert real to accreal in libTHCUNN
- This reverts commit 0d85922d116879448485ef88ae21e83a9255a0b0. - Includes fixes for TemporalRowConvolution
Diffstat (limited to 'THCUNN.lua')
-rw-r--r--THCUNN.lua40
1 files changed, 15 insertions, 25 deletions
diff --git a/THCUNN.lua b/THCUNN.lua
index 573690b..d5bf1c2 100644
--- a/THCUNN.lua
+++ b/THCUNN.lua
@@ -45,7 +45,7 @@ local replacements_generic =
['THCTensor'] = 'THCudaTensor',
['THCIndexTensor'] = 'THCudaLongTensor',
['TYPE'] = 'Cuda',
- ['real'] = 'float'
+ ['real'] = 'float',
},
{
['THCTensor'] = 'THCudaDoubleTensor',
@@ -55,6 +55,13 @@ local replacements_generic =
}
}
+-- gsub(s, 'real', 'float') changes accreal to accfloat.
+-- typedef accfloat ahead of time.
+ffi.cdef("typedef float accfloat;")
+-- gsub(s, 'real', 'double') changes accreal to accfloat.
+-- typedef accdouble ahead of time
+ffi.cdef("typedef double accdouble;")
+
if cutorch.hasHalf then
ffi.cdef("half THC_float2half(float a);")
ffi.cdef("float THC_half2float(half a);")
@@ -63,9 +70,12 @@ if cutorch.hasHalf then
['THCTensor'] = 'THCudaHalfTensor',
['THCIndexTensor'] = 'THCudaLongTensor',
['TYPE'] = 'CudaHalf',
- ['real'] = 'half'
+ ['real'] = 'half',
}
table.insert(replacements_generic, half_replacement)
+ -- gsub(s, 'real', 'double') changes accreal to accfloat.
+ -- typedef acchalf ahead of time
+ ffi.cdef("typedef float acchalf;")
end
for i=1,#replacements_generic do
@@ -133,29 +143,9 @@ THNN.kernels['torch.CudaDoubleTensor'] = THNN.bind(THCUNN.C, function_names_gene
torch.getmetatable('torch.CudaDoubleTensor').THNN = THNN.kernels['torch.CudaDoubleTensor']
if cutorch.hasHalf then
--- in order to call 'half' functions from lua, convert real arguments from
--- to half since there is no other defined conversion
-local transform_reals_to_half = function(func_name, real_args, ...)
- local t = {}
- -- this select logic is necessary to deal with nil arguments
- for i = 1, select('#', ...) do
- t[i] = select(i, ...)
- end
- for k,v in ipairs(real_args[func_name]) do
- -- first argument (THCState) is added implicitly by bind
- t[v-1] = THC.THC_float2half(t[v-1])
- end
- return t
-end
-
-local raw_half_functions = THNN.bind(THCUNN.C, function_names_generic, 'CudaHalf', THCUNN.getState)
-for k,v in pairs(raw_half_functions) do
- -- select required in case there are trailing nils
- raw_half_functions[k] = function(...) v(unpack(transform_reals_to_half(k, real_args, ...), 1, select("#",...)))
-end
-end
-THNN.kernels['torch.CudaHalfTensor'] = raw_half_functions
-torch.getmetatable('torch.CudaHalfTensor').THNN = THNN.kernels['torch.CudaHalfTensor']
+ local raw_half_functions = THNN.bind(THCUNN.C, function_names_generic, 'CudaHalf', THCUNN.getState)
+ THNN.kernels['torch.CudaHalfTensor'] = raw_half_functions
+ torch.getmetatable('torch.CudaHalfTensor').THNN = THNN.kernels['torch.CudaHalfTensor']
end
local function Module__converter(type)