local ffi = require 'ffi' local THNN = require 'nn.THNN' local THCUNN = {} -- load libTHCUNN THCUNN.C = ffi.load(package.searchpath('libTHCUNN', package.cpath)) -- load THC local THC = ffi.os == 'Windows' and ffi.load('THC') or ffi.C local THCState_ptr = ffi.typeof('THCState*') function THCUNN.getState() return THCState_ptr(cutorch.getState()); end local THCUNN_generic_h = require 'cunn.THCUNN_generic_h' -- strip all lines starting with # -- to remove preprocessor directives originally present -- in THNN.h THCUNN_generic_h = THCUNN_generic_h:gsub("\n#[^\n]*", "") THCUNN_generic_h = THCUNN_generic_h:gsub("^#[^\n]*\n", "") local preprocessed_generic = string.gsub(THCUNN_generic_h, 'TH_API void THNN_%(([%a%d_]+)%)', 'void THNN_TYPE%1') local replacements = { { ['THTensor'] = 'THCudaTensor', ['THCIndexTensor'] = 'THCudaLongTensor', ['THIndex_t'] = 'long', ['THInteger_t'] = 'float' } } local cct2lt = { ['THCudaFloatTensor'] = 'torch.CudaTensor', ['THCudaDoubleTensor'] = 'torch.CudaDoubleTensor', } local replacements_generic = { { ['THCTensor'] = 'THCudaTensor', ['THCIndexTensor'] = 'THCudaLongTensor', ['TYPE'] = 'Cuda', ['accreal'] = 'float', }, { ['THCTensor'] = 'THCudaDoubleTensor', ['THCIndexTensor'] = 'THCudaLongTensor', ['TYPE'] = 'CudaDouble', ['accreal'] = 'double', } } if cutorch.hasHalf then ffi.cdef("half THC_float2half(float a);") ffi.cdef("float THC_half2float(half a);") cct2lt['THCudaHalfTensor'] = 'torch.CudaHalfTensor' local half_replacement = { ['THCTensor'] = 'THCudaHalfTensor', ['THCIndexTensor'] = 'THCudaLongTensor', ['TYPE'] = 'CudaHalf', ['accreal'] = 'float', } table.insert(replacements_generic, half_replacement) end for i=1,#replacements_generic do local r = replacements_generic[i] local s = preprocessed_generic for k,v in pairs(r) do s = string.gsub(s, k, v) end ffi.cdef(s) end local function extract_function_names_generic(s) local t = {} for n in string.gmatch(s, 'TH_API void THNN_%(([%a%d_]+)%)') do t[#t+1] = n end return t end local function find_positions(s, p) local begin = 0 local positions = {} while true do local start, stop = string.find(s, p, begin) if (start == nil) then break end positions[#positions+1] = start begin = stop + 1 end return positions end local function extract_function_names_and_real_args(s) local t = {} for n in string.gmatch(s, 'TH_API ([^;]+)') do local func_name = string.match(n, 'void THNN_%(([%a%d_]+)%)') local param_positions = find_positions(n, ',') local positions = {} for x,y in ipairs(find_positions(n, 'real')) do local found = false for cn,cp in ipairs(param_positions) do if cp > y then positions[#positions+1] = cn found = true break end end -- it is the last param if not found then positions[#positions+1] = #param_positions + 1 end end t[func_name] = positions end return t end local real_args = extract_function_names_and_real_args(THCUNN_generic_h) -- build function table local function_names_generic = extract_function_names_generic(THCUNN_generic_h) THNN.kernels['torch.CudaTensor'] = THNN.bind(THCUNN.C, function_names_generic, 'Cuda', THCUNN.getState) torch.getmetatable('torch.CudaTensor').THNN = THNN.kernels['torch.CudaTensor'] THNN.kernels['torch.CudaDoubleTensor'] = THNN.bind(THCUNN.C, function_names_generic, 'CudaDouble', THCUNN.getState) torch.getmetatable('torch.CudaDoubleTensor').THNN = THNN.kernels['torch.CudaDoubleTensor'] if cutorch.hasHalf then local raw_half_functions = THNN.bind(THCUNN.C, function_names_generic, 'CudaHalf', THCUNN.getState) THNN.kernels['torch.CudaHalfTensor'] = raw_half_functions torch.getmetatable('torch.CudaHalfTensor').THNN = THNN.kernels['torch.CudaHalfTensor'] end local function Module__converter(type) return function(self) return self:type(type) end end rawset(torch.getmetatable('nn.Module'), 'cudaDouble', Module__converter('torch.CudaDoubleTensor')) if cutorch.hasHalf then rawset(torch.getmetatable('nn.Module'), 'cudaHalf', Module__converter('torch.CudaHalfTensor')) end return THCUNN