Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cunn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2017-01-26 00:13:20 +0300
committerGitHub <noreply@github.com>2017-01-26 00:13:20 +0300
commit0d85922d116879448485ef88ae21e83a9255a0b0 (patch)
tree87449f65566e7c1b6d68e6b6671bb3d18083c600 /THCUNN.lua
parent87223032d716826207c97bdac72ccc269225790d (diff)
Revert "Convert real to accreal in libTHCUNN"revert-416-half-fixes
Diffstat (limited to 'THCUNN.lua')
-rw-r--r--THCUNN.lua40
1 files changed, 25 insertions, 15 deletions
diff --git a/THCUNN.lua b/THCUNN.lua
index d5bf1c2..6776a23 100644
--- a/THCUNN.lua
+++ b/THCUNN.lua
@@ -45,7 +45,7 @@ local replacements_generic =
['THCTensor'] = 'THCudaTensor',
['THCIndexTensor'] = 'THCudaLongTensor',
['TYPE'] = 'Cuda',
- ['real'] = 'float',
+ ['real'] = 'float'
},
{
['THCTensor'] = 'THCudaDoubleTensor',
@@ -55,13 +55,6 @@ local replacements_generic =
}
}
--- gsub(s, 'real', 'float') changes accreal to accfloat.
--- typedef accfloat ahead of time.
-ffi.cdef("typedef float accfloat;")
--- gsub(s, 'real', 'double') changes accreal to accfloat.
--- typedef accdouble ahead of time
-ffi.cdef("typedef double accdouble;")
-
if cutorch.hasHalf then
ffi.cdef("half THC_float2half(float a);")
ffi.cdef("float THC_half2float(half a);")
@@ -70,12 +63,9 @@ if cutorch.hasHalf then
['THCTensor'] = 'THCudaHalfTensor',
['THCIndexTensor'] = 'THCudaLongTensor',
['TYPE'] = 'CudaHalf',
- ['real'] = 'half',
+ ['real'] = 'half'
}
table.insert(replacements_generic, half_replacement)
- -- gsub(s, 'real', 'double') changes accreal to accfloat.
- -- typedef acchalf ahead of time
- ffi.cdef("typedef float acchalf;")
end
for i=1,#replacements_generic do
@@ -143,9 +133,29 @@ THNN.kernels['torch.CudaDoubleTensor'] = THNN.bind(THCUNN.C, function_names_gene
torch.getmetatable('torch.CudaDoubleTensor').THNN = THNN.kernels['torch.CudaDoubleTensor']
if cutorch.hasHalf then
- local raw_half_functions = THNN.bind(THCUNN.C, function_names_generic, 'CudaHalf', THCUNN.getState)
- THNN.kernels['torch.CudaHalfTensor'] = raw_half_functions
- torch.getmetatable('torch.CudaHalfTensor').THNN = THNN.kernels['torch.CudaHalfTensor']
+-- in order to call 'half' functions from lua, convert real arguments from
+-- to half since there is no other defined conversion
+local transform_reals_to_half = function(func_name, real_args, ...)
+ t = {}
+ -- this select logic is necessary to deal with nil arguments
+ for i = 1, select('#', ...) do
+ t[i] = select(i, ...)
+ end
+ for k,v in ipairs(real_args[func_name]) do
+ -- first argument (THCState) is added implicitly by bind
+ t[v-1] = THC.THC_float2half(t[v-1])
+ end
+ return t
+end
+
+local raw_half_functions = THNN.bind(THCUNN.C, function_names_generic, 'CudaHalf', THCUNN.getState)
+for k,v in pairs(raw_half_functions) do
+ -- select required in case there are trailing nils
+ raw_half_functions[k] = function(...) v(unpack(transform_reals_to_half(k, real_args, ...), 1, select("#",...)))
+end
+end
+THNN.kernels['torch.CudaHalfTensor'] = raw_half_functions
+torch.getmetatable('torch.CudaHalfTensor').THNN = THNN.kernels['torch.CudaHalfTensor']
end
local function Module__converter(type)