Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/threads-ffi.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRonan Collobert <ronan@collobert.com>2015-03-24 22:51:12 +0300
committerRonan Collobert <ronan@collobert.com>2015-03-24 22:51:12 +0300
commit0c26c67ed3baded1e6855fac2c604be060780155 (patch)
tree10cf1c5e22eef334afa2b397e9ebe1ff02f7aaf4
parent9cb1daba55c66ab6bc3b2210c0b2e8dafcf1a6a8 (diff)
added specific([boolean]) and fix threads initialization accordingly
-rw-r--r--init.lua170
1 files changed, 121 insertions, 49 deletions
diff --git a/init.lua b/init.lua
index 5de9e44..4509e62 100644
--- a/init.lua
+++ b/init.lua
@@ -43,7 +43,7 @@ function Threads.serialization(name)
end
function Threads:__call(N, ...)
- local self = {N=N, endcallbacks={n=0}, errors={}}
+ local self = {N=N, endcallbacks={n=0}, errors={}, __specific=true}
local funcs = {...}
local serialize = require(Threads.__serialize)
@@ -51,61 +51,45 @@ function Threads:__call(N, ...)
funcs = {function() end}
end
- local initres = {}
setmetatable(self, {__index=Threads})
self.mainworker = Worker(N, Threads.__serialize)
self.threadworker = Worker(N, Threads.__serialize)
+ self.threadspecificworkers = {}
self.threads = {}
for i=1,N do
+ self.threadspecificworkers[i] = Worker(N, Threads.__serialize)
+
local L = C.luaL_newstate()
assert(L ~= nil, string.format('%d-th lua state creation failed', i))
C.luaL_openlibs(L)
- for j=1,#funcs do
- local code_p, sz = serialize.save(funcs[j])
- if j < #funcs then
- checkL(L, C.luaL_loadstring(L, string.format([[
- local serialize = require '%s'
- local ffi = require 'ffi'
- local code = serialize.load(ffi.cast('const char*', %d), %d)
- code(%d)
- ]], Threads.__serialize, tonumber(ffi.cast('intptr_t', code_p)), sz, i)))
- else
- checkL(L, C.luaL_loadstring(L, string.format([[
- local serialize = require '%s'
- local ffi = require 'ffi'
- local code = serialize.load(ffi.cast('const char*', %d), %d)
- __threadid = %d
- __workerinitres_p, __workerinitres_sz = serialize.save{code(%d)}
- __workerinitres_p = tonumber(ffi.cast('intptr_t', __workerinitres_p))
- ]], Threads.__serialize, tonumber(ffi.cast('intptr_t', code_p)), sz, i, i)))
- end
- checkL(L, C.lua_pcall(L, 0, 0, 0) == 0)
- end
-
- C.lua_getfield(L, LUA_GLOBALSINDEX, '__workerinitres_p')
- local workerinitres_p = C.lua_tointeger(L, -1)
- C.lua_getfield(L, LUA_GLOBALSINDEX, '__workerinitres_sz')
- local workerinitres_sz = C.lua_tointeger(L, -1)
- C.lua_settop(L, -3)
- table.insert(initres, serialize.load(ffi.cast('const char*', workerinitres_p), workerinitres_sz))
-
- checkL(L, C.luaL_loadstring(L, [[
+ checkL(L,
+ C.luaL_loadstring(
+ L,
+ string.format(
+ [[
local ffi = require 'ffi'
local sdl = require 'sdl2'
require 'threads.worker'
+ __threadid = %d
local function workerloop(data)
local workers = ffi.cast('struct THWorker**', data)
local mainworker = workers[0]
local threadworker = workers[1]
+ local threadspecificworker = workers[2]
local threadid = __threadid
while __worker_running do
- local status, res, endcallbackid = threadworker:dojob()
+ local status, res, endcallbackid
+ if __worker_specific then
+ status, res, endcallbackid = threadspecificworker:dojob()
+ else
+ status, res, endcallbackid = threadworker:dojob()
+ end
mainworker:addjob(function()
return status, res, endcallbackid, threadid
end)
@@ -115,15 +99,18 @@ function Threads:__call(N, ...)
end
__worker_running = true
+ __worker_specific = true
__workerloop_ptr = tonumber(ffi.cast('intptr_t', ffi.cast('int (*)(void *)', workerloop)))
-]]
-) == 0)
+]],
+ i)
+ ) == 0)
+
checkL(L, C.lua_pcall(L, 0, 0, 0) == 0)
C.lua_getfield(L, LUA_GLOBALSINDEX, '__workerloop_ptr')
local workerloop_ptr = C.lua_tointeger(L, -1)
C.lua_settop(L, -2);
- local workers = ffi.new('struct THWorker*[2]', {self.mainworker, self.threadworker}) -- note: GCed
+ local workers = ffi.new('struct THWorker*[3]', {self.mainworker, self.threadworker, self.threadspecificworkers[i]}) -- note: GCed
local thread = sdl.createThread(ffi.cast('SDL_ThreadFunction', workerloop_ptr), string.format("%s%.2d", Threads.name, i), workers)
assert(thread ~= nil, string.format('%d-th thread creation failed', i))
table.insert(self.threads, {thread=thread, L=L})
@@ -135,9 +122,61 @@ function Threads:__call(N, ...)
self:synchronize()
end
+ local initres = {}
+ for j=1,#funcs do
+ for i=1,self.N do
+ if j ~= #funcs then
+ self:addjob(
+ i, -- specific
+ funcs[j],
+ function()
+ end,
+ i -- passed to callback
+ )
+ else
+ self:addjob(
+ i, -- specific
+ funcs[j],
+ function(...)
+ table.insert(initres, {...})
+ end,
+ i -- passed to callback
+ )
+ end
+ end
+ end
+ self:synchronize()
+ self:specific(false)
+
return self, initres
end
+function Threads:specific(flag)
+ if flag ~= nil then
+ assert(type(flag) == 'boolean', 'boolean expected')
+ self:synchronize() -- finish jobs first
+ if self.__specific ~= flag then
+ if self.__specific then
+ for i=1,self.N do
+ self:addjob(i,
+ function()
+ __worker_specific = false
+ end)
+ end
+ else
+ for i=1,self.N do
+ self:addjob(function()
+ __worker_specific = true
+ end)
+ end
+ end
+ self.__specific = flag
+ end
+ else
+ return self.__specific
+ end
+end
+
function Threads:dojob()
local endcallbacks = self.endcallbacks
local callstatus, args, endcallbackid, threadid = self.mainworker:dojob()
@@ -153,17 +192,41 @@ function Threads:dojob()
endcallbacks.n = endcallbacks.n - 1
end
-function Threads:acceptsjob()
- return self.threadworker.isfull ~= 1
+function Threads:acceptsjob(idx)
+ local threadworker
+ if self:specific() then
+ assert(type(idx) == 'number' and idx >= 1 and idx <= self.N, 'thread index expected')
+ threadworker = self.threadspecificworkers[idx]
+ else
+ threadworker = self.threadworker
+ end
+ return threadworker.isfull ~= 1
end
-function Threads:__addjob__(sync, callback, endcallback, ...) -- endcallback is passed with returned values of callback
+function Threads:__addjob__(sync, ...) -- endcallback is passed with returned values of callback
if #self.errors > 0 then self:synchronize() end -- if errors exist, sync immediately.
local endcallbacks = self.endcallbacks
+ local idx, threadworker, r, callback, endcallback
+ if self:specific() then
+ idx = select(1, ...)
+ assert(type(idx) == 'number' and idx >= 1 and idx <= self.N, 'thread index expected')
+ threadworker = self.threadspecificworkers[idx]
+ callback = select(2, ...)
+ endcallback = select(3, ...)
+ r = 4
+ else
+ callback = select(1, ...)
+ endcallback = select(2, ...)
+ threadworker = self.threadworker
+ r = 3
+ end
+ assert(type(callback) == 'function', 'function callback expected')
+ assert(type(endcallback) == 'function' or type(endcallback) == 'nil', 'function (or nil) endcallback expected')
+
-- first finish running jobs if any
if sync then
- while not self:acceptsjob() do
+ while not self:acceptsjob(idx) do
self:dojob()
end
end
@@ -179,15 +242,15 @@ function Threads:__addjob__(sync, callback, endcallback, ...) -- endcallback is
return status, res, endcallbackid
end
- self.threadworker:addjob(func, ...)
+ threadworker:addjob(func, select(r, ...))
end
-function Threads:addjob(callback, endcallback, ...)
- self:__addjob__(true, callback, endcallback, ...)
+function Threads:addjob(...)
+ self:__addjob__(true, ...)
end
-function Threads:addjobasync(callback, endcallback, ...)
- self:__addjob__(false, callback, endcallback, ...)
+function Threads:addjobasync(...)
+ self:__addjob__(false, ...)
end
function Threads:haserror()
@@ -195,7 +258,7 @@ function Threads:haserror()
end
function Threads:hasjob()
- return (self.mainworker.runningjobs > 0 or self.threadworker.runningjobs > 0 or self.endcallbacks.n > 0)
+ return self.endcallbacks.n > 0
end
function Threads:synchronize()
@@ -212,9 +275,18 @@ end
function Threads:terminate()
-- terminate the threads
for i=1,self.N do
- self:addjob(function()
- __worker_running = false
- end)
+ if self:specific() then
+ self:addjob(
+ i,
+ function()
+ __worker_running = false
+ end)
+ else
+ self:addjob(
+ function()
+ __worker_running = false
+ end)
+ end
end
-- terminate all jobs