diff options
author | Clement Farabet <clement.farabet@gmail.com> | 2011-10-05 05:59:41 +0400 |
---|---|---|
committer | Clement Farabet <clement.farabet@gmail.com> | 2011-10-05 05:59:41 +0400 |
commit | d4f1c18061ac763aef1db7970bb52c3dd0e8ec93 (patch) | |
tree | 8f82c42fce8f17a85b35b36c3e18e7a91d9c3207 | |
parent | 56880e999ec6b9f4bac1413bb4626fd9d95e7472 (diff) |
Working SNES ?
-rw-r--r-- | SNESOptimization.lua | 45 |
1 files changed, 27 insertions, 18 deletions
diff --git a/SNESOptimization.lua b/SNESOptimization.lua index f9d77f9..8f391dd 100644 --- a/SNESOptimization.lua +++ b/SNESOptimization.lua @@ -5,7 +5,7 @@ function SNES:__init(...) xlua.unpack_class(self, {...}, 'SNESOptimization', nil, {arg='lambda', type='number', help='number of parallel samples', default=100}, - {arg='eta_mu', type='number', help='learning rate for mu', default=1e-2}, + {arg='eta_mu', type='number', help='learning rate for mu', default=1}, {arg='eta_sigma', type='number', help='learning rate for sigma', default=1e-2} ) -- original parameters @@ -25,6 +25,8 @@ function SNES:__init(...) -- SNES gradient vectors self.gradmu = torch.Tensor():resizeAs(self.mu) self.gradsigma = torch.Tensor():resizeAs(self.sigma) + -- SNES utilities + self:utilities() end function SNES:f(th, X, inputs, targets) @@ -40,19 +42,18 @@ function SNES:f(th, X, inputs, targets) return f end -function SNES:utilities(fitness) - -- sort fitness tables - table.sort(fitness, function(a,b) if a.f < b.f then return a end end) +function SNES:utilities() -- compute utilities local sum = 0 - for i,fit in ipairs(fitness) do - local x = (i-1)/#fitness -- x in [0..1] - fit.u = math.max(0, x-0.5) - sum = sum + fit.u + self.u = {} + for i = 1,self.lambda do + local x = (i-1)/self.lambda -- x in [0..1] + self.u[i] = math.exp(1-x*10)-1 + sum = sum + self.u[i] end - -- normalize us - for i,fitness in ipairs(fitness) do - fitness.u = fitness.u / sum + -- normalize u + for i = 1,self.lambda do + self.u[i] = self.u[i] / sum end end @@ -73,8 +74,8 @@ function SNES:optimize(inputs, targets) fitness[i] = {f=f_X, s=s_k, z=z_k} end - -- compute utilities - self:utilities(fitness) + -- sort fitness tables + table.sort(fitness, function(a,b) if a.f < b.f then return a end end) -- set current output to best f_X (lowest) self.output = fitness[1].f @@ -83,19 +84,27 @@ function SNES:optimize(inputs, targets) self.gradmu:zero() self.gradsigma:zero() for i = 1,self.lambda do - local fitness = fitness[i] - self.gradmu:add(fitness.u, fitness.s) - self.gradsigma:add(fitness.u, fitness.s:clone():pow(2):add(-1)) + self.gradmu:add(self.u[i], fitness[i].s) + self.gradsigma:add(self.u[i], fitness[i].s:clone():pow(2):add(-1)) end -- update parameters for i = 1,self.lambda do self.mu:add( self.sigma * self.gradmu * self.eta_mu ) - self.sigma:add( (self.gradsigma * self.eta_sigma/2):exp() ) + self.sigma:cmul( lab.exp(self.gradsigma * self.eta_sigma/2) ) end -- optimization done, copy back best parameter vector - self.parameter:copy(fitness[1].z) + self.parameter:copy(self.mu) + + -- verbose + self.batchCounter = self.batchCounter or 0 + self.batchCounter = self.batchCounter + 1 + if self.verbose >= 2 then + print('<SNESOptimization> evaluated f(X) on ' .. self.lambda .. ' random points') + print(' + batches seen: ' .. self.batchCounter) + print(' + lowest eval f(X) = ' .. self.output) + end -- for now call GC collectgarbage() |