diff options
author | Ronan Collobert <ronan@collobert.com> | 2015-04-08 05:08:49 +0300 |
---|---|---|
committer | Ronan Collobert <ronan@collobert.com> | 2015-04-08 05:08:49 +0300 |
commit | b5cefe962e71ff4cc951f0e72ed0f91fca86c040 (patch) | |
tree | edda7dc663f95078c2c43e3f0562c1b80df1b11d /sharedserialize.lua | |
parent | 5661155854c7f5135bc8f78344b39574f53f7635 (diff) |
sharedserialize: do not assume things are available in ffi.C
Diffstat (limited to 'sharedserialize.lua')
-rw-r--r-- | sharedserialize.lua | 13 |
1 files changed, 9 insertions, 4 deletions
diff --git a/sharedserialize.lua b/sharedserialize.lua index 9cbcc1a..f940f1a 100644 --- a/sharedserialize.lua +++ b/sharedserialize.lua @@ -1,10 +1,12 @@ local ffi = require 'ffi' -local C = ffi.C require 'torch' local status, tds = pcall(require, 'tds') +-- definitions for TH +local TH = ffi.load(package.searchpath('libtorch', package.cpath)) + ffi.cdef[[ void free(void *ptr); void *malloc(size_t size); @@ -28,7 +30,10 @@ void THFloatStorage_retain(THFloatStorage *self); void THDoubleStorage_retain(THDoubleStorage *self); ]] +-- definitions for THC +local THC if torch.CudaTensor then + THC = ffi.load(package.searchpath('libcutorch', package.cpath)) ffi.cdef[[ void THCudaTensor_retain(THCudaTensor *self); void THCudaStorage_retain(THCudaStorage *self); @@ -109,7 +114,7 @@ for _, typename in ipairs{ if typenames[typename] then local mt = typenames[typename] local thname = typename:gsub('torch%.', 'TH') - local retain = C[thname .. '_retain'] + local retain = thname:match('Cuda') and THC[thname .. '_retain'] or TH[thname .. '_retain'] function mt.__factory(f) local self = f:readLong() @@ -160,7 +165,7 @@ function serialize.save(func) local code_p = storage:data() local sz = storage:size() -- refcounted, but do not free mem - C.THCharStorage_clearFlag(storage:cdata(), 4) + TH.THCharStorage_clearFlag(storage:cdata(), 4) f:close() return code_p, sz end @@ -188,7 +193,7 @@ function serialize.load(code_p, sz) local status, func = pcall( function() - local storage_p = C.THCharStorage_newWithData(code_p, sz) + local storage_p = TH.THCharStorage_newWithData(code_p, sz) local storage = torch.pushudata(storage_p, 'torch.CharStorage') local f = torch.MemoryFile(storage) f:binary() |