Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/threads-ffi.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRonan Collobert <ronan@collobert.com>2015-04-05 04:57:38 +0300
committerRonan Collobert <ronan@collobert.com>2015-04-05 04:59:47 +0300
commit3bd8aa86f457e9d93c4947a5848b07fd16b58842 (patch)
treea2c4ce29640b82ef696bb0a5244e24ee5ab2b8aa
parenta283ccd0447d44f2b8c4899238f0ca90d126a874 (diff)
sharedserialize: now handles also torch storages + tds hash
added test case
-rw-r--r--sharedserialize.lua192
-rw-r--r--test/simple.lua111
2 files changed, 255 insertions, 48 deletions
diff --git a/sharedserialize.lua b/sharedserialize.lua
index b151cec..9cbcc1a 100644
--- a/sharedserialize.lua
+++ b/sharedserialize.lua
@@ -3,77 +3,154 @@ local C = ffi.C
require 'torch'
+local status, tds = pcall(require, 'tds')
+
ffi.cdef[[
void free(void *ptr);
void *malloc(size_t size);
THCharStorage* THCharStorage_newWithData(const char *data, long size);
void THCharStorage_clearFlag(THCharStorage *storage, const char flag);
+
+void THByteTensor_retain(THByteTensor *self);
+void THCharTensor_retain(THCharTensor *self);
+void THShortTensor_retain(THShortTensor *self);
+void THIntTensor_retain(THIntTensor *self);
+void THLongTensor_retain(THLongTensor *self);
+void THFloatTensor_retain(THFloatTensor *self);
+void THDoubleTensor_retain(THDoubleTensor *self);
+
+void THByteStorage_retain(THByteStorage *self);
+void THCharStorage_retain(THCharStorage *self);
+void THShortStorage_retain(THShortStorage *self);
+void THIntStorage_retain(THIntStorage *self);
+void THLongStorage_retain(THLongStorage *self);
+void THFloatStorage_retain(THFloatStorage *self);
+void THDoubleStorage_retain(THDoubleStorage *self);
+]]
+
+if torch.CudaTensor then
+ ffi.cdef[[
+void THCudaTensor_retain(THCudaTensor *self);
+void THCudaStorage_retain(THCudaStorage *self);
]]
+end
local serialize = {}
-local tensor = {}
-local tensortypes = {}
-
-for _, name in ipairs{
- 'ByteTensor',
- 'CharTensor',
- 'ShortTensor',
- 'IntTensor',
- 'LongTensor',
- 'CudaTensor',
- 'FloatTensor',
- 'DoubleTensor',
- 'CudaTensor'} do
-
- if torch[name] then
- table.insert(tensortypes, name)
- tensor[name] = {
- read = torch[name].read,
- write = torch[name].write
- }
+local typenames = {}
+
+-- check if typenames exists
+for _, typename in ipairs{
+ 'torch.ByteTensor',
+ 'torch.CharTensor',
+ 'torch.ShortTensor',
+ 'torch.IntTensor',
+ 'torch.LongTensor',
+ 'torch.CudaTensor',
+ 'torch.FloatTensor',
+ 'torch.DoubleTensor',
+ 'torch.CudaTensor',
+ 'torch.ByteStorage',
+ 'torch.CharStorage',
+ 'torch.ShortStorage',
+ 'torch.IntStorage',
+ 'torch.LongStorage',
+ 'torch.CudaStorage',
+ 'torch.FloatStorage',
+ 'torch.DoubleStorage',
+ 'torch.CudaStorage',
+ 'tds_hash'} do
+
+ if torch.getmetatable(typename) then
+ typenames[typename] = {}
end
end
-local function tensor_write(self, f)
- f:writeLong(torch.pointer(self))
- local p = self:cdata()
- p.refcount = p.refcount + 1
-end
+if typenames.tds_hash then
+ local mt = typenames.tds_hash
-local function tensor_read(self, f)
- local p = f:readLong()
- local z = torch.pushudata(p, torch.typename(self))
- self:set(z)
-end
+ function mt.__factory(f)
+ local self = f:readLong()
+ self = ffi.cast('tds_hash&', self)
+ ffi.gc(self, tds.C.tds_hash_free)
+ return self
+ end
-local function sharewrite()
- for _, name in ipairs(tensortypes) do
- torch[name].write = tensor_write
+ function mt.__write(self, f)
+ f:writeLong(torch.pointer(self))
+ tds.C.tds_hash_retain(self)
+ end
+
+ function mt.__read(self, f)
end
end
-local function unsharewrite()
- for _, name in ipairs(tensortypes) do
- torch[name].write = tensor[name].write
+for _, typename in ipairs{
+ 'torch.ByteTensor',
+ 'torch.CharTensor',
+ 'torch.ShortTensor',
+ 'torch.IntTensor',
+ 'torch.LongTensor',
+ 'torch.CudaTensor',
+ 'torch.FloatTensor',
+ 'torch.DoubleTensor',
+ 'torch.CudaTensor',
+ 'torch.ByteStorage',
+ 'torch.CharStorage',
+ 'torch.ShortStorage',
+ 'torch.IntStorage',
+ 'torch.LongStorage',
+ 'torch.CudaStorage',
+ 'torch.FloatStorage',
+ 'torch.DoubleStorage',
+ 'torch.CudaStorage'} do
+
+ if typenames[typename] then
+ local mt = typenames[typename]
+ local thname = typename:gsub('torch%.', 'TH')
+ local retain = C[thname .. '_retain']
+
+ function mt.__factory(f)
+ local self = f:readLong()
+ self = torch.pushudata(self, typename)
+ return self
+ end
+
+ function mt.write(self, f)
+ f:writeLong(torch.pointer(self))
+ retain(self:cdata())
+ end
+
+ function mt.read(self, f)
+ end
end
end
-local function shareread()
- for _, name in ipairs(tensortypes) do
- torch[name].read = tensor_read
+local function swapwrite()
+ for typename, mt in pairs(typenames) do
+ local mts = torch.getmetatable(typename)
+ mts.__write, mt.__write = mt.__write, mts.__write
+ mts.write, mt.write = mt.write, mts.write
end
end
-local function unshareread()
- for _, name in ipairs(tensortypes) do
- torch[name].read = tensor[name].read
+local function swapread()
+ for typename, mt in pairs(typenames) do
+ local mts = torch.getmetatable(typename)
+ mts.__factory, mt.__factory = mt.__factory, mts.__factory
+ mts.__read, mt.__read = mt.__read, mts.__read
+ mts.read, mt.read = mt.read, mts.read
end
end
function serialize.save(func)
- sharewrite()
+ local status, msg = pcall(swapwrite)
+ if not status then
+ print(string.format('FATAL THREAD PANIC: (write) %s', msg))
+ os.exit(-1)
+ end
+
local status, code_p, sz = pcall(
function()
local f = torch.MemoryFile()
@@ -88,15 +165,27 @@ function serialize.save(func)
return code_p, sz
end
)
- unsharewrite()
if not status then
- error(code_p)
+ print(string.format('FATAL THREAD PANIC: (write) %s', code_p))
+ os.exit(-1)
end
+
+ local status, msg = pcall(swapwrite)
+ if not status then
+ print(string.format('FATAL THREAD PANIC: (write) %s', msg))
+ os.exit(-1)
+ end
+
return code_p, sz
end
function serialize.load(code_p, sz)
- shareread()
+ local status, msg = pcall(swapread)
+ if not status then
+ print(string.format('FATAL THREAD PANIC: (read) %s', msg))
+ os.exit(-1)
+ end
+
local status, func = pcall(
function()
local storage_p = C.THCharStorage_newWithData(code_p, sz)
@@ -108,10 +197,17 @@ function serialize.load(code_p, sz)
return func
end
)
- unshareread()
if not status then
- error(func)
+ print(string.format('FATAL THREAD PANIC: (read) %s', func))
+ os.exit(-1)
end
+
+ local status, msg = pcall(swapread)
+ if not status then
+ print(string.format('FATAL THREAD PANIC: (read) %s', msg))
+ os.exit(-1)
+ end
+
return func
end
diff --git a/test/simple.lua b/test/simple.lua
new file mode 100644
index 0000000..e422cd7
--- /dev/null
+++ b/test/simple.lua
@@ -0,0 +1,111 @@
+require 'torch'
+
+local Threads = require 'threads'
+local sdl = require 'sdl2'
+local tds = require 'tds'
+
+local nthread = 4
+local njob = 10
+local msg = "hello from a satellite thread"
+
+sdl.init(0)
+
+--Threads.serialization('threads.sharedserialize')
+
+local x = {}
+local xh = tds.hash()
+local xs = {}
+local z = tds.hash()
+local D = 10
+local K = 100000 -- good luck in non-shared (30M)
+for i=1,njob do
+ x[i] = torch.ones(D)
+ xh[i] = torch.ones(D)
+ xs[i] = torch.FloatStorage(D):fill(1)
+ for j=1,K do
+ z[(i-1)*K+j] = "blah" .. i .. j
+ end
+end
+collectgarbage()
+collectgarbage()
+
+print('GO')
+
+local threads = Threads(
+ nthread,
+ function(threadIdx)
+ require 'tds'
+ print('starting a new thread/state number:', threadIdx)
+ gmsg = msg -- we copy here an upvalue of the main thread
+ end
+)
+
+local jobdone = 0
+for i=1,njob do
+ threads:addjob(
+ function()
+ assert(x[i]:sum() == D)
+ assert(xh[i]:sum() == D)
+ assert(torch.FloatTensor(xs[i]):sum() == D)
+ for j=1,K do
+ assert(z[(i-1)*K+j] == "blah" .. i .. j)
+ end
+ x[i]:add(1)
+ xh[i]:add(1)
+ torch.FloatTensor(xs[i]):add(1)
+ print(string.format('%s -- thread ID is %x', gmsg, __threadid))
+ collectgarbage()
+ collectgarbage()
+ return __threadid
+ end,
+
+ function(id)
+ print(string.format("task %d finished (ran on thread ID %x)", i, id))
+ jobdone = jobdone + 1
+ end
+ )
+end
+
+for i=1,njob do
+ threads:addjob(
+ function()
+ collectgarbage()
+ collectgarbage()
+ end
+ )
+end
+
+threads:synchronize()
+
+print(string.format('%d jobs done', jobdone))
+
+threads:terminate()
+
+-- did we do the job in shared mode?
+for i=1,njob do
+ assert(x[i]:sum() == 2*D)
+ assert(xh[i]:sum() == 2*D)
+ assert(torch.FloatTensor(xs[i]):sum() == 2*D)
+end
+
+-- serialize and zero x
+local str = torch.serialize(x)
+local strh = torch.serialize(xh)
+local strs = torch.serialize(xs)
+for i=1,njob do
+ x[i]:zero()
+ xh[i]:zero()
+ xs[i]:fill(0)
+end
+
+-- dude, check that unserialized x does not point on x
+local y = torch.deserialize(str)
+local yh = torch.deserialize(strh)
+local ys = torch.deserialize(strs)
+for i=1,njob do
+ assert(y[i]:sum() == 2*D)
+ assert(yh[i]:sum() == 2*D)
+ assert(torch.FloatTensor(ys[i]):sum() == 2*D)
+end
+
+print('PASSED')