diff options
author | Ronan Collobert <ronan@collobert.com> | 2015-03-24 22:51:12 +0300 |
---|---|---|
committer | Ronan Collobert <ronan@collobert.com> | 2015-03-24 22:51:12 +0300 |
commit | 0c26c67ed3baded1e6855fac2c604be060780155 (patch) | |
tree | 10cf1c5e22eef334afa2b397e9ebe1ff02f7aaf4 | |
parent | 9cb1daba55c66ab6bc3b2210c0b2e8dafcf1a6a8 (diff) |
added specific([boolean]) and fix threads initialization accordingly
-rw-r--r-- | init.lua | 170 |
1 files changed, 121 insertions, 49 deletions
@@ -43,7 +43,7 @@ function Threads.serialization(name) end function Threads:__call(N, ...) - local self = {N=N, endcallbacks={n=0}, errors={}} + local self = {N=N, endcallbacks={n=0}, errors={}, __specific=true} local funcs = {...} local serialize = require(Threads.__serialize) @@ -51,61 +51,45 @@ function Threads:__call(N, ...) funcs = {function() end} end - local initres = {} setmetatable(self, {__index=Threads}) self.mainworker = Worker(N, Threads.__serialize) self.threadworker = Worker(N, Threads.__serialize) + self.threadspecificworkers = {} self.threads = {} for i=1,N do + self.threadspecificworkers[i] = Worker(N, Threads.__serialize) + local L = C.luaL_newstate() assert(L ~= nil, string.format('%d-th lua state creation failed', i)) C.luaL_openlibs(L) - for j=1,#funcs do - local code_p, sz = serialize.save(funcs[j]) - if j < #funcs then - checkL(L, C.luaL_loadstring(L, string.format([[ - local serialize = require '%s' - local ffi = require 'ffi' - local code = serialize.load(ffi.cast('const char*', %d), %d) - code(%d) - ]], Threads.__serialize, tonumber(ffi.cast('intptr_t', code_p)), sz, i))) - else - checkL(L, C.luaL_loadstring(L, string.format([[ - local serialize = require '%s' - local ffi = require 'ffi' - local code = serialize.load(ffi.cast('const char*', %d), %d) - __threadid = %d - __workerinitres_p, __workerinitres_sz = serialize.save{code(%d)} - __workerinitres_p = tonumber(ffi.cast('intptr_t', __workerinitres_p)) - ]], Threads.__serialize, tonumber(ffi.cast('intptr_t', code_p)), sz, i, i))) - end - checkL(L, C.lua_pcall(L, 0, 0, 0) == 0) - end - - C.lua_getfield(L, LUA_GLOBALSINDEX, '__workerinitres_p') - local workerinitres_p = C.lua_tointeger(L, -1) - C.lua_getfield(L, LUA_GLOBALSINDEX, '__workerinitres_sz') - local workerinitres_sz = C.lua_tointeger(L, -1) - C.lua_settop(L, -3) - table.insert(initres, serialize.load(ffi.cast('const char*', workerinitres_p), workerinitres_sz)) - - checkL(L, C.luaL_loadstring(L, [[ + checkL(L, + C.luaL_loadstring( + L, + string.format( + [[ local ffi = require 'ffi' local sdl = require 'sdl2' require 'threads.worker' + __threadid = %d local function workerloop(data) local workers = ffi.cast('struct THWorker**', data) local mainworker = workers[0] local threadworker = workers[1] + local threadspecificworker = workers[2] local threadid = __threadid while __worker_running do - local status, res, endcallbackid = threadworker:dojob() + local status, res, endcallbackid + if __worker_specific then + status, res, endcallbackid = threadspecificworker:dojob() + else + status, res, endcallbackid = threadworker:dojob() + end mainworker:addjob(function() return status, res, endcallbackid, threadid end) @@ -115,15 +99,18 @@ function Threads:__call(N, ...) end __worker_running = true + __worker_specific = true __workerloop_ptr = tonumber(ffi.cast('intptr_t', ffi.cast('int (*)(void *)', workerloop))) -]] -) == 0) +]], + i) + ) == 0) + checkL(L, C.lua_pcall(L, 0, 0, 0) == 0) C.lua_getfield(L, LUA_GLOBALSINDEX, '__workerloop_ptr') local workerloop_ptr = C.lua_tointeger(L, -1) C.lua_settop(L, -2); - local workers = ffi.new('struct THWorker*[2]', {self.mainworker, self.threadworker}) -- note: GCed + local workers = ffi.new('struct THWorker*[3]', {self.mainworker, self.threadworker, self.threadspecificworkers[i]}) -- note: GCed local thread = sdl.createThread(ffi.cast('SDL_ThreadFunction', workerloop_ptr), string.format("%s%.2d", Threads.name, i), workers) assert(thread ~= nil, string.format('%d-th thread creation failed', i)) table.insert(self.threads, {thread=thread, L=L}) @@ -135,9 +122,61 @@ function Threads:__call(N, ...) self:synchronize() end + local initres = {} + for j=1,#funcs do + for i=1,self.N do + if j ~= #funcs then + self:addjob( + i, -- specific + funcs[j], + function() + end, + i -- passed to callback + ) + else + self:addjob( + i, -- specific + funcs[j], + function(...) + table.insert(initres, {...}) + end, + i -- passed to callback + ) + end + end + end + self:synchronize() + self:specific(false) + return self, initres end +function Threads:specific(flag) + if flag ~= nil then + assert(type(flag) == 'boolean', 'boolean expected') + self:synchronize() -- finish jobs first + if self.__specific ~= flag then + if self.__specific then + for i=1,self.N do + self:addjob(i, + function() + __worker_specific = false + end) + end + else + for i=1,self.N do + self:addjob(function() + __worker_specific = true + end) + end + end + self.__specific = flag + end + else + return self.__specific + end +end + function Threads:dojob() local endcallbacks = self.endcallbacks local callstatus, args, endcallbackid, threadid = self.mainworker:dojob() @@ -153,17 +192,41 @@ function Threads:dojob() endcallbacks.n = endcallbacks.n - 1 end -function Threads:acceptsjob() - return self.threadworker.isfull ~= 1 +function Threads:acceptsjob(idx) + local threadworker + if self:specific() then + assert(type(idx) == 'number' and idx >= 1 and idx <= self.N, 'thread index expected') + threadworker = self.threadspecificworkers[idx] + else + threadworker = self.threadworker + end + return threadworker.isfull ~= 1 end -function Threads:__addjob__(sync, callback, endcallback, ...) -- endcallback is passed with returned values of callback +function Threads:__addjob__(sync, ...) -- endcallback is passed with returned values of callback if #self.errors > 0 then self:synchronize() end -- if errors exist, sync immediately. local endcallbacks = self.endcallbacks + local idx, threadworker, r, callback, endcallback + if self:specific() then + idx = select(1, ...) + assert(type(idx) == 'number' and idx >= 1 and idx <= self.N, 'thread index expected') + threadworker = self.threadspecificworkers[idx] + callback = select(2, ...) + endcallback = select(3, ...) + r = 4 + else + callback = select(1, ...) + endcallback = select(2, ...) + threadworker = self.threadworker + r = 3 + end + assert(type(callback) == 'function', 'function callback expected') + assert(type(endcallback) == 'function' or type(endcallback) == 'nil', 'function (or nil) endcallback expected') + -- first finish running jobs if any if sync then - while not self:acceptsjob() do + while not self:acceptsjob(idx) do self:dojob() end end @@ -179,15 +242,15 @@ function Threads:__addjob__(sync, callback, endcallback, ...) -- endcallback is return status, res, endcallbackid end - self.threadworker:addjob(func, ...) + threadworker:addjob(func, select(r, ...)) end -function Threads:addjob(callback, endcallback, ...) - self:__addjob__(true, callback, endcallback, ...) +function Threads:addjob(...) + self:__addjob__(true, ...) end -function Threads:addjobasync(callback, endcallback, ...) - self:__addjob__(false, callback, endcallback, ...) +function Threads:addjobasync(...) + self:__addjob__(false, ...) end function Threads:haserror() @@ -195,7 +258,7 @@ function Threads:haserror() end function Threads:hasjob() - return (self.mainworker.runningjobs > 0 or self.threadworker.runningjobs > 0 or self.endcallbacks.n > 0) + return self.endcallbacks.n > 0 end function Threads:synchronize() @@ -212,9 +275,18 @@ end function Threads:terminate() -- terminate the threads for i=1,self.N do - self:addjob(function() - __worker_running = false - end) + if self:specific() then + self:addjob( + i, + function() + __worker_running = false + end) + else + self:addjob( + function() + __worker_running = false + end) + end end -- terminate all jobs |