diff options
-rw-r--r-- | Jacobian.lua | 131 | ||||
-rw-r--r-- | Module.lua | 129 | ||||
-rw-r--r-- | hessian.lua | 75 | ||||
-rw-r--r-- | test.lua | 83 |
4 files changed, 290 insertions, 128 deletions
diff --git a/Jacobian.lua b/Jacobian.lua index 25e8cf0..51ca139 100644 --- a/Jacobian.lua +++ b/Jacobian.lua @@ -92,6 +92,110 @@ function nn.Jacobian.forward(module, input, param, perturbation) return jacobian end +function nn.Jacobian.backwardDiagHessian(module, input, diagHessianParamName) + -- Compute the second derivatives (diagonal Hessian elements) + -- by backpropagation (using the code from hessian.lua). + -- + -- This function computes the diagonal Hessian elements of the following function: + -- + -- F(x_1, x_2, ..., x_n) = y_1^2/2 + y_2^2/2 + ... + y_m^2/2, + -- + -- where + -- x_1, ..., x_n are the input values and parameters of the given module, + -- y_1, ..., y_m are the output values of the given module. + -- + -- All x_i and y_i values are scalars here. In other words, + -- x_1, ..., x_n denote the scalar elements of the module input tensor, + -- the scalar elements of module.weight, + -- and the scalar elements of module.bias; + -- y_1, ..., y_m are the scalar elements of the module output tensor. + -- + -- The diagonal Hessian elements of F are computed with respect to + -- the module input values and parameters (x_1, .., x_n). + -- + -- The function F is chosen for its convenient properties: + -- + -- dF / dy_i = y_i, + -- d^2F / dy_i^2 = 1. + -- + -- In other words, the diagonal Hessian elements of F with respect + -- to the module OUTPUT values (y_1, ... y_m) are equal to 1. + -- + -- Because of that, computing the diagonal Hessian elements of F + -- with respect to the module INPUT values and PARAMETERS (x_1, ..., x_n) + -- can be done by calling updateDiagHessianInput() and accDiagHessianParameters() + -- using a tensor of ones as diagHessianOutput. + + module:forward(input) + local diagHessianOutput = module.output.new():resizeAs(module.output):fill(1) + + module.diagHessianWeight:zero() + module.diagHessianBias:zero() + module:updateDiagHessianInput(input, diagHessianOutput) + module:accDiagHessianParameters(input, diagHessianOutput) + + return module[diagHessianParamName] +end + +function nn.Jacobian.linearModuleDiagHessian(module, input, gradParamName) + -- Compute the second derivatives (diagonal Hessian elements) + -- from the first derivatives for the given module + -- (without using the code from hessian.lua). + -- + -- The given module is assumed to be linear with respect to its inputs and weights + -- (like nn.Linear, nn.SpatialConvolution, etc.) + -- + -- This function computes the diagonal Hessian elements of the following function: + -- + -- F(x_1, x_2, ..., x_n) = y_1^2/2 + y_2^2/2 + ... + y_m^2/2. + -- + -- (See the the comment for nn.Jacobian.backwardDiagHessian() for explanation.) + -- + -- The first derivatives of F with respect to + -- the module inputs and parameters (x_1, ..., x_n) are: + -- + -- dF / dx_i = \sum_k (dF / dy_k) (dy_k / dx_i). + -- + -- The second derivatives are: + -- + -- d^2F / dx_i = \sum_k [(d^2F / dy_k^2) (dy_k / dx_i)^2 + (dF / dy_k) (d^2y_k / dx_i^2)]. + -- + -- The second derivatives of F with respect to the module outputs (y_1, ..., y_m) + -- are equal to 1, so: + -- + -- d^2F / dx_i = \sum_k [(dy_k / dx_i)^2 + (dF / dy_k) (d^2y_k / dx_i^2)]. + -- + -- Assuming the linearity of module outputs (y_1, ..., y_m) + -- with respect to module inputs and parameters (x_1, ..., x_n), + -- we have (d^2y_k / dx_i^2) = 0, + -- and the expression finally becomes: + -- + -- d^2F / dx_i = \sum_k (dy_k / dx_i)^2. + -- + -- The first derivatives (dy_k / dx_i) are computed by normal backpropagation, + -- using updateGradInput() and accGradParameters(). + + local gradParam = module[gradParamName] + + local diagHessian = gradParam.new():resize(gradParam:nElement()):zero() + + module:forward(input) + local gradOutput = module.output.new():resizeAs(module.output) + local gradOutput1D = gradOutput:view(gradOutput:nElement()) + + for i=1,gradOutput:nElement() do + gradOutput1D:zero() + gradOutput1D[i] = 1 + module.gradWeight:zero() + module.gradBias:zero() + module:updateGradInput(input, gradOutput) + module:accGradParameters(input, gradOutput) + diagHessian:addcmul(gradParam, gradParam) + end + + return diagHessian +end + function nn.Jacobian.forwardUpdate(module, input, param, perturbation) -- perturbation amount perturbation = perturbation or 1e-6 @@ -156,6 +260,33 @@ function nn.Jacobian.testJacobianUpdateParameters(module, input, param, minval, return error:abs():max() end +function nn.Jacobian.testDiagHessian(module, input, gradParamName, diagHessianParamName, minval, maxval) + -- Compute the diagonal Hessian elements for the same function in two different ways, + -- then compare the results and return the difference. + + minval = minval or -2 + maxval = maxval or 2 + local inrange = maxval - minval + input:copy(torch.rand(input:nElement()):mul(inrange):add(minval)) + module:initDiagHessianParameters() + local h_bprop = nn.Jacobian.backwardDiagHessian(module, input, diagHessianParamName) + local h_linearmodule = nn.Jacobian.linearModuleDiagHessian(module, input, gradParamName) + local error = h_bprop - h_linearmodule + return error:abs():max() +end + +function nn.Jacobian.testDiagHessianInput(module, input, minval, maxval) + return nn.Jacobian.testDiagHessian(module, input, 'gradInput', 'diagHessianInput', minval, maxval) +end + +function nn.Jacobian.testDiagHessianWeight(module, input, minval, maxval) + return nn.Jacobian.testDiagHessian(module, input, 'gradWeight', 'diagHessianWeight', minval, maxval) +end + +function nn.Jacobian.testDiagHessianBias(module, input, minval, maxval) + return nn.Jacobian.testDiagHessian(module, input, 'gradBias', 'diagHessianBias', minval, maxval) +end + function nn.Jacobian.testIO(module,input, minval, maxval) minval = minval or -2 maxval = maxval or 2 @@ -137,92 +137,91 @@ end function Module:reset() end -function Module:getParameters() - -- get parameters - local parameters,gradParameters = self:parameters() - +-- this function flattens arbitrary lists of parameters, +-- even complex shared ones +function Module.flatten(parameters) local function storageInSet(set, storage) local storageAndOffset = set[torch.pointer(storage)] if storageAndOffset == nil then - return nil + return nil end local _, offset = table.unpack(storageAndOffset) return offset end - -- this function flattens arbitrary lists of parameters, - -- even complex shared ones - local function flatten(parameters) - if not parameters or #parameters == 0 then - return torch.Tensor() + if not parameters or #parameters == 0 then + return torch.Tensor() + end + local Tensor = parameters[1].new + local dtype = parameters[1]:type() + + local storages = {} + local nParameters = 0 + for k = 1,#parameters do + if parameters[k]:type() ~= dtype then + error("Inconsistent parameter types. " .. parameters[k]:type() .. + " ~= " .. dtype) end - local Tensor = parameters[1].new - local dtype = parameters[1]:type() - - local storages = {} - local nParameters = 0 - for k = 1,#parameters do - if parameters[k]:type() ~= dtype then - error("Inconsistent parameter types. " .. parameters[k]:type() .. - " ~= " .. dtype) - end - local storage = parameters[k]:storage() - if not storageInSet(storages, storage) then - storages[torch.pointer(storage)] = {storage, nParameters} - nParameters = nParameters + storage:size() - end + local storage = parameters[k]:storage() + if not storageInSet(storages, storage) then + storages[torch.pointer(storage)] = {storage, nParameters} + nParameters = nParameters + storage:size() end + end - local flatParameters = Tensor(nParameters):fill(1) - local flatStorage = flatParameters:storage() + local flatParameters = Tensor(nParameters):fill(1) + local flatStorage = flatParameters:storage() - for k = 1,#parameters do - local storageOffset = storageInSet(storages, parameters[k]:storage()) - parameters[k]:set(flatStorage, - storageOffset + parameters[k]:storageOffset(), - parameters[k]:size(), - parameters[k]:stride()) - parameters[k]:zero() - end + for k = 1,#parameters do + local storageOffset = storageInSet(storages, parameters[k]:storage()) + parameters[k]:set(flatStorage, + storageOffset + parameters[k]:storageOffset(), + parameters[k]:size(), + parameters[k]:stride()) + parameters[k]:zero() + end - local maskParameters = flatParameters:float():clone() - local cumSumOfHoles = flatParameters:float():cumsum(1) - local nUsedParameters = nParameters - cumSumOfHoles[#cumSumOfHoles] - local flatUsedParameters = Tensor(nUsedParameters) - local flatUsedStorage = flatUsedParameters:storage() - - for k = 1,#parameters do - local offset = cumSumOfHoles[parameters[k]:storageOffset()] - parameters[k]:set(flatUsedStorage, - parameters[k]:storageOffset() - offset, - parameters[k]:size(), - parameters[k]:stride()) - end + local maskParameters = flatParameters:float():clone() + local cumSumOfHoles = flatParameters:float():cumsum(1) + local nUsedParameters = nParameters - cumSumOfHoles[#cumSumOfHoles] + local flatUsedParameters = Tensor(nUsedParameters) + local flatUsedStorage = flatUsedParameters:storage() + + for k = 1,#parameters do + local offset = cumSumOfHoles[parameters[k]:storageOffset()] + parameters[k]:set(flatUsedStorage, + parameters[k]:storageOffset() - offset, + parameters[k]:size(), + parameters[k]:stride()) + end - for _, storageAndOffset in pairs(storages) do - local k, v = table.unpack(storageAndOffset) - flatParameters[{{v+1,v+k:size()}}]:copy(Tensor():set(k)) - end + for _, storageAndOffset in pairs(storages) do + local k, v = table.unpack(storageAndOffset) + flatParameters[{{v+1,v+k:size()}}]:copy(Tensor():set(k)) + end - if cumSumOfHoles:sum() == 0 then - flatUsedParameters:copy(flatParameters) - else - local counter = 0 - for k = 1,flatParameters:nElement() do - if maskParameters[k] == 0 then - counter = counter + 1 - flatUsedParameters[counter] = flatParameters[counter+cumSumOfHoles[k]] - end + if cumSumOfHoles:sum() == 0 then + flatUsedParameters:copy(flatParameters) + else + local counter = 0 + for k = 1,flatParameters:nElement() do + if maskParameters[k] == 0 then + counter = counter + 1 + flatUsedParameters[counter] = flatParameters[counter+cumSumOfHoles[k]] end - assert (counter == nUsedParameters) end - return flatUsedParameters + assert (counter == nUsedParameters) end + return flatUsedParameters +end +function Module:getParameters() + -- get parameters + local parameters,gradParameters = self:parameters() -- flatten parameters and gradients - local flatParameters = flatten(parameters) + local flatParameters = Module.flatten(parameters) collectgarbage() - local flatGradParameters = flatten(gradParameters) + local flatGradParameters = Module.flatten(gradParameters) collectgarbage() -- return new flat vector that contains all discrete parameters diff --git a/hessian.lua b/hessian.lua index d63c6a8..ce35749 100644 --- a/hessian.lua +++ b/hessian.lua @@ -31,6 +31,9 @@ function nn.hessian.enable() module[gwname] = hwval module[hwname] = gwval end + local oldOutput = module.output + module.output = module.output.new():resizeAs(oldOutput) + module.forward(module, module.inputSq) module.accGradParameters(module, module.inputSq, diagHessianOutput, 1) -- put back gradients for i=1,#gw do @@ -41,6 +44,7 @@ function nn.hessian.enable() module[gwname] = hwval module[hwname] = gwval end + module.output = oldOutput end nn.hessian.accDiagHessianParameters = accDiagHessianParameters @@ -210,7 +214,7 @@ function nn.hessian.enable() end function nn.SpatialConvolution.initDiagHessianParameters(self) - initDiagHessianParameters(self,{'gradWeight'},{'diagHessianWeight'}) + initDiagHessianParameters(self,{'gradWeight','gradBias'},{'diagHessianWeight','diagHessianBias'}) end ---------------------------------------------------------------------- @@ -222,7 +226,7 @@ function nn.hessian.enable() end function nn.SpatialFullConvolution.accDiagHessianParameters(self, input, diagHessianOutput) - accDiagHessianParameters(self,input, diagHessianOutput, {'gradWeight'}, {'diagHessianWeight'}) + accDiagHessianParameters(self,input, diagHessianOutput, {'gradWeight','gradBias'}, {'diagHessianWeight','diagHessianBias'}) end function nn.SpatialFullConvolution.initDiagHessianParameters(self) @@ -324,70 +328,15 @@ function nn.hessian.enable() function nn.Module.getParameters(self) -- get parameters local parameters,gradParameters,hessianParameters = self:parameters() - - local function storageInSet(set, storage) - local storageAndOffset = set[torch.pointer(storage)] - if storageAndOffset == nil then - return nil - end - local _, offset = table.unpack(storageAndOffset) - return offset - end - - -- this function flattens arbitrary lists of parameters, - -- even complex shared ones - local function flatten(parameters) - local storages = {} - local nParameters = 0 - for k = 1,#parameters do - local storage = parameters[k]:storage() - if not storageInSet(storages, storage) then - storages[torch.pointer(storage)] = {storage, nParameters} - nParameters = nParameters + storage:size() - end - end - - local flatParameters = torch.Tensor(nParameters):fill(1) - local flatStorage = flatParameters:storage() - - for k = 1,#parameters do - local storageOffset = storageInSet(storages, parameters[k]:storage()) - parameters[k]:set(flatStorage, - storageOffset + parameters[k]:storageOffset(), - parameters[k]:size(), - parameters[k]:stride()) - parameters[k]:zero() - end - - local cumSumOfHoles = flatParameters:cumsum(1) - local nUsedParameters = nParameters - cumSumOfHoles[#cumSumOfHoles] - local flatUsedParameters = torch.Tensor(nUsedParameters) - local flatUsedStorage = flatUsedParameters:storage() - - for k = 1,#parameters do - local offset = cumSumOfHoles[parameters[k]:storageOffset()] - parameters[k]:set(flatUsedStorage, - parameters[k]:storageOffset() - offset, - parameters[k]:size(), - parameters[k]:stride()) - end - - for _, storageAndOffset in pairs(storages) do - local k, v = table.unpack(storageAndOffset) - flatParameters[{{v+1,v+k:size()}}]:copy(torch.Tensor():set(k)) - end - for k = 1,flatUsedParameters:nElement() do - flatUsedParameters[k] = flatParameters[k+cumSumOfHoles[k] ] - end - return flatUsedParameters - end - -- flatten parameters and gradients - local flatParameters = flatten(parameters) - local flatGradParameters = flatten(gradParameters) + local flatParameters = nn.Module.flatten(parameters) + collectgarbage() + local flatGradParameters = nn.Module.flatten(gradParameters) + collectgarbage() local flatHessianParameters if hessianParameters and hessianParameters[1] then - flatHessianParameters = flatten(hessianParameters) + flatHessianParameters = nn.Module.flatten(hessianParameters) + collectgarbage() end -- return new flat vector that contains all discrete parameters @@ -2,6 +2,8 @@ -- th -lnn -e "nn.test{'LookupTable'}" -- th -lnn -e "nn.test{'LookupTable', 'Add'}" +nn.hessian.enable() + local mytester = torch.Tester() local jac local sjac @@ -484,6 +486,15 @@ function nntest.Linear() local err = jac.testJacobianUpdateParameters(module, input, module.bias) mytester:assertlt(err,precision, 'error on bias [direct update] ') + local err = jac.testDiagHessianInput(module, input) + mytester:assertlt(err , precision, 'error on diagHessianInput') + + local err = jac.testDiagHessianWeight(module, input) + mytester:assertlt(err , precision, 'error on diagHessianWeight') + + local err = jac.testDiagHessianBias(module, input) + mytester:assertlt(err , precision, 'error on diagHessianBias') + for t,err in pairs(jac.testAllUpdate(module, input, 'weight', 'gradWeight')) do mytester:assertlt(err, precision, string.format( 'error on weight [%s]', t)) @@ -513,6 +524,15 @@ function nntest.Linear() local err = jac.testJacobianUpdateParameters(module, input, module.bias) mytester:assertlt(err,precision, 'error on bias [direct update] ') + local err = jac.testDiagHessianInput(module, input) + mytester:assertlt(err , precision, 'error on diagHessianInput') + + local err = jac.testDiagHessianWeight(module, input) + mytester:assertlt(err , precision, 'error on diagHessianWeight') + + local err = jac.testDiagHessianBias(module, input) + mytester:assertlt(err , precision, 'error on diag HessianBias') + for t,err in pairs(jac.testAllUpdate(module, input, 'weight', 'gradWeight')) do mytester:assertlt(err, precision, string.format( 'error on weight [%s]', t)) @@ -1384,6 +1404,15 @@ function nntest.SpatialConvolution() local err = jac.testJacobianUpdateParameters(module, input, module.bias) mytester:assertlt(err , precision, 'error on bias [direct update] ') + local err = jac.testDiagHessianInput(module, input) + mytester:assertlt(err , precision, 'error on diagHessianInput') + + local err = jac.testDiagHessianWeight(module, input) + mytester:assertlt(err , precision, 'error on diagHessianWeight') + + local err = jac.testDiagHessianBias(module, input) + mytester:assertlt(err , precision, 'error on diag HessianBias') + for t,err in pairs(jac.testAllUpdate(module, input, 'weight', 'gradWeight')) do mytester:assertlt(err, precision, string.format( 'error on weight [%s]', t)) @@ -1424,6 +1453,15 @@ function nntest.SpatialConvolution() local err = jac.testJacobianUpdateParameters(module, input, module.bias) mytester:assertlt(err , precision, 'batch error on bias [direct update] ') + local err = jac.testDiagHessianInput(module, input) + mytester:assertlt(err , precision, 'error on diagHessianInput') + + local err = jac.testDiagHessianWeight(module, input) + mytester:assertlt(err , precision, 'error on diagHessianWeight') + + local err = jac.testDiagHessianBias(module, input) + mytester:assertlt(err , precision, 'error on diag HessianBias') + for t,err in pairs(jac.testAllUpdate(module, input, 'weight', 'gradWeight')) do mytester:assertlt(err, precision, string.format( 'error on weight [%s]', t)) @@ -1555,6 +1593,15 @@ function nntest.SpatialConvolutionMap() local err = jac.testJacobianParameters(module, input, module.bias, module.gradBias) mytester:assertlt(err , precision, 'error on bias ') + local err = jac.testDiagHessianInput(module, input) + mytester:assertlt(err , precision, 'error on diagHessianInput') + + local err = jac.testDiagHessianWeight(module, input) + mytester:assertlt(err , precision, 'error on diagHessianWeight') + + local err = jac.testDiagHessianBias(module, input) + mytester:assertlt(err , precision, 'error on diag HessianBias') + for t,err in pairs(jac.testAllUpdate(module, input, 'weight', 'gradWeight')) do mytester:assertlt(err, precision, string.format( 'error on weight [%s]', t)) @@ -1593,6 +1640,15 @@ function nntest.SpatialConvolutionMap() local err = jac.testJacobianUpdateParameters(module, input, module.bias) mytester:assertlt(err , precision, 'batch error on bias [direct update] ') + local err = jac.testDiagHessianInput(module, input) + mytester:assertlt(err , precision, 'error on diagHessianInput') + + local err = jac.testDiagHessianWeight(module, input) + mytester:assertlt(err , precision, 'error on diagHessianWeight') + + local err = jac.testDiagHessianBias(module, input) + mytester:assertlt(err , precision, 'error on diag HessianBias') + for t,err in pairs(jac.testAllUpdate(module, input, 'weight', 'gradWeight')) do mytester:assertlt(err, precision, string.format( 'error on weight [%s]', t)) @@ -1637,6 +1693,15 @@ function nntest.SpatialFullConvolution() local err = jac.testJacobianUpdateParameters(module, input, module.bias) mytester:assertlt(err , precision, 'error on bias [direct update] ') + local err = jac.testDiagHessianInput(module, input) + mytester:assertlt(err , precision, 'error on diagHessianInput') + + local err = jac.testDiagHessianWeight(module, input) + mytester:assertlt(err , precision, 'error on diagHessianWeight') + + local err = jac.testDiagHessianBias(module, input) + mytester:assertlt(err , precision, 'error on diag HessianBias') + for t,err in pairs(jac.testAllUpdate(module, input, 'weight', 'gradWeight')) do mytester:assertlt(err, precision, string.format( 'error on weight [%s]', t)) @@ -1669,6 +1734,15 @@ function nntest.SpatialFullConvolution() local err = jac.testJacobianUpdateParameters(module, input, module.bias) mytester:assertlt(err , precision, 'batch error on bias [direct update] ') + local err = jac.testDiagHessianInput(module, input) + mytester:assertlt(err , precision, 'error on diagHessianInput') + + local err = jac.testDiagHessianWeight(module, input) + mytester:assertlt(err , precision, 'error on diagHessianWeight') + + local err = jac.testDiagHessianBias(module, input) + mytester:assertlt(err , precision, 'error on diag HessianBias') + for t,err in pairs(jac.testAllUpdate(module, input, 'weight', 'gradWeight')) do mytester:assertlt(err, precision, string.format( 'error on weight [%s]', t)) @@ -1714,6 +1788,15 @@ function nntest.SpatialFullConvolutionMap() local err = jac.testJacobianUpdateParameters(module, input, module.bias) mytester:assertlt(err , precision, 'error on bias [direct update] ') + local err = jac.testDiagHessianInput(module, input) + mytester:assertlt(err , precision, 'error on diagHessianInput') + + local err = jac.testDiagHessianWeight(module, input) + mytester:assertlt(err , precision, 'error on diagHessianWeight') + + local err = jac.testDiagHessianBias(module, input) + mytester:assertlt(err , precision, 'error on diag HessianBias') + for t,err in pairs(jac.testAllUpdate(module, input, 'weight', 'gradWeight')) do mytester:assertlt(err, precision, string.format( 'error on weight [%s]', t)) |