diff options
author | Clement Farabet <clement.farabet@gmail.com> | 2011-08-29 04:08:23 +0400 |
---|---|---|
committer | Clement Farabet <clement.farabet@gmail.com> | 2011-08-29 04:08:23 +0400 |
commit | ce7043e49884a2aeb2bc7489aa6e50d21682aba5 (patch) | |
tree | 12f41824d439875eea055b36d2abf07d4dd8ba7c | |
parent | 357829243d61169f719144231bb0260cbde52404 (diff) |
Added support for user hooks in parallel BFGS. To be tested.
-rw-r--r-- | LBFGSOptimization.lua | 49 |
1 files changed, 49 insertions, 0 deletions
diff --git a/LBFGSOptimization.lua b/LBFGSOptimization.lua index 643e2b1..8373cfd 100644 --- a/LBFGSOptimization.lua +++ b/LBFGSOptimization.lua @@ -90,6 +90,33 @@ function LBFGS:forward_mapreduce(inputs, targets, options) -- parameters local P = self.parallelize + -- transmit user hooks, if defined + if not self.hooksets then + if self.prehook then + if type(self.prehook) == 'string' then + parallel.children:send(self.prehook) + else + print('\r<LBFGSOptimization> WARNING: when using para||el mode, hooks should be') + print('\r<LBFGSOptimization> WARNING: defined as strings. User prehook ignored.') + parallel.children:send('') + end + else + parallel.children:send('') + end + if self.posthook then + if type(self.posthook) == 'string' then + parallel.children:send(self.posthook) + else + print('\r<LBFGSOptimization> WARNING: when using para||el mode, hooks should be') + print('<\rLBFGSOptimization> WARNING: defined as strings. User posthook ignored.') + parallel.children:send('') + end + else + parallel.children:send('') + end + self.hooksets = true + end + -- (0a) replicate output and gradParameters local outputs = {} local gradParameters = {} @@ -97,12 +124,15 @@ function LBFGS:forward_mapreduce(inputs, targets, options) -- (0b) divide input/target batch into N batches local inputss = {} local targetss = {} + local optionss = {} for t = 1,P do inputss[t] = {} targetss[t] = {} + optionss[t] = {} for i = t,#inputs,P do table.insert(inputss[t], inputs[i]) table.insert(targetss[t], targets[i]) + if options then table.insert(optionss[t], options[i]) end end end @@ -110,6 +140,7 @@ function LBFGS:forward_mapreduce(inputs, targets, options) for t = 1,P do parallel.children[t]:send(inputss[t]) parallel.children[t]:send(targetss[t]) + parallel.children[t]:send(optionss[t]) end -- (1) construct a closure that compute f(inputs) + df/dW @@ -199,6 +230,15 @@ function LBFGS:setup_mapreduce () module = parallel.parent:receive() criterion = parallel.parent:receive() + -- create fake optimizer, for hooks + optimizer = {module=module, criterion=criterion} + + -- retrieve optional prehook/posthook + prehook = parallel.parent:receive() + posthook = parallel.parent:receive() + if prehook ~= '' then loadstring(prehook)() else prehook = nil end + if posthook ~= '' then loadstring(posthook)() else posthook = nil end + -- get pointer to parameter and gradParameter vectors parameters = nnx.getParameters(module) gradParameters = nnx.getGradParameters(module) @@ -209,6 +249,7 @@ function LBFGS:setup_mapreduce () inputs = parallel.parent:receive() if type(inputs) == 'string' and inputs == 'break' then break end targets = parallel.parent:receive() + options = parallel.parent:receive() -- inner loop: evaluations while true do @@ -225,6 +266,10 @@ function LBFGS:setup_mapreduce () local f_x = 0 -- evaluate gradients on inputs for this thread for i = 1,#inputs do + -- user hook + if prehook then + prehook(optimizer, {inputs[i], targets[i], options[i]}) + end -- estimate f local output = module:forward(inputs[i]) local err = criterion:forward(output, targets[i]) @@ -232,6 +277,10 @@ function LBFGS:setup_mapreduce () -- estimate df/dW local df_do = criterion:backward(output, targets[i]) module:backward(inputs[i], df_do) + -- user hook + if posthook then + posthook(optimizer, {inputs[i], targets[i], options[i]}) + end end -- now send back gradParameters + partial output |