diff options
author | Ronan Collobert <ronan@collobert.com> | 2016-10-19 00:21:00 +0300 |
---|---|---|
committer | Ronan Collobert <ronan@collobert.com> | 2016-10-19 00:23:51 +0300 |
commit | 7e14ba31d0b1d2e48907d5bf5b2f876e865669ae (patch) | |
tree | 4c6e53d888be5b08d1f56195fc70e245b593da62 | |
parent | f0464d81308956a67680c5bf5ebe36e7f5c0d01c (diff) |
added threads.safe()
-rw-r--r-- | CMakeLists.txt | 1 | ||||
-rw-r--r-- | README.md | 14 | ||||
-rw-r--r-- | init.lua | 1 | ||||
-rw-r--r-- | safe.lua | 67 | ||||
-rw-r--r-- | test/test-threads-safe.lua | 48 |
5 files changed, 131 insertions, 0 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt index 355dc12..1e39a8e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -17,6 +17,7 @@ set(luasrc serialize.lua sharedserialize.lua queue.lua + safe.lua ) ADD_LIBRARY(threadsmain MODULE lib/thread-main.c) @@ -130,6 +130,7 @@ The library provides different low-level and high-level threading capabilities. * [Threads](#threads.main): a thread pool ; * [Queue](#queue): a thread-safe task queue ; and * [serialize](#threads.serialize): functions for serialization and deserialization. + * [safe](#threads.safe): make a function thread-safe. * [Low-level](#threads.lowlevel): * [Thread](#thread): a single thread with no artifice ; * [Mutex](#mutex): a thread mutex ; @@ -388,6 +389,19 @@ This function unserializes the outputs of a [serialize.save](#threads.serialize. The unserialized object `obj` is returned. +<a name='threads.safe'/> +### threads.safe(func, [mutex]) ### + +The function returns a new thread-safe function which embedds `func` (call +arguments and returned arguments are the same). A mutex is created and +locked before the execution of `func()`, and unlocked after. The mutex is +destroyed at the garbage collection of `func`. + +If needed, one can specify the `mutex` to use as a second optional argument +to threads.safe(). It is then up to the user to free this mutex when +needed. + + <a name='threads.lowlevel'/> ## Threads Low-Level Features @@ -6,6 +6,7 @@ threads.Thread = C.Thread threads.Mutex = C.Mutex threads.Condition = C.Condition threads.Threads = require 'threads.threads' +threads.safe = require 'threads.safe' -- only for backward compatibility (boo) setmetatable(threads, getmetatable(threads.Threads)) diff --git a/safe.lua b/safe.lua new file mode 100644 index 0000000..f23354b --- /dev/null +++ b/safe.lua @@ -0,0 +1,67 @@ +-- utility for lua 5.2 +local setfenv = setfenv or + function(fn, env) + local i = 1 + while true do + local name = debug.getupvalue(fn, i) + if name == "_ENV" then + debug.upvaluejoin(fn, i, (function() + return env + end), 1) + break + elseif not name then + break + end + i = i + 1 + end + return fn + end + +local function newproxygc(func) + local proxy + if newproxy then -- 5.1 + proxy = newproxy(true) + getmetatable(proxy).__gc = func + else -- 5.2 + proxy = {} + setmetatable(proxy, {__gc=func}) + end + return proxy +end + +return function(func, mutex) + local threads = require 'threads' + + assert(type(func) == 'function', 'function, [mutex] expected') + assert(mutex == nil or getmetatable(threads.Mutex).__index == getmetatable(mutex).__index, 'function, [mutex] expected') + + -- make sure mutex is freed if it is our own + local proxy + if not mutex then + mutex = threads.Mutex() + proxy = newproxygc( + function() + mutex:free() + end + ) + end + + local mutexid = mutex:id() + local safe = + function(...) + local threads = require 'threads' + local mutex = threads.Mutex(mutexid) + local unpack = unpack or table.unpack + mutex:lock() + local res = {func(...)} + mutex:unlock() + return unpack(res) + end + + -- make sure mutex is freed if it is our own + if proxy then + setfenv(safe, {require=require, unpack=unpack, table=table, proxy=proxy}) + end + + return safe +end diff --git a/test/test-threads-safe.lua b/test/test-threads-safe.lua new file mode 100644 index 0000000..874e521 --- /dev/null +++ b/test/test-threads-safe.lua @@ -0,0 +1,48 @@ +require 'torch' + +if not (#arg == 0) + and not (#arg == 1 and tonumber(arg[1])) + and not (#arg == 2 and tonumber(arg[1]) and arg[2] == 'unsafe') +then + error(string.format('usage: %s [number of runs] ["unsafe"]', arg[0])) +end + +local N = tonumber(arg[1]) or 1000 +local issafe = (arg[2] ~= 'unsafe') + +local threads = require 'threads' + +threads.Threads.serialization('threads.sharedserialize') + +local tensor = torch.zeros(10000000) + +local pool = threads.Threads(10) + +local run = + function() + tensor:add(1) + end + +if issafe then + run = threads.safe(run) +end + +for i=1,N do + pool:addjob( + run, + function() + if i % (N/100) == 0 then + io.write('.') + io.flush() + end + end + ) +end + +pool:synchronize() +print() + +assert(tensor:min() == N) +assert(tensor:max() == N) + +print('PASSED') |