diff options
author | Clement Farabet <clement.farabet@gmail.com> | 2011-09-30 03:34:40 +0400 |
---|---|---|
committer | Clement Farabet <clement.farabet@gmail.com> | 2011-09-30 03:34:40 +0400 |
commit | 600996159830cada12bce8c07ba1c9427f23d505 (patch) | |
tree | aedd65d6afdd7339a54525e5ff63046118c4a8f9 | |
parent | ba5cb832916589daead581756e489a58bae851b1 (diff) |
Added packed minibatch to BatchOptimization.
-rw-r--r-- | BatchOptimization.lua | 67 |
1 files changed, 46 insertions, 21 deletions
diff --git a/BatchOptimization.lua b/BatchOptimization.lua index f25cdc9..d827dd4 100644 --- a/BatchOptimization.lua +++ b/BatchOptimization.lua @@ -42,6 +42,14 @@ function Batch:forward(inputs, targets, options) end function Batch:forward_sequential(inputs, targets, options) + -- (0) batch size + local batchsize = 1 + if type(inputs) == 'table' then + batchsize = #inputs + else + batchsize = inputs:size(1) + end + -- (1) construct a closure that compute f(inputs) + df/dW -- after each call to that function: -- + self.parameters contains the current X vector @@ -54,44 +62,61 @@ function Batch:forward_sequential(inputs, targets, options) print('<BatchOptimization> evaluating f(X) + df/dX') end local _t_ = sys.clock() + -- reset gradients self.gradParameters:zero() + -- f is the average of all criterions self.output = 0 - -- given all inputs, evaluate gradients - for i = 1,#inputs do - -- user hook - if self.prehook then - self.prehook(self, {inputs[i], targets[i], options[i]}) + + -- minibatch + if type(inputs) == 'table' then + -- given all inputs, evaluate gradients + for i = 1,#inputs do + -- user hook + if self.prehook then + self.prehook(self, {inputs[i], targets[i], options[i]}) + end + -- estimate f + local output = self.module:forward(inputs[i]) + local err = self.criterion:forward(output, targets[i]) + self.output = self.output + err + -- estimate df/dW + local df_do = self.criterion:backward(output, targets[i]) + self.module:backward(inputs[i], df_do) + self.module:accGradParameters(inputs[i], df_do) + -- user hook + if self.posthook then + self.posthook(self, {inputs[i], targets[i], options[i]}) + end + -- update evaluation counter + self.evalCounter = self.evalCounter + 1 end + + else -- minibatch is assumed to be a BatchSize x ... tensor + -- estimate f - local output = self.module:forward(inputs[i]) - local err = self.criterion:forward(output, targets[i]) - self.output = self.output + err + local output = self.module:forward(inputs) + self.output = self.criterion:forward(output, targets) -- estimate df/dW - local df_do = self.criterion:backward(output, targets[i]) - self.module:backward(inputs[i], df_do) - self.module:accGradParameters(inputs[i], df_do) - -- user hook - if self.posthook then - self.posthook(self, {inputs[i], targets[i], options[i]}) - end + local df_do = self.criterion:backward(output, targets) + self.module:backward(inputs, df_do) + self.module:accGradParameters(inputs, df_do) -- update evaluation counter - self.evalCounter = self.evalCounter + 1 + self.evalCounter = self.evalCounter + inputs:size(1) end -- update evaluation counter self.batchCounter = self.batchCounter + 1 - -- normalize gradients - self.gradParameters:div(#inputs) + -- normalize gradients and f(X) + self.gradParameters:div(batchsize) + self.output = self.output/batchsize -- verbose if self.verbose >= 2 then print('<BatchOptimization> ' .. self.batchCounter .. 'th batch took ' .. (sys.clock() - _t_) .. ' sec') end - -- return average f(X) - self.output = self.output/#inputs return self.output end @@ -101,7 +126,7 @@ function Batch:forward_sequential(inputs, targets, options) end -- (3) update sample counter - self.sampleCounter = self.sampleCounter + #inputs + self.sampleCounter = self.sampleCounter + batchsize -- (4) return current output after optimization return self.output |