Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/threads-ffi.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRonan Collobert <ronan@collobert.com>2015-04-08 01:49:02 +0300
committerRonan Collobert <ronan@collobert.com>2015-04-08 05:04:09 +0300
commit5661155854c7f5135bc8f78344b39574f53f7635 (patch)
tree6837ac1eb71f687a9f1b50064e4ef7b6ed61b5f8 /init.lua
parent8234b69d147b37f89936c0ac8bb197a2440879c7 (diff)
refcount job queues
Diffstat (limited to 'init.lua')
-rw-r--r--init.lua38
1 files changed, 35 insertions, 3 deletions
diff --git a/init.lua b/init.lua
index 69c7b43..7ae2381 100644
--- a/init.lua
+++ b/init.lua
@@ -43,7 +43,7 @@ function Threads.serialization(name)
end
function Threads:__call(N, ...)
- local self = {N=N, endcallbacks={n=0}, errors={}, __specific=true}
+ local self = {N=N, endcallbacks={n=0}, errors={}, __specific=true, __running=true}
local funcs = {...}
local serialize = require(Threads.__serialize)
@@ -82,6 +82,12 @@ function Threads:__call(N, ...)
local threadworker = workers[1]
local threadspecificworker = workers[2]
local threadid = __threadid
+ mainworker:retain()
+ threadworker:retain()
+ threadspecificworker:retain()
+ mainworker:gc()
+ threadworker:gc()
+ threadspecificworker:gc()
while __worker_running do
local status, res, endcallbackid
@@ -118,9 +124,10 @@ function Threads:__call(N, ...)
self.__gc__ = newproxy(true)
getmetatable(self.__gc__).__gc =
- function()
+ function()
self:synchronize()
- end
+ self:terminate()
+ end
local initres = {}
for j=1,#funcs do
@@ -150,7 +157,16 @@ function Threads:__call(N, ...)
return self, initres
end
+function Threads:isrunning()
+ return self.__running
+end
+
+local function checkrunning(self)
+ assert(self:isrunning(), 'thread system is not running')
+end
+
function Threads:specific(flag)
+ checkrunning(self)
if flag ~= nil then
assert(type(flag) == 'boolean', 'boolean expected')
self:synchronize() -- finish jobs first
@@ -178,6 +194,7 @@ function Threads:specific(flag)
end
function Threads:dojob()
+ checkrunning(self)
local endcallbacks = self.endcallbacks
local callstatus, args, endcallbackid, threadid = self.mainworker:dojob()
if callstatus then
@@ -193,6 +210,7 @@ function Threads:dojob()
end
function Threads:acceptsjob(idx)
+ checkrunning(self)
local threadworker
if self:specific() then
assert(type(idx) == 'number' and idx >= 1 and idx <= self.N, 'thread index expected')
@@ -204,6 +222,7 @@ function Threads:acceptsjob(idx)
end
function Threads:__addjob__(sync, ...) -- endcallback is passed with returned values of callback
+ checkrunning(self)
if #self.errors > 0 then self:synchronize() end -- if errors exist, sync immediately.
local endcallbacks = self.endcallbacks
@@ -246,22 +265,29 @@ function Threads:__addjob__(sync, ...) -- endcallback is passed with returned va
end
function Threads:addjob(...)
+ checkrunning(self)
self:__addjob__(true, ...)
end
function Threads:addjobasync(...)
+ checkrunning(self)
self:__addjob__(false, ...)
end
function Threads:haserror()
+ checkrunning(self)
return (#self.errors > 0)
end
function Threads:hasjob()
+ checkrunning(self)
return self.endcallbacks.n > 0
end
function Threads:synchronize()
+ if not self:isrunning() then
+ return
+ end
while self:hasjob()do
self:dojob()
end
@@ -273,6 +299,9 @@ function Threads:synchronize()
end
function Threads:terminate()
+ if not self:isrunning() then
+ return
+ end
-- terminate the threads
for i=1,self.N do
if self:specific() then
@@ -298,6 +327,9 @@ function Threads:terminate()
sdl.waitThread(self.threads[i].thread, pvalue)
C.lua_close(self.threads[i].L)
end
+
+ -- make sure you won't run anything
+ self.__running = false
end
return Threads --createThreads