Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/threads-ffi.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRonan Collobert <ronan@collobert.com>2016-10-19 00:21:00 +0300
committerRonan Collobert <ronan@collobert.com>2016-10-19 00:23:51 +0300
commit7e14ba31d0b1d2e48907d5bf5b2f876e865669ae (patch)
tree4c6e53d888be5b08d1f56195fc70e245b593da62
parentf0464d81308956a67680c5bf5ebe36e7f5c0d01c (diff)
added threads.safe()
-rw-r--r--CMakeLists.txt1
-rw-r--r--README.md14
-rw-r--r--init.lua1
-rw-r--r--safe.lua67
-rw-r--r--test/test-threads-safe.lua48
5 files changed, 131 insertions, 0 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 355dc12..1e39a8e 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -17,6 +17,7 @@ set(luasrc
serialize.lua
sharedserialize.lua
queue.lua
+ safe.lua
)
ADD_LIBRARY(threadsmain MODULE lib/thread-main.c)
diff --git a/README.md b/README.md
index c82bbbf..49d16d1 100644
--- a/README.md
+++ b/README.md
@@ -130,6 +130,7 @@ The library provides different low-level and high-level threading capabilities.
* [Threads](#threads.main): a thread pool ;
* [Queue](#queue): a thread-safe task queue ; and
* [serialize](#threads.serialize): functions for serialization and deserialization.
+ * [safe](#threads.safe): make a function thread-safe.
* [Low-level](#threads.lowlevel):
* [Thread](#thread): a single thread with no artifice ;
* [Mutex](#mutex): a thread mutex ;
@@ -388,6 +389,19 @@ This function unserializes the outputs of a [serialize.save](#threads.serialize.
The unserialized object `obj` is returned.
+<a name='threads.safe'/>
+### threads.safe(func, [mutex]) ###
+
+The function returns a new thread-safe function which embedds `func` (call
+arguments and returned arguments are the same). A mutex is created and
+locked before the execution of `func()`, and unlocked after. The mutex is
+destroyed at the garbage collection of `func`.
+
+If needed, one can specify the `mutex` to use as a second optional argument
+to threads.safe(). It is then up to the user to free this mutex when
+needed.
+
+
<a name='threads.lowlevel'/>
## Threads Low-Level Features
diff --git a/init.lua b/init.lua
index 30ca9ae..a76fec3 100644
--- a/init.lua
+++ b/init.lua
@@ -6,6 +6,7 @@ threads.Thread = C.Thread
threads.Mutex = C.Mutex
threads.Condition = C.Condition
threads.Threads = require 'threads.threads'
+threads.safe = require 'threads.safe'
-- only for backward compatibility (boo)
setmetatable(threads, getmetatable(threads.Threads))
diff --git a/safe.lua b/safe.lua
new file mode 100644
index 0000000..f23354b
--- /dev/null
+++ b/safe.lua
@@ -0,0 +1,67 @@
+-- utility for lua 5.2
+local setfenv = setfenv or
+ function(fn, env)
+ local i = 1
+ while true do
+ local name = debug.getupvalue(fn, i)
+ if name == "_ENV" then
+ debug.upvaluejoin(fn, i, (function()
+ return env
+ end), 1)
+ break
+ elseif not name then
+ break
+ end
+ i = i + 1
+ end
+ return fn
+ end
+
+local function newproxygc(func)
+ local proxy
+ if newproxy then -- 5.1
+ proxy = newproxy(true)
+ getmetatable(proxy).__gc = func
+ else -- 5.2
+ proxy = {}
+ setmetatable(proxy, {__gc=func})
+ end
+ return proxy
+end
+
+return function(func, mutex)
+ local threads = require 'threads'
+
+ assert(type(func) == 'function', 'function, [mutex] expected')
+ assert(mutex == nil or getmetatable(threads.Mutex).__index == getmetatable(mutex).__index, 'function, [mutex] expected')
+
+ -- make sure mutex is freed if it is our own
+ local proxy
+ if not mutex then
+ mutex = threads.Mutex()
+ proxy = newproxygc(
+ function()
+ mutex:free()
+ end
+ )
+ end
+
+ local mutexid = mutex:id()
+ local safe =
+ function(...)
+ local threads = require 'threads'
+ local mutex = threads.Mutex(mutexid)
+ local unpack = unpack or table.unpack
+ mutex:lock()
+ local res = {func(...)}
+ mutex:unlock()
+ return unpack(res)
+ end
+
+ -- make sure mutex is freed if it is our own
+ if proxy then
+ setfenv(safe, {require=require, unpack=unpack, table=table, proxy=proxy})
+ end
+
+ return safe
+end
diff --git a/test/test-threads-safe.lua b/test/test-threads-safe.lua
new file mode 100644
index 0000000..874e521
--- /dev/null
+++ b/test/test-threads-safe.lua
@@ -0,0 +1,48 @@
+require 'torch'
+
+if not (#arg == 0)
+ and not (#arg == 1 and tonumber(arg[1]))
+ and not (#arg == 2 and tonumber(arg[1]) and arg[2] == 'unsafe')
+then
+ error(string.format('usage: %s [number of runs] ["unsafe"]', arg[0]))
+end
+
+local N = tonumber(arg[1]) or 1000
+local issafe = (arg[2] ~= 'unsafe')
+
+local threads = require 'threads'
+
+threads.Threads.serialization('threads.sharedserialize')
+
+local tensor = torch.zeros(10000000)
+
+local pool = threads.Threads(10)
+
+local run =
+ function()
+ tensor:add(1)
+ end
+
+if issafe then
+ run = threads.safe(run)
+end
+
+for i=1,N do
+ pool:addjob(
+ run,
+ function()
+ if i % (N/100) == 0 then
+ io.write('.')
+ io.flush()
+ end
+ end
+ )
+end
+
+pool:synchronize()
+print()
+
+assert(tensor:min() == N)
+assert(tensor:max() == N)
+
+print('PASSED')