Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/cunn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2017-01-06 19:04:32 +0300
committerGitHub <noreply@github.com>2017-01-06 19:04:32 +0300
commit349df42dfa550389f04ac90ea621f21b2838b00c (patch)
treece1d9668073f0db525faf1c204a9a61c5cc65525
parent103fbffaf5ae32cbbe2bf761c858477f191b08b7 (diff)
parent0cb41a138ed1914294d2a1c4491e3af239c26237 (diff)
Merge pull request #409 from BTNC/windows
explicitly load THC for windows
-rw-r--r--THCUNN.lua5
-rw-r--r--test.lua15
2 files changed, 13 insertions, 7 deletions
diff --git a/THCUNN.lua b/THCUNN.lua
index 8bbca17..6776a23 100644
--- a/THCUNN.lua
+++ b/THCUNN.lua
@@ -6,6 +6,9 @@ local THCUNN = {}
-- load libTHCUNN
THCUNN.C = ffi.load(package.searchpath('libTHCUNN', package.cpath))
+-- load THC
+local THC = ffi.os == 'Windows' and ffi.load('THC') or ffi.C
+
local THCState_ptr = ffi.typeof('THCState*')
function THCUNN.getState()
@@ -140,7 +143,7 @@ local transform_reals_to_half = function(func_name, real_args, ...)
end
for k,v in ipairs(real_args[func_name]) do
-- first argument (THCState) is added implicitly by bind
- t[v-1] = ffi.C.THC_float2half(t[v-1])
+ t[v-1] = THC.THC_float2half(t[v-1])
end
return t
end
diff --git a/test.lua b/test.lua
index 06cdbb6..c3ed9bb 100644
--- a/test.lua
+++ b/test.lua
@@ -5,6 +5,9 @@ local precision_backward = 1e-2
local nloop = 1
local times = {}
+-- load THC
+local THC = ffi.os == 'Windows' and ffi.load('THC') or ffi.C
+
--e.g.: th -lcunn -e "nn.testcuda{'Sigmoid_forward'}"
local typenames = {
@@ -362,17 +365,17 @@ function cunntest.Square_transposed()
end
function cunntest.SoftShrink_forward()
- local r = ffi.C.THC_half2float(ffi.C.THC_float2half(math.random()))
+ local r = THC.THC_half2float(THC.THC_float2half(math.random()))
pointwise_forward(nn.SoftShrink(r), 'SoftShrink', precision_forward)
end
function cunntest.SoftShrink_backward()
- local r = ffi.C.THC_half2float(ffi.C.THC_float2half(math.random()))
+ local r = THC.THC_half2float(THC.THC_float2half(math.random()))
pointwise_backward(nn.SoftShrink(r), 'SoftShrink', precision_backward)
end
function cunntest.SoftShrink_transposed()
- local r = ffi.C.THC_half2float(ffi.C.THC_float2half(math.random()))
+ local r = THC.THC_half2float(THC.THC_float2half(math.random()))
pointwise_transposed(nn.SoftShrink(r), 'SoftShrink', precision_backward)
end
@@ -3399,7 +3402,7 @@ function cunntest.mse()
local cgin = cmod:backward(cinput,ctarget)
if (typename == 'torch.CudaHalfTensor') then
- fout = ffi.C.THC_half2float(ffi.C.THC_float2half(fout))
+ fout = THC.THC_half2float(THC.THC_float2half(fout))
end
mytester:assertlt(math.abs(fout-cout), precision_forward_type(0.02, typename),
string.format('error on output with %s', typename))
@@ -3433,7 +3436,7 @@ function cunntest.SmoothL1()
local cgin = cmod:backward(cinput,ctarget)
if (typename == 'torch.CudaHalfTensor') then
- fout = ffi.C.THC_half2float(ffi.C.THC_float2half(fout))
+ fout = THC.THC_half2float(THC.THC_float2half(fout))
end
mytester:assertlt(math.abs(fout-cout), 0.01, string.format('error on output with %s', typename))
local gerr = cgin:double() - fgin:double()
@@ -3998,7 +4001,7 @@ function cunntest.l1cost()
local cgin = cmod:backward(cinput)
if (typename == 'torch.CudaHalfTensor') then
- fout = ffi.C.THC_half2float(ffi.C.THC_float2half(fout))
+ fout = THC.THC_half2float(THC.THC_float2half(fout))
end
mytester:assertlt(math.abs(fout-cout), precision_forward_type(precision_forward, typename),
string.format('error on output with %s', typename))