Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/threads-ffi.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRonan Collobert <ronan@collobert.com>2015-04-08 05:08:49 +0300
committerRonan Collobert <ronan@collobert.com>2015-04-08 05:08:49 +0300
commitb5cefe962e71ff4cc951f0e72ed0f91fca86c040 (patch)
treeedda7dc663f95078c2c43e3f0562c1b80df1b11d /sharedserialize.lua
parent5661155854c7f5135bc8f78344b39574f53f7635 (diff)
sharedserialize: do not assume things are available in ffi.C
Diffstat (limited to 'sharedserialize.lua')
-rw-r--r--sharedserialize.lua13
1 files changed, 9 insertions, 4 deletions
diff --git a/sharedserialize.lua b/sharedserialize.lua
index 9cbcc1a..f940f1a 100644
--- a/sharedserialize.lua
+++ b/sharedserialize.lua
@@ -1,10 +1,12 @@
local ffi = require 'ffi'
-local C = ffi.C
require 'torch'
local status, tds = pcall(require, 'tds')
+-- definitions for TH
+local TH = ffi.load(package.searchpath('libtorch', package.cpath))
+
ffi.cdef[[
void free(void *ptr);
void *malloc(size_t size);
@@ -28,7 +30,10 @@ void THFloatStorage_retain(THFloatStorage *self);
void THDoubleStorage_retain(THDoubleStorage *self);
]]
+-- definitions for THC
+local THC
if torch.CudaTensor then
+ THC = ffi.load(package.searchpath('libcutorch', package.cpath))
ffi.cdef[[
void THCudaTensor_retain(THCudaTensor *self);
void THCudaStorage_retain(THCudaStorage *self);
@@ -109,7 +114,7 @@ for _, typename in ipairs{
if typenames[typename] then
local mt = typenames[typename]
local thname = typename:gsub('torch%.', 'TH')
- local retain = C[thname .. '_retain']
+ local retain = thname:match('Cuda') and THC[thname .. '_retain'] or TH[thname .. '_retain']
function mt.__factory(f)
local self = f:readLong()
@@ -160,7 +165,7 @@ function serialize.save(func)
local code_p = storage:data()
local sz = storage:size()
-- refcounted, but do not free mem
- C.THCharStorage_clearFlag(storage:cdata(), 4)
+ TH.THCharStorage_clearFlag(storage:cdata(), 4)
f:close()
return code_p, sz
end
@@ -188,7 +193,7 @@ function serialize.load(code_p, sz)
local status, func = pcall(
function()
- local storage_p = C.THCharStorage_newWithData(code_p, sz)
+ local storage_p = TH.THCharStorage_newWithData(code_p, sz)
local storage = torch.pushudata(storage_p, 'torch.CharStorage')
local f = torch.MemoryFile(storage)
f:binary()