Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/threads-ffi.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSam Gross <colesbury@gmail.com>2015-11-24 00:10:40 +0300
committerSam Gross <colesbury@gmail.com>2015-11-24 00:25:00 +0300
commitebb9d1aa3fc7fcf71b1f7d6e719c60908fd94b47 (patch)
treec08a49a3408715c074f80c322ccff6668b0b8c61
parenta5869f64c10f00fb1a4e04cc2babfc8aa99d3ad8 (diff)
Fix deadlock when using coroutine.yield in endcallback
-rw-r--r--.travis.yml1
-rw-r--r--test/test-threads-coroutine.lua40
-rw-r--r--threads.lua8
3 files changed, 45 insertions, 4 deletions
diff --git a/.travis.yml b/.travis.yml
index 14d54b4..bf854c8 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -58,3 +58,4 @@ script:
- ${TESTLUA} test-threads-async.lua
- ${TESTLUA} test-threads-shared.lua
- ${TESTLUA} test-traceback.lua
+- ${TESTLUA} test-threads-coroutine.lua
diff --git a/test/test-threads-coroutine.lua b/test/test-threads-coroutine.lua
new file mode 100644
index 0000000..ac94ee4
--- /dev/null
+++ b/test/test-threads-coroutine.lua
@@ -0,0 +1,40 @@
+local threads = require 'threads'
+
+local t = threads.Threads(1)
+
+-- PUC Lua 5.1 doesn't support coroutine.yield within pcall
+if _VERSION == 'Lua 5.1' then
+ print('Unsupported test for PUC Lua 5.1')
+ return 0
+end
+
+local function loop()
+ t:addjob(function() return 1 end, coroutine.yield)
+ t:addjob(function() return 2 end, coroutine.yield)
+ t:synchronize()
+end
+
+local function test1()
+ local expected = 1
+ for r in coroutine.wrap(loop) do
+ assert(r == expected)
+ expected = expected + 1
+ end
+ assert(expected == 3)
+end
+
+local function test2()
+ for r in coroutine.wrap(loop) do
+ if r == 2 then
+ error('error at two')
+ end
+ end
+end
+
+test1()
+
+local ok = pcall(test2)
+assert(not ok)
+t:synchronize()
+
+print('Done')
diff --git a/threads.lua b/threads.lua
index 8507bb2..f65d678 100644
--- a/threads.lua
+++ b/threads.lua
@@ -165,11 +165,13 @@ end
function Threads:dojob()
checkrunning(self)
- local endcallbacks = self.endcallbacks
local callstatus, args, endcallbackid, threadid = self.mainqueue:dojob()
+ local endcallback = self.endcallbacks[endcallbackid]
+ self.endcallbacks[endcallbackid] = nil
+ self.endcallbacks.n = self.endcallbacks.n - 1
if callstatus then
local endcallstatus, msg = xpcall(
- function() return endcallbacks[endcallbackid](_unpack(args)) end,
+ function() return endcallback(_unpack(args)) end,
debug.traceback)
if not endcallstatus then
table.insert(self.errors, string.format('[thread %d endcallback] %s', threadid, msg))
@@ -177,8 +179,6 @@ function Threads:dojob()
else
table.insert(self.errors, string.format('[thread %d callback] %s', threadid, args[1]))
end
- endcallbacks[endcallbackid] = nil
- endcallbacks.n = endcallbacks.n - 1
end
function Threads:acceptsjob(idx)