diff options
author | Clement Farabet <clement.farabet@gmail.com> | 2011-09-16 23:34:55 +0400 |
---|---|---|
committer | Clement Farabet <clement.farabet@gmail.com> | 2011-09-16 23:34:55 +0400 |
commit | 56d85b2eedd81078b6950179179b2797be482fbc (patch) | |
tree | e27f1dc3c48a57ee5aa010038a94abb1b58f6e4f | |
parent | 6b1eb6102d9e65605c2b2cefd11af8bb5553f3e8 (diff) |
Added pre-hook code in BatchOptimization.
-rw-r--r-- | BatchOptimization.lua | 43 |
1 files changed, 18 insertions, 25 deletions
diff --git a/BatchOptimization.lua b/BatchOptimization.lua index 02a54e2..5eb5fb7 100644 --- a/BatchOptimization.lua +++ b/BatchOptimization.lua @@ -14,6 +14,8 @@ function Batch:__init(...) help='a criterion to estimate the error', req=true}, {arg='parallelize', type='number', help='parallelize onto N cores (experimental!)', default=1}, + {arg='precode', type='function', + help='optional code to be run by each parallel worker at their init'}, {arg='verbose', type='number', help='verbose level during training [0-2]', default=0} ) @@ -103,28 +105,8 @@ function Batch:forward_mapreduce(inputs, targets, options) -- transmit user hooks, if defined if not self.hooksets then - if self.prehook then - if type(self.prehook) == 'string' then - parallel.children:send(self.prehook) - else - print('\r<BatchOptimization> WARNING: when using para||el mode,'.. - ' hooks should be defined as strings. User prehook ignored.') - parallel.children:send('') - end - else - parallel.children:send('') - end - if self.posthook then - if type(self.posthook) == 'string' then - parallel.children:send(self.posthook) - else - print('\r<BatchOptimization> WARNING: when using para||el mode,'.. - ' hooks should be defined as strings. User posthook ignored.') - parallel.children:send('') - end - else - parallel.children:send('') - end + parallel.children:send(self.prehook or '') + parallel.children:send(self.posthook or '') self.hooksets = true end @@ -247,6 +229,10 @@ function Batch:setup_mapreduce () -- require packages require 'nnx' + -- retrieve optional code to setup worker + precode = parallel.parent:receive() + if type(precode) == 'function' then precode() end + -- retrieve module + criterion at startup parallel.yield() module = parallel.parent:receive() @@ -258,8 +244,8 @@ function Batch:setup_mapreduce () -- retrieve optional prehook/posthook prehook = parallel.parent:receive() posthook = parallel.parent:receive() - if prehook ~= '' then loadstring(prehook)() else prehook = nil end - if posthook ~= '' then loadstring(posthook)() else posthook = nil end + if type(prehook) ~= 'function' then prehook = nil end + if type(posthook) ~= 'function' then posthook = nil end -- get pointer to parameter and gradParameter vectors parameters = nnx.flattenParameters(nnx.getParameters(module)) @@ -326,7 +312,14 @@ function Batch:setup_mapreduce () self.children = parallel.sfork(self.parallelize) self.children:exec(worker_code) - -- (3) and send them the module + criterion architecture + -- (3) send them optional config code + if self.precode then + self.children:send(self.precode) + else + self.children:send('') + end + + -- (4) and send them the module + criterion architecture self.children:join() self.children:send(self.module) self.children:send(self.criterion) |