Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cunn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndreas Köpf <andreas.koepf@xamla.com>2015-12-16 03:20:35 +0300
committersoumith <soumith@gmail.com>2015-12-30 00:41:20 +0300
commitad6b29d881ea77310b23211c3d9cd808f1e44ecc (patch)
tree57ee2a878ae14df438f9112746b215e3423602a0 /THCUNN.lua
parent0ffd5491c0ece953b0182c89396ec960a3875623 (diff)
Add THCUNN/ffi conversion of Abs and AbsCriterion
Diffstat (limited to 'THCUNN.lua')
-rw-r--r--THCUNN.lua70
1 files changed, 70 insertions, 0 deletions
diff --git a/THCUNN.lua b/THCUNN.lua
new file mode 100644
index 0000000..7a793ff
--- /dev/null
+++ b/THCUNN.lua
@@ -0,0 +1,70 @@
+local ffi = require 'ffi'
+local THNN = require 'nn.THNN'
+
+local THCUNN = {}
+
+local THCState_ptr = ffi.typeof('THCState*')
+
+function THCUNN.getState()
+ return THCState_ptr(cutorch.getState());
+end
+
+local THCUNN_h = [[
+typedef void THCState;
+
+TH_API void THNN_CudaAbs_updateOutput(
+ THCState *state,
+ THCudaTensor *input,
+ THCudaTensor *output);
+TH_API void THNN_CudaAbs_updateGradInput(
+ THCState *state,
+ THCudaTensor *input,
+ THCudaTensor *gradOutput,
+ THCudaTensor *gradInput);
+
+TH_API void THNN_CudaAbsCriterion_updateOutput(
+ THCState *state,
+ THCudaTensor *input,
+ THCudaTensor *target,
+ float *output,
+ bool sizeAverage);
+TH_API void THNN_CudaAbsCriterion_updateGradInput(
+ THCState *state,
+ THCudaTensor *input,
+ THCudaTensor *target,
+ THCudaTensor *gradInput,
+ bool sizeAverage);
+]]
+
+
+local preprocessed = string.gsub(THCUNN_h, 'TH_API ', '')
+ffi.cdef(preprocessed)
+
+local ok,result
+if ffi.os == "OSX" then
+ ok,result = pcall(ffi.load, 'libTHCUNN.dylib')
+else
+ ok,result = pcall(ffi.load, 'THCUNN')
+end
+if not ok then
+ print(result)
+ error("Ops, could not load 'THCUNN' GPU backend library.")
+else
+ THCUNN.C = result
+end
+
+local function extract_function_names(s)
+ local t = {}
+ for n in string.gmatch(s, 'TH_API void THNN_Cuda([%a%d_]+)') do
+ t[#t+1] = n
+ end
+ return t
+end
+
+-- build function table
+local function_names = extract_function_names(THCUNN_h)
+
+THNN.kernels['torch.CudaTensor'] = THNN.bind(THCUNN.C, function_names, 'Cuda', THCUNN.getState)
+torch.getmetatable('torch.CudaTensor').THNN = THNN.kernels['torch.CudaTensor']
+
+return THCUNN