diff options
author | Ronan Collobert <ronan@collobert.com> | 2015-04-05 04:57:38 +0300 |
---|---|---|
committer | Ronan Collobert <ronan@collobert.com> | 2015-04-05 04:59:47 +0300 |
commit | 3bd8aa86f457e9d93c4947a5848b07fd16b58842 (patch) | |
tree | a2c4ce29640b82ef696bb0a5244e24ee5ab2b8aa | |
parent | a283ccd0447d44f2b8c4899238f0ca90d126a874 (diff) |
sharedserialize: now handles also torch storages + tds hash
added test case
-rw-r--r-- | sharedserialize.lua | 192 | ||||
-rw-r--r-- | test/simple.lua | 111 |
2 files changed, 255 insertions, 48 deletions
diff --git a/sharedserialize.lua b/sharedserialize.lua index b151cec..9cbcc1a 100644 --- a/sharedserialize.lua +++ b/sharedserialize.lua @@ -3,77 +3,154 @@ local C = ffi.C require 'torch' +local status, tds = pcall(require, 'tds') + ffi.cdef[[ void free(void *ptr); void *malloc(size_t size); THCharStorage* THCharStorage_newWithData(const char *data, long size); void THCharStorage_clearFlag(THCharStorage *storage, const char flag); + +void THByteTensor_retain(THByteTensor *self); +void THCharTensor_retain(THCharTensor *self); +void THShortTensor_retain(THShortTensor *self); +void THIntTensor_retain(THIntTensor *self); +void THLongTensor_retain(THLongTensor *self); +void THFloatTensor_retain(THFloatTensor *self); +void THDoubleTensor_retain(THDoubleTensor *self); + +void THByteStorage_retain(THByteStorage *self); +void THCharStorage_retain(THCharStorage *self); +void THShortStorage_retain(THShortStorage *self); +void THIntStorage_retain(THIntStorage *self); +void THLongStorage_retain(THLongStorage *self); +void THFloatStorage_retain(THFloatStorage *self); +void THDoubleStorage_retain(THDoubleStorage *self); +]] + +if torch.CudaTensor then + ffi.cdef[[ +void THCudaTensor_retain(THCudaTensor *self); +void THCudaStorage_retain(THCudaStorage *self); ]] +end local serialize = {} -local tensor = {} -local tensortypes = {} - -for _, name in ipairs{ - 'ByteTensor', - 'CharTensor', - 'ShortTensor', - 'IntTensor', - 'LongTensor', - 'CudaTensor', - 'FloatTensor', - 'DoubleTensor', - 'CudaTensor'} do - - if torch[name] then - table.insert(tensortypes, name) - tensor[name] = { - read = torch[name].read, - write = torch[name].write - } +local typenames = {} + +-- check if typenames exists +for _, typename in ipairs{ + 'torch.ByteTensor', + 'torch.CharTensor', + 'torch.ShortTensor', + 'torch.IntTensor', + 'torch.LongTensor', + 'torch.CudaTensor', + 'torch.FloatTensor', + 'torch.DoubleTensor', + 'torch.CudaTensor', + 'torch.ByteStorage', + 'torch.CharStorage', + 'torch.ShortStorage', + 'torch.IntStorage', + 'torch.LongStorage', + 'torch.CudaStorage', + 'torch.FloatStorage', + 'torch.DoubleStorage', + 'torch.CudaStorage', + 'tds_hash'} do + + if torch.getmetatable(typename) then + typenames[typename] = {} end end -local function tensor_write(self, f) - f:writeLong(torch.pointer(self)) - local p = self:cdata() - p.refcount = p.refcount + 1 -end +if typenames.tds_hash then + local mt = typenames.tds_hash -local function tensor_read(self, f) - local p = f:readLong() - local z = torch.pushudata(p, torch.typename(self)) - self:set(z) -end + function mt.__factory(f) + local self = f:readLong() + self = ffi.cast('tds_hash&', self) + ffi.gc(self, tds.C.tds_hash_free) + return self + end -local function sharewrite() - for _, name in ipairs(tensortypes) do - torch[name].write = tensor_write + function mt.__write(self, f) + f:writeLong(torch.pointer(self)) + tds.C.tds_hash_retain(self) + end + + function mt.__read(self, f) end end -local function unsharewrite() - for _, name in ipairs(tensortypes) do - torch[name].write = tensor[name].write +for _, typename in ipairs{ + 'torch.ByteTensor', + 'torch.CharTensor', + 'torch.ShortTensor', + 'torch.IntTensor', + 'torch.LongTensor', + 'torch.CudaTensor', + 'torch.FloatTensor', + 'torch.DoubleTensor', + 'torch.CudaTensor', + 'torch.ByteStorage', + 'torch.CharStorage', + 'torch.ShortStorage', + 'torch.IntStorage', + 'torch.LongStorage', + 'torch.CudaStorage', + 'torch.FloatStorage', + 'torch.DoubleStorage', + 'torch.CudaStorage'} do + + if typenames[typename] then + local mt = typenames[typename] + local thname = typename:gsub('torch%.', 'TH') + local retain = C[thname .. '_retain'] + + function mt.__factory(f) + local self = f:readLong() + self = torch.pushudata(self, typename) + return self + end + + function mt.write(self, f) + f:writeLong(torch.pointer(self)) + retain(self:cdata()) + end + + function mt.read(self, f) + end end end -local function shareread() - for _, name in ipairs(tensortypes) do - torch[name].read = tensor_read +local function swapwrite() + for typename, mt in pairs(typenames) do + local mts = torch.getmetatable(typename) + mts.__write, mt.__write = mt.__write, mts.__write + mts.write, mt.write = mt.write, mts.write end end -local function unshareread() - for _, name in ipairs(tensortypes) do - torch[name].read = tensor[name].read +local function swapread() + for typename, mt in pairs(typenames) do + local mts = torch.getmetatable(typename) + mts.__factory, mt.__factory = mt.__factory, mts.__factory + mts.__read, mt.__read = mt.__read, mts.__read + mts.read, mt.read = mt.read, mts.read end end function serialize.save(func) - sharewrite() + local status, msg = pcall(swapwrite) + if not status then + print(string.format('FATAL THREAD PANIC: (write) %s', msg)) + os.exit(-1) + end + local status, code_p, sz = pcall( function() local f = torch.MemoryFile() @@ -88,15 +165,27 @@ function serialize.save(func) return code_p, sz end ) - unsharewrite() if not status then - error(code_p) + print(string.format('FATAL THREAD PANIC: (write) %s', code_p)) + os.exit(-1) end + + local status, msg = pcall(swapwrite) + if not status then + print(string.format('FATAL THREAD PANIC: (write) %s', msg)) + os.exit(-1) + end + return code_p, sz end function serialize.load(code_p, sz) - shareread() + local status, msg = pcall(swapread) + if not status then + print(string.format('FATAL THREAD PANIC: (read) %s', msg)) + os.exit(-1) + end + local status, func = pcall( function() local storage_p = C.THCharStorage_newWithData(code_p, sz) @@ -108,10 +197,17 @@ function serialize.load(code_p, sz) return func end ) - unshareread() if not status then - error(func) + print(string.format('FATAL THREAD PANIC: (read) %s', func)) + os.exit(-1) end + + local status, msg = pcall(swapread) + if not status then + print(string.format('FATAL THREAD PANIC: (read) %s', msg)) + os.exit(-1) + end + return func end diff --git a/test/simple.lua b/test/simple.lua new file mode 100644 index 0000000..e422cd7 --- /dev/null +++ b/test/simple.lua @@ -0,0 +1,111 @@ +require 'torch' + +local Threads = require 'threads' +local sdl = require 'sdl2' +local tds = require 'tds' + +local nthread = 4 +local njob = 10 +local msg = "hello from a satellite thread" + +sdl.init(0) + +--Threads.serialization('threads.sharedserialize') + +local x = {} +local xh = tds.hash() +local xs = {} +local z = tds.hash() +local D = 10 +local K = 100000 -- good luck in non-shared (30M) +for i=1,njob do + x[i] = torch.ones(D) + xh[i] = torch.ones(D) + xs[i] = torch.FloatStorage(D):fill(1) + for j=1,K do + z[(i-1)*K+j] = "blah" .. i .. j + end +end +collectgarbage() +collectgarbage() + +print('GO') + +local threads = Threads( + nthread, + function(threadIdx) + require 'tds' + print('starting a new thread/state number:', threadIdx) + gmsg = msg -- we copy here an upvalue of the main thread + end +) + +local jobdone = 0 +for i=1,njob do + threads:addjob( + function() + assert(x[i]:sum() == D) + assert(xh[i]:sum() == D) + assert(torch.FloatTensor(xs[i]):sum() == D) + for j=1,K do + assert(z[(i-1)*K+j] == "blah" .. i .. j) + end + x[i]:add(1) + xh[i]:add(1) + torch.FloatTensor(xs[i]):add(1) + print(string.format('%s -- thread ID is %x', gmsg, __threadid)) + collectgarbage() + collectgarbage() + return __threadid + end, + + function(id) + print(string.format("task %d finished (ran on thread ID %x)", i, id)) + jobdone = jobdone + 1 + end + ) +end + +for i=1,njob do + threads:addjob( + function() + collectgarbage() + collectgarbage() + end + ) +end + +threads:synchronize() + +print(string.format('%d jobs done', jobdone)) + +threads:terminate() + +-- did we do the job in shared mode? +for i=1,njob do + assert(x[i]:sum() == 2*D) + assert(xh[i]:sum() == 2*D) + assert(torch.FloatTensor(xs[i]):sum() == 2*D) +end + +-- serialize and zero x +local str = torch.serialize(x) +local strh = torch.serialize(xh) +local strs = torch.serialize(xs) +for i=1,njob do + x[i]:zero() + xh[i]:zero() + xs[i]:fill(0) +end + +-- dude, check that unserialized x does not point on x +local y = torch.deserialize(str) +local yh = torch.deserialize(strh) +local ys = torch.deserialize(strs) +for i=1,njob do + assert(y[i]:sum() == 2*D) + assert(yh[i]:sum() == 2*D) + assert(torch.FloatTensor(ys[i]):sum() == 2*D) +end + +print('PASSED') |