diff options
author | Clement Farabet <clement.farabet@gmail.com> | 2012-07-06 03:57:43 +0400 |
---|---|---|
committer | Clement Farabet <clement.farabet@gmail.com> | 2012-07-06 03:57:43 +0400 |
commit | 4081495b78ceb1b102e82cb642e31bc9f04182be (patch) | |
tree | a9e3ee51ff6153ff31f9272c8b1b9e54c7d6453c | |
parent | ba152e884d7c521a663a819131d029938a7023aa (diff) |
Oups: missing file.
-rw-r--r-- | hessian.lua | 374 |
1 files changed, 374 insertions, 0 deletions
diff --git a/hessian.lua b/hessian.lua new file mode 100644 index 0000000..493f7d9 --- /dev/null +++ b/hessian.lua @@ -0,0 +1,374 @@ +---------------------------------------------------------------------- +-- hessian.lua: this file appends extra methods to modules in nn, +-- to estimate diagonal elements of the Hessian. This is useful +-- to condition learning rates individually. +---------------------------------------------------------------------- +nn.hessian = {} + +---------------------------------------------------------------------- +-- Hessian code is still experimental, +-- and deactivated by default +---------------------------------------------------------------------- +function nn.hessian.activate() + + local function accDiagHessianParameters(module, input, diagHessianOutput, gw, hw) + if #gw ~= #hw then + error('Number of gradients is nto equal to number of hessians') + end + module.inputSq = module.inputSq or input.new() + module.inputSq:resizeAs(input) + torch.cmul(module.inputSq, input, input) + -- replace gradients with hessian + for i=1,#gw do + local gwname = gw[i] + local hwname = hw[i] + local gwval = module[gwname] + local hwval = module[hwname] + if hwval == nil then + module[hwname] = gwval.new():resizeAs(gwval) + hwval = module[hwname] + end + module[gwname] = hwval + module[hwname] = gwval + end + module.accGradParameters(module, module.inputSq, diagHessianOutput, 1) + -- put back gradients + for i=1,#gw do + local gwname = gw[i] + local hwname = hw[i] + local gwval = module[gwname] + local hwval = module[hwname] + module[gwname] = hwval + module[hwname] = gwval + end + end + nn.hessian.accDiagHessianParameters = accDiagHessianParameters + + local function updateDiagHessianInput(module, input, diagHessianOutput, w, wsq) + if #w ~= #wsq then + error('Number of weights is not equal to number of weights squares') + end + module.diagHessianInput = module.diagHessianInput or input.new() + module.diagHessianInput:resizeAs(input) + + local gi = module.gradInput + module.gradInput = module.diagHessianInput + for i=1,#w do + local wname = w[i] + local wsqname = wsq[i] + local wval = module[wname] + local wsqval = module[wsqname] + if wsqval == nil then + module[wsqname] = wval.new() + wsqval = module[wsqname] + end + wsqval:resizeAs(wval) + torch.cmul(wsqval, wval, wval) + module[wsqname] = wval + module[wname] = wsqval + end + module.updateGradInput(module,input,diagHessianOutput) + for i=1,#w do + local wname = w[i] + local wsqname = wsq[i] + local wval = module[wname] + local wsqval = module[wsqname] + module[wname] = wsqval + module[wsqname] = wval + end + module.gradInput = gi + end + nn.hessian.updateDiagHessianInput = updateDiagHessianInput + + local function updateDiagHessianInputPointWise(module, input, diagHessianOutput) + local tdh = diagHessianOutput.new():resizeAs(diagHessianOutput):fill(1) + updateDiagHessianInput(module,input,tdh,{},{}) + module.diagHessianInput:cmul(module.diagHessianInput) + module.diagHessianInput:cmul(diagHessianOutput) + end + nn.hessian.updateDiagHessianInputPointWise = updateDiagHessianInputPointWise + + local function initDiagHessianParameters(module,gw,hw) + module.diagHessianInput = module.diagHessianInput or module.gradInput.new(); + for i=1,#gw do + module[hw[i]] = module[hw[i]] or module[gw[i]].new():resizeAs(module[gw[i]]) + end + end + nn.hessian.initDiagHessianParameters = initDiagHessianParameters + + ---------------------------------------------------------------------- + -- MODULE + ---------------------------------------------------------------------- + function nn.Module.updateDiagHessianInput(self, input, diagHessianOutput) + self.diagHessianInput = self.diagHessianInput or diagHessianOutput + return self.diagHessianInput + end + + function nn.Module.accDiagHessianParameters(self, input, diagHessianOutput) + end + + function nn.Module.initDiagHessianParameters() + end + + ---------------------------------------------------------------------- + -- SEQUENTIAL + ---------------------------------------------------------------------- + function nn.Sequential.initDiagHessianParameters(self) + for i=1,#self.modules do + self.modules[i]:initDiagHessianParameters() + end + end + + function nn.Sequential.updateDiagHessianInput(self, input, diagHessianOutput) + local currentDiagHessianOutput = diagHessianOutput + local currentModule = self.modules[#self.modules] + for i=#self.modules-1,1,-1 do + local previousModule = self.modules[i] + currentDiagHessianOutput = currentModule:updateDiagHessianInput(previousModule.output, currentDiagHessianOutput) + currentModule = previousModule + end + currentDiagHessianOutput = currentModule:updateDiagHessianInput(input, currentDiagHessianOutput) + self.diagHessianInput = currentDiagHessianOutput + return currentDiagHessianOutput + end + + function nn.Sequential.accDiagHessianParameters(self, input, diagHessianOutput) + local currentDiagHessianOutput = diagHessianOutput + local currentModule = self.modules[#self.modules] + for i=#self.modules-1,1,-1 do + local previousModule = self.modules[i] + currentModule:accDiagHessianParameters(previousModule.output, currentDiagHessianOutput) + currentDiagHessianOutput = currentModule.diagHessianInput + currentModule = previousModule + end + currentModule:accDiagHessianParameters(input, currentDiagHessianOutput) + end + + ---------------------------------------------------------------------- + -- CRITERION + ---------------------------------------------------------------------- + function nn.Criterion.updateDiagHessianInput(self, input, diagHessianOutput) + self.diagHessianInput = self.diagHessianInput or self.output.new() + return self.diagHessianInput + end + + ---------------------------------------------------------------------- + -- MSECRITERION + ---------------------------------------------------------------------- + function nn.MSECriterion.updateDiagHessianInput(self, input, target) + self.diagHessianInput = self.diagHessianInput or input.new() + local val = 2 + if self.sizeAverage then + val = val / input:nElement() + end + self.diagHessianInput:resizeAs(input):fill(val) + return self.diagHessianInput + end + + function nn.WeightedMSECriterion.updateDiagHessianInput(self,input,target) + return nn.MSECriterion.updateDiagHessianInput(self,input,target) + end + + ---------------------------------------------------------------------- + -- LINEAR + ---------------------------------------------------------------------- + function nn.Linear.updateDiagHessianInput(self, input, diagHessianOutput) + updateDiagHessianInput(self, input, diagHessianOutput, {'weight'}, {'weightSq'}) + return self.diagHessianInput + end + + function nn.Linear.accDiagHessianParameters(self, input, diagHessianOutput) + accDiagHessianParameters(self,input, diagHessianOutput, {'gradWeight','gradBias'}, {'diagHessianWeight','diagHessianBias'}) + end + + function nn.Linear.initDiagHessianParameters(self) + initDiagHessianParameters(self,{'gradWeight','gradBias'},{'diagHessianWeight','diagHessianBias'}) + end + + ---------------------------------------------------------------------- + -- SpatialConvolution + ---------------------------------------------------------------------- + function nn.SpatialConvolution.updateDiagHessianInput(self, input, diagHessianOutput) + updateDiagHessianInput(self, input, diagHessianOutput, {'weight'}, {'weightSq'}) + return self.diagHessianInput + end + + function nn.SpatialConvolution.accDiagHessianParameters(self, input, diagHessianOutput) + accDiagHessianParameters(self,input, diagHessianOutput, {'gradWeight'}, {'diagHessianWeight'}) + end + function nn.SpatialConvolution.initDiagHessianParameters(self) + initDiagHessianParameters(self,{'gradWeight'},{'diagHessianWeight'}) + end + + ---------------------------------------------------------------------- + -- SpatialConvolutionMap + ---------------------------------------------------------------------- + function nn.SpatialConvolutionMap.updateDiagHessianInput(self, input, diagHessianOutput) + updateDiagHessianInput(self, input, diagHessianOutput, {'weight','bias'}, {'weightSq','biasSq'}) + return self.diagHessianInput + end + + function nn.SpatialConvolutionMap.accDiagHessianParameters(self, input, diagHessianOutput) + accDiagHessianParameters(self,input, diagHessianOutput, {'gradWeight','gradBias'}, {'diagHessianWeight','diagHessianBias'}) + end + function nn.SpatialConvolutionMap.initDiagHessianParameters(self) + initDiagHessianParameters(self,{'gradWeight','gradBias'},{'diagHessianWeight','diagHessianBias'}) + end + + ---------------------------------------------------------------------- + -- TANH + ---------------------------------------------------------------------- + function nn.Tanh.updateDiagHessianInput(self, input, diagHessianOutput) + updateDiagHessianInputPointWise(self,input, diagHessianOutput) + return self.diagHessianInput + end + + function nn.TanhShrink.updateDiagHessianInput(self, input, diagHessianOutput) + updateDiagHessianInputPointWise(self.tanh,input, diagHessianOutput) + self.diagHessianInput = self.diagHessianInput or input.new():resizeAs(input) + torch.add(self.diagHessianInput, self.tanh.diagHessianInput, diagHessianOutput) + return self.diagHessianInput + end + + function nn.Diag.updateDiagHessianInput(self, input, diagHessianOutput) + updateDiagHessianInput(self, input, diagHessianOutput, {'weight'}, {'weightSq'}) + return self.diagHessianInput + end + + function nn.Diag.accDiagHessianParameters(self, input, diagHessianOutput) + accDiagHessianParameters(self,input, diagHessianOutput, {'gradWeight'}, {'diagHessianWeight'}) + end + + function nn.Diag.initDiagHessianParameters(self) + initDiagHessianParameters(self,{'gradWeight'},{'diagHessianWeight'}) + end + + ---------------------------------------------------------------------- + -- Parameters manipulation: + -- we modify these functions such that they return hessian coefficients + ---------------------------------------------------------------------- + function nn.Module.parameters(self) + if self.weight and self.bias then + return {self.weight, self.bias}, {self.gradWeight, self.gradBias}, {self.diagHessianWeight, self.diagHessianBias} + elseif self.weight then + return {self.weight}, {self.gradWeight}, {self.diagHessianWeight} + elseif self.bias then + return {self.bias}, {self.gradBias}, {self.diagHessianBias} + else + return + end + end + + function nn.Module.getParameters(self) + -- get parameters + local parameters,gradParameters,hessianParameters = self:parameters() + + -- this function flattens arbitrary lists of parameters, + -- even complex shared ones + local function flatten(parameters) + -- already flat ? + local flat = true + for k = 2,#parameters do + if parameters[k]:storage() ~= parameters[k-1]:storage() then + flat = false + break + end + end + if flat then + local nParameters = 0 + for k,param in ipairs(parameters) do + nParameters = nParameters + param:nElement() + end + local flatParameters = parameters[1].new(parameters[1]:storage()) + if nParameters ~= flatParameters:nElement() then + error('flattenParameters(): weird parameters') + end + return flatParameters + end + -- compute offsets of each parameter + local offsets = {} + local sizes = {} + local strides = {} + local elements = {} + local storageOffsets = {} + local params = {} + local nParameters = 0 + for k,param in ipairs(parameters) do + table.insert(offsets, nParameters+1) + table.insert(sizes, param:size()) + table.insert(strides, param:stride()) + table.insert(elements, param:nElement()) + table.insert(storageOffsets, param:storageOffset()) + local isView = false + for i = 1,k-1 do + if param:storage() == parameters[i]:storage() then + offsets[k] = offsets[i] + if storageOffsets[k] ~= storageOffsets[i] or elements[k] ~= elements[i] then + error('flattenParameters(): cannot flatten shared weights with different structures') + end + isView = true + break + end + end + if not isView then + nParameters = nParameters + param:nElement() + end + end + -- create flat vector + local flatParameters = parameters[1].new(nParameters) + local storage = flatParameters:storage() + -- reallocate all parameters in flat vector + for i = 1,#parameters do + local data = parameters[i]:clone() + parameters[i]:set(storage, offsets[i], elements[i]):resize(sizes[i],strides[i]):copy(data) + data = nil + collectgarbage() + end + -- cleanup + collectgarbage() + -- return flat param + return flatParameters + end + + -- flatten parameters and gradients + local flatParameters = flatten(parameters) + local flatGradParameters = flatten(gradParameters) + local flatHessianParameters + if hessianParameters[1] then + flatHessianParameters = flatten(hessianParameters) + end + + -- return new flat vector that contains all discrete parameters + return flatParameters, flatGradParameters, flatHessianParameters + end + + function nn.Sequential.parameters(self) + local function tinsert(to, from) + if type(from) == 'table' then + for i=1,#from do + tinsert(to,from[i]) + end + else + table.insert(to,from) + end + end + local w = {} + local gw = {} + local ggw = {} + for i=1,#self.modules do + local mw,mgw,mggw = self.modules[i]:parameters() + if mw then + tinsert(w,mw) + tinsert(gw,mgw) + tinsert(ggw,mggw) + end + end + return w,gw,ggw + end + + ---------------------------------------------------------------------- + -- Avoid multiple calls to activate() + ---------------------------------------------------------------------- + function nn.hessian.activate() + end +end |