diff options
author | Sam Gross <colesbury@gmail.com> | 2015-11-24 00:10:40 +0300 |
---|---|---|
committer | Sam Gross <colesbury@gmail.com> | 2015-11-24 00:25:00 +0300 |
commit | ebb9d1aa3fc7fcf71b1f7d6e719c60908fd94b47 (patch) | |
tree | c08a49a3408715c074f80c322ccff6668b0b8c61 | |
parent | a5869f64c10f00fb1a4e04cc2babfc8aa99d3ad8 (diff) |
Fix deadlock when using coroutine.yield in endcallback
-rw-r--r-- | .travis.yml | 1 | ||||
-rw-r--r-- | test/test-threads-coroutine.lua | 40 | ||||
-rw-r--r-- | threads.lua | 8 |
3 files changed, 45 insertions, 4 deletions
diff --git a/.travis.yml b/.travis.yml index 14d54b4..bf854c8 100644 --- a/.travis.yml +++ b/.travis.yml @@ -58,3 +58,4 @@ script: - ${TESTLUA} test-threads-async.lua - ${TESTLUA} test-threads-shared.lua - ${TESTLUA} test-traceback.lua +- ${TESTLUA} test-threads-coroutine.lua diff --git a/test/test-threads-coroutine.lua b/test/test-threads-coroutine.lua new file mode 100644 index 0000000..ac94ee4 --- /dev/null +++ b/test/test-threads-coroutine.lua @@ -0,0 +1,40 @@ +local threads = require 'threads' + +local t = threads.Threads(1) + +-- PUC Lua 5.1 doesn't support coroutine.yield within pcall +if _VERSION == 'Lua 5.1' then + print('Unsupported test for PUC Lua 5.1') + return 0 +end + +local function loop() + t:addjob(function() return 1 end, coroutine.yield) + t:addjob(function() return 2 end, coroutine.yield) + t:synchronize() +end + +local function test1() + local expected = 1 + for r in coroutine.wrap(loop) do + assert(r == expected) + expected = expected + 1 + end + assert(expected == 3) +end + +local function test2() + for r in coroutine.wrap(loop) do + if r == 2 then + error('error at two') + end + end +end + +test1() + +local ok = pcall(test2) +assert(not ok) +t:synchronize() + +print('Done') diff --git a/threads.lua b/threads.lua index 8507bb2..f65d678 100644 --- a/threads.lua +++ b/threads.lua @@ -165,11 +165,13 @@ end function Threads:dojob() checkrunning(self) - local endcallbacks = self.endcallbacks local callstatus, args, endcallbackid, threadid = self.mainqueue:dojob() + local endcallback = self.endcallbacks[endcallbackid] + self.endcallbacks[endcallbackid] = nil + self.endcallbacks.n = self.endcallbacks.n - 1 if callstatus then local endcallstatus, msg = xpcall( - function() return endcallbacks[endcallbackid](_unpack(args)) end, + function() return endcallback(_unpack(args)) end, debug.traceback) if not endcallstatus then table.insert(self.errors, string.format('[thread %d endcallback] %s', threadid, msg)) @@ -177,8 +179,6 @@ function Threads:dojob() else table.insert(self.errors, string.format('[thread %d callback] %s', threadid, args[1])) end - endcallbacks[endcallbackid] = nil - endcallbacks.n = endcallbacks.n - 1 end function Threads:acceptsjob(idx) |