diff options
author | Marco Scoffier <github@metm.org> | 2011-10-10 15:03:17 +0400 |
---|---|---|
committer | Marco Scoffier <github@metm.org> | 2011-10-10 15:03:17 +0400 |
commit | c3e369fb9c0beba3dacaef5fd0c59c60e33d6dff (patch) | |
tree | 421cc920f04ebb3bac7957dc887d87ef388523c1 | |
parent | 32096b2fe0f94e405f3611febca94453ff9ad70b (diff) | |
parent | c42ff17b16bca9c4aa6d264cc2a79b3347a4b81f (diff) |
Merge branch 'master' of github.com:clementfarabet/lua---nnx
-rw-r--r-- | CMakeLists.txt | 1 | ||||
-rw-r--r-- | DiagHessian.lua | 103 | ||||
-rw-r--r-- | init.lua | 13 | ||||
-rw-r--r-- | test/test-hessian.lua | 108 |
4 files changed, 225 insertions, 0 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt index 745f2e1..d0739ef 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -124,6 +124,7 @@ install_files(${INSTALL_PREFIX} SGDOptimization.lua) install_files(${INSTALL_PREFIX} GeneticSGDOptimization.lua) install_files(${INSTALL_PREFIX} BatchOptimization.lua) install_files(${INSTALL_PREFIX} SNESOptimization.lua) +install_files(${INSTALL_PREFIX} DiagHessian.lua) install_files(${INSTALL_PREFIX} BatchTrainer.lua) add_subdirectory (test) install_targets(${CINSTALL_PREFIX} nnx) diff --git a/DiagHessian.lua b/DiagHessian.lua new file mode 100644 index 0000000..d163051 --- /dev/null +++ b/DiagHessian.lua @@ -0,0 +1,103 @@ + +-- Module +function nn.Module.backwardDiagHessian(self, input, diagHessianOutput) + self.diagHessianInput = self.diagHessianInput or self.output.new() + return self.diagHessianInput +end + +function nn.Module.accDiagHessianParameters(self, input, diagHessianOutput, scale) +end + +function nn.Module.initDiagHessianParameters(self) +end + +-- Criterion +function nn.Criterion.backwardDiagHessian(self, input, diagHessianOutput) + self.diagHessianInput = self.diagHessianInput or self.output.new() + return self.diagHessianInput +end + +-- MSECriterion +function nn.MSECriterion.backwardDiagHessian(self, input, diagHessianOutput) + self.diagHessianInput = self.diagHessianInput or input.new() + self.diagHessianInput:resizeAs(input):fill(1) + return self.diagHessianInput +end + +-- Linear +function nn.Linear.backwardDiagHessian(self, input, diagHessianOutput) + self.diagHessianInput = self.diagHessianInput or self.output.new() + self.weightSq = self.weightSq or self.output.new():resizeAs(self.weight) + self.weightSq:copy(self.weight):cmul(self.weightSq) + if input:dim() == 1 then + self.diagHessianInput:resizeAs(input) + self.diagHessianInput:addmv(0, 1, self.weightSq:t(), diagHessianOutput) + elseif input:dim() == 2 then + self.diagHessianInput:resizeAs(input) + self.diagHessianInput:addmm(0, 1, diagHessianOutput, self.weightSq) + end + return self.diagHessianInput +end + +function nn.Linear.initDiagHessianParameters(self) + self.diagHessianWeight = self.diagHessianWeight or self.output.new():resizeAs(self.weight) + self.diagHessianBias = self.diagHessianBias or self.output.new():resizeAs(self.bias) +end + +function nn.Linear.accDiagHessianParameters(self, input, diagHessianOutput, scale) + scale = scale or 1 + self.inputSq = self.inputSq or self.output.new() + self.inputSq:resizeAs(input):copy(input):cmul(self.inputSq) + if input:dim() == 1 then + self.diagHessianWeight:addr(scale, diagHessianOutput, self.inputSq) + self.diagHessianBias:add(scale, diagHessianOutput) + elseif input:dim() == 2 then + local nframe = input:size(1) + local nunit = self.bias:size(1) + self.diagHessianWeight:addmm(scale, diagHessianOutput:t(), self.inputSq) + self.diagHessianBias:addmv(scale, diagHessianOutput:t(), self.output.new(nframe):fill(1)) + end +end + +-- Tanh +function nn.Tanh.backwardDiagHessian(self, input, diagHessianOutput) + self.diagHessianInput = self.diagHessianInput or self.output.new() + self.derivativeSq = self.derivativeSq or self.output.new() + self.derivativeSq:resizeAs(self.output):copy(self.output):cmul(self.output):mul(-1):add(1) + self.derivativeSq:cmul(self.derivativeSq) + self.diagHessianInput:resizeAs(input):copy(diagHessianOutput):cmul(self.derivativeSq) + return self.diagHessianInput +end + +-- Sequential +function nn.Sequential.backwardDiagHessian(self, input, diagHessianOutput) + local currentDiagHessianOutput = diagHessianOutput + local currentModule = self.modules[#self.modules] + for i=#self.modules-1,1,-1 do + local previousModule = self.modules[i] + currentDiagHessianOutput = currentModule:backwardDiagHessian(previousModule.output, currentDiagHessianOutput) + currentModule = previousModule + end + currentDiagHessianOutput = currentModule:backwardDiagHessian(input, currentDiagHessianOutput) + self.diagHessianInput = currentDiagHessianOutput + return currentDiagHessianOutput +end + +function nn.Sequential.initDiagHessianParameters(self) + for i=1,#self.modules do + self.modules[i]:initDiagHessianParameters() + end +end + +function nn.Sequential.accDiagHessianParameters(self, input, diagHessianOutput, scale) + scale = scale or 1 + local currentDiagHessianOutput = diagHessianOutput + local currentModule = self.modules[#self.modules] + for i=#self.modules-1,1,-1 do + local previousModule = self.modules[i] + currentModule:accDiagHessianParameters(previousModule.output, currentDiagHessianOutput, scale) + currentDiagHessianOutput = currentModule.diagHessianInput + currentModule = previousModule + end + currentModule:accDiagHessianParameters(input, currentDiagHessianOutput, scale) +end @@ -108,6 +108,7 @@ torch.include('nnx', 'SGDOptimization.lua') torch.include('nnx', 'LBFGSOptimization.lua') torch.include('nnx', 'CGOptimization.lua') torch.include('nnx', 'GeneticSGDOptimization.lua') +torch.include('nnx', 'DiagHessian.lua') -- trainers: torch.include('nnx', 'Trainer.lua') @@ -196,6 +197,18 @@ function nnx.getGradParameters(...) return holder end +function nnx.getDiagHessianParameters(...) + -- to hold all parameters found + local holder = {} + -- call recursive call + local modules = {...} + for _,module in ipairs(modules) do + get(module, holder, {'diagHessianWeight', 'diagHessianBias'}) + end + -- return all parameters found + return holder +end + function nnx.flattenParameters(parameters) -- already flat ? local flat = true diff --git a/test/test-hessian.lua b/test/test-hessian.lua new file mode 100644 index 0000000..cf6e42f --- /dev/null +++ b/test/test-hessian.lua @@ -0,0 +1,108 @@ +------------------------------------------------------------ +-- this simple script demonstrates the use of +-- approximate second-order information to calibrate +-- the learning rates individually +-- +-- given an input vector X, we want to learn a mapping +-- f(X) = \sum_i X_i +-- +-- we use a two-layer perceptron, just to validate +-- the tanh+linear hessian +-- (of course learning such a function is much more +-- trivial using a single linear layer :-) +-- + +-- libs +require 'nnx' + +-- fix random seed +random.manualSeed(1) + +-- SGD params +learningRate = 1e-3 +diagHessianEpsilon = 1e-3 +computeDiagHessian = true -- SET THIS FLAG TO FALSE TO SEE THE EFFECT OF THE DIAG HESSIAN + +-- fake data +inputs = {} +targets = {} +for i = 1,1000 do + inputs[i] = lab.randn(10) + targets[i] = torch.Tensor(1):fill(inputs[i]:sum()) +end + +-- create module +module = nn.Sequential() +module:add(nn.Linear(10,10)) +module:add(nn.Tanh()) +module:add(nn.Linear(10,1)) + +-- loss +criterion = nn.MSECriterion() + +-- get params +parameters = nnx.flattenParameters(nnx.getParameters(module)) +gradParameters = nnx.flattenParameters(nnx.getGradParameters(module)) + +-- compute learning rates +learningRates = torch.Tensor(parameters:size()):fill(1) +if computeDiagHessian then + -- init diag hessian + module:initDiagHessianParameters() + diagHessianParameters = nnx.flattenParameters(nnx.getDiagHessianParameters(module)) + + -- estimate diag hessian over dataset + diagHessianParameters:zero() + for i = 1,#inputs do + local output = module:forward(inputs[i]) + local critDiagHessian = criterion:backwardDiagHessian(output, targets[i]) + module:backwardDiagHessian(inputs[i], critDiagHessian) + module:accDiagHessianParameters(inputs[i], critDiagHessian) + end + diagHessianParameters:div(#inputs) + + -- protect diag hessian (the proper way of doing it is the commented code, + -- but for speed reasons, the uncommented code just works) + --diagHessianParameters:apply(function(x) return math.max(x, diagHessianEpsilon) end) + diagHessianParameters:add(diagHessianEpsilon) + + -- now learning rates are obtained like this: + learningRates:cdiv(diagHessianParameters) + + -- print info + print('learning rates calculated to') + print(learningRates) +end + +-- regular SGD +for epoch = 1,100 do + error = 0 + for i = 1,#inputs do + -- backprop gradients + local output = module:forward(inputs[i]) + local critGradInput = criterion:backward(output, targets[i]) + module:backward(inputs[i], critGradInput) + + -- print current error + error = error + criterion:forward(output, targets[i]) + + -- gradients wrt parameters + gradParameters:zero() + module:accGradParameters(inputs[i], critGradInput) + + -- given a parameter vector, and a gradParameter vector, the update goes like this: + deltaParameters = deltaParameters or parameters.new() + deltaParameters:resizeAs(gradParameters):copy(learningRates):cmul(gradParameters) + parameters:add(-learningRate, deltaParameters) + end + error = error / #inputs + print('current average error: ' .. error) +end + +-- test vector +input = lab.randn(10) +groundtruth = input:sum() +output = module:forward(input) +print('test input:') print(input) +print('predicted output:', output[1]) +print('groundtruth (\sum_i X_i):', groundtruth) |