Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cunn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGregory Chanan <gchanan@fb.com>2016-11-09 00:48:48 +0300
committerGregory Chanan <gchanan@fb.com>2016-11-09 00:49:30 +0300
commit27479c372040b8cab4e53e9338e8ce840bdb67dd (patch)
treef0a89adfb00e7a49031ca32a9badcf016bc599cb /THCUNN.lua
parent604d6fffa9913cbcffb2cb32a1660a5cc4e893ab (diff)
Remove non-generic support for modules.
Diffstat (limited to 'THCUNN.lua')
-rw-r--r--THCUNN.lua33
1 files changed, 1 insertions, 32 deletions
diff --git a/THCUNN.lua b/THCUNN.lua
index 3dc0186..490cd5c 100644
--- a/THCUNN.lua
+++ b/THCUNN.lua
@@ -12,13 +12,6 @@ function THCUNN.getState()
return THCState_ptr(cutorch.getState());
end
-local THCUNN_h = require 'cunn.THCUNN_h'
--- strip all lines starting with #
--- to remove preprocessor directives originally present
--- in THNN.h
-THCUNN_h = THCUNN_h:gsub("\n#[^\n]*", "")
-THCUNN_h = THCUNN_h:gsub("^#[^\n]*\n", "")
-
local THCUNN_generic_h = require 'cunn.THCUNN_generic_h'
-- strip all lines starting with #
-- to remove preprocessor directives originally present
@@ -26,7 +19,6 @@ local THCUNN_generic_h = require 'cunn.THCUNN_generic_h'
THCUNN_generic_h = THCUNN_generic_h:gsub("\n#[^\n]*", "")
THCUNN_generic_h = THCUNN_generic_h:gsub("^#[^\n]*\n", "")
-local preprocessed = string.gsub(THCUNN_h, 'TH_API ', '')
local preprocessed_generic = string.gsub(THCUNN_generic_h, 'TH_API void THNN_%(([%a%d_]+)%)', 'void THNN_TYPE%1')
local replacements =
@@ -73,15 +65,6 @@ if cutorch.hasHalf then
table.insert(replacements_generic, half_replacement)
end
-for i=1,#replacements do
- local r = replacements[i]
- local s = preprocessed
- for k,v in pairs(r) do
- s = string.gsub(s, k, v)
- end
- ffi.cdef(s)
-end
-
for i=1,#replacements_generic do
local r = replacements_generic[i]
local s = preprocessed_generic
@@ -91,14 +74,6 @@ for i=1,#replacements_generic do
ffi.cdef(s)
end
-local function extract_function_names(s)
- local t = {}
- for n in string.gmatch(s, 'TH_API void THNN_Cuda([%a%d_]+)') do
- t[#t+1] = n
- end
- return t
-end
-
local function extract_function_names_generic(s)
local t = {}
for n in string.gmatch(s, 'TH_API void THNN_%(([%a%d_]+)%)') do
@@ -146,15 +121,9 @@ end
local real_args = extract_function_names_and_real_args(THCUNN_generic_h)
-- build function table
-local function_names = extract_function_names(THCUNN_h)
local function_names_generic = extract_function_names_generic(THCUNN_generic_h)
--- combine function names for CudaTensor
-for k,v in pairs(real_args) do
- function_names[#function_names+1] = k
-end
-
-THNN.kernels['torch.CudaTensor'] = THNN.bind(THCUNN.C, function_names, 'Cuda', THCUNN.getState)
+THNN.kernels['torch.CudaTensor'] = THNN.bind(THCUNN.C, function_names_generic, 'Cuda', THCUNN.getState)
torch.getmetatable('torch.CudaTensor').THNN = THNN.kernels['torch.CudaTensor']
THNN.kernels['torch.CudaDoubleTensor'] = THNN.bind(THCUNN.C, function_names_generic, 'CudaDouble', THCUNN.getState)