Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndrey Golovizin <ag@sologoc.com>2015-07-01 11:27:41 +0300
committerAndrey Golovizin <ag@sologoc.com>2015-07-10 00:20:11 +0300
commita53cb3cce78cb218eee7fe2f3c67a8d0fc411dcc (patch)
treee488da6d0fbf97d47a633c9bc2c3861e4ec6e5d7 /Module.lua
parentb7aa53d96fbb6c0f2eaa1976b28c5cf12edf1ced (diff)
Add unit tests for hessian.lua, fix bugs detected by the tests.
* Fix initialization of diagHessianBias for nn.SpatialConvolution. * Fix computing diagHessianBias for nn.SpatialFullConvolution. * Call module:forward() with the proper input before calling accGradParameters(). Without that, accDiagHessianParameters() produces incorrect results for some convolution classes. * Move duplicate code from Module.getParameters() to Module.flatten(), which is now used by both the original Module.getParameters() in Module.lua and the replacement Module.getParameters() in hessian.lua.
Diffstat (limited to 'Module.lua')
-rw-r--r--Module.lua129
1 files changed, 64 insertions, 65 deletions
diff --git a/Module.lua b/Module.lua
index c2075b8..f5029e9 100644
--- a/Module.lua
+++ b/Module.lua
@@ -137,92 +137,91 @@ end
function Module:reset()
end
-function Module:getParameters()
- -- get parameters
- local parameters,gradParameters = self:parameters()
-
+-- this function flattens arbitrary lists of parameters,
+-- even complex shared ones
+function Module.flatten(parameters)
local function storageInSet(set, storage)
local storageAndOffset = set[torch.pointer(storage)]
if storageAndOffset == nil then
- return nil
+ return nil
end
local _, offset = table.unpack(storageAndOffset)
return offset
end
- -- this function flattens arbitrary lists of parameters,
- -- even complex shared ones
- local function flatten(parameters)
- if not parameters or #parameters == 0 then
- return torch.Tensor()
+ if not parameters or #parameters == 0 then
+ return torch.Tensor()
+ end
+ local Tensor = parameters[1].new
+ local dtype = parameters[1]:type()
+
+ local storages = {}
+ local nParameters = 0
+ for k = 1,#parameters do
+ if parameters[k]:type() ~= dtype then
+ error("Inconsistent parameter types. " .. parameters[k]:type() ..
+ " ~= " .. dtype)
end
- local Tensor = parameters[1].new
- local dtype = parameters[1]:type()
-
- local storages = {}
- local nParameters = 0
- for k = 1,#parameters do
- if parameters[k]:type() ~= dtype then
- error("Inconsistent parameter types. " .. parameters[k]:type() ..
- " ~= " .. dtype)
- end
- local storage = parameters[k]:storage()
- if not storageInSet(storages, storage) then
- storages[torch.pointer(storage)] = {storage, nParameters}
- nParameters = nParameters + storage:size()
- end
+ local storage = parameters[k]:storage()
+ if not storageInSet(storages, storage) then
+ storages[torch.pointer(storage)] = {storage, nParameters}
+ nParameters = nParameters + storage:size()
end
+ end
- local flatParameters = Tensor(nParameters):fill(1)
- local flatStorage = flatParameters:storage()
+ local flatParameters = Tensor(nParameters):fill(1)
+ local flatStorage = flatParameters:storage()
- for k = 1,#parameters do
- local storageOffset = storageInSet(storages, parameters[k]:storage())
- parameters[k]:set(flatStorage,
- storageOffset + parameters[k]:storageOffset(),
- parameters[k]:size(),
- parameters[k]:stride())
- parameters[k]:zero()
- end
+ for k = 1,#parameters do
+ local storageOffset = storageInSet(storages, parameters[k]:storage())
+ parameters[k]:set(flatStorage,
+ storageOffset + parameters[k]:storageOffset(),
+ parameters[k]:size(),
+ parameters[k]:stride())
+ parameters[k]:zero()
+ end
- local maskParameters = flatParameters:float():clone()
- local cumSumOfHoles = flatParameters:float():cumsum(1)
- local nUsedParameters = nParameters - cumSumOfHoles[#cumSumOfHoles]
- local flatUsedParameters = Tensor(nUsedParameters)
- local flatUsedStorage = flatUsedParameters:storage()
-
- for k = 1,#parameters do
- local offset = cumSumOfHoles[parameters[k]:storageOffset()]
- parameters[k]:set(flatUsedStorage,
- parameters[k]:storageOffset() - offset,
- parameters[k]:size(),
- parameters[k]:stride())
- end
+ local maskParameters = flatParameters:float():clone()
+ local cumSumOfHoles = flatParameters:float():cumsum(1)
+ local nUsedParameters = nParameters - cumSumOfHoles[#cumSumOfHoles]
+ local flatUsedParameters = Tensor(nUsedParameters)
+ local flatUsedStorage = flatUsedParameters:storage()
+
+ for k = 1,#parameters do
+ local offset = cumSumOfHoles[parameters[k]:storageOffset()]
+ parameters[k]:set(flatUsedStorage,
+ parameters[k]:storageOffset() - offset,
+ parameters[k]:size(),
+ parameters[k]:stride())
+ end
- for _, storageAndOffset in pairs(storages) do
- local k, v = table.unpack(storageAndOffset)
- flatParameters[{{v+1,v+k:size()}}]:copy(Tensor():set(k))
- end
+ for _, storageAndOffset in pairs(storages) do
+ local k, v = table.unpack(storageAndOffset)
+ flatParameters[{{v+1,v+k:size()}}]:copy(Tensor():set(k))
+ end
- if cumSumOfHoles:sum() == 0 then
- flatUsedParameters:copy(flatParameters)
- else
- local counter = 0
- for k = 1,flatParameters:nElement() do
- if maskParameters[k] == 0 then
- counter = counter + 1
- flatUsedParameters[counter] = flatParameters[counter+cumSumOfHoles[k]]
- end
+ if cumSumOfHoles:sum() == 0 then
+ flatUsedParameters:copy(flatParameters)
+ else
+ local counter = 0
+ for k = 1,flatParameters:nElement() do
+ if maskParameters[k] == 0 then
+ counter = counter + 1
+ flatUsedParameters[counter] = flatParameters[counter+cumSumOfHoles[k]]
end
- assert (counter == nUsedParameters)
end
- return flatUsedParameters
+ assert (counter == nUsedParameters)
end
+ return flatUsedParameters
+end
+function Module:getParameters()
+ -- get parameters
+ local parameters,gradParameters = self:parameters()
-- flatten parameters and gradients
- local flatParameters = flatten(parameters)
+ local flatParameters = Module.flatten(parameters)
collectgarbage()
- local flatGradParameters = flatten(gradParameters)
+ local flatGradParameters = Module.flatten(gradParameters)
collectgarbage()
-- return new flat vector that contains all discrete parameters