diff options
author | Clement Farabet <clement.farabet@gmail.com> | 2012-09-21 23:33:25 +0400 |
---|---|---|
committer | Clement Farabet <clement.farabet@gmail.com> | 2012-09-21 23:33:25 +0400 |
commit | c1cb8b84d6f7a9ab3d76dee256e79667616fa2ac (patch) | |
tree | 630e9d93cc45bf56da84917744c2ba512b730577 | |
parent | c4ec7d3e275074df635b78e2c1558931c5931919 (diff) |
Added an extra corner case to getParameters().
This might finally fix all the possible corners.
Fix by Michael Matthieu.
-rw-r--r-- | Module.lua | 26 | ||||
-rw-r--r-- | hessian.lua | 26 | ||||
-rw-r--r-- | test/test.lua | 42 |
3 files changed, 80 insertions, 14 deletions
@@ -170,17 +170,29 @@ function Module:getParameters() parameters[k]:stride()) parameters[k]:zero() end - if (flatParameters:sum() ~= 0) then - print("<getParameters()> WARNING: found " - .. flatParameters:sum() .. " holes in the parameters vector (i.e. " - .. flatParameters:sum() .. " storage elements that are unused, this " - .. "might be an issue for your optimization procedure)") + + local cumSumOfHoles = flatParameters:cumsum(1) + local nUsedParameters = nParameters - cumSumOfHoles[#cumSumOfHoles] + local flatUsedParameters = torch.Tensor(nUsedParameters) + local flatUsedStorage = flatUsedParameters:storage() + + for k = 1,#parameters do + local offset = cumSumOfHoles[parameters[k]:storageOffset()] + parameters[k]:set(flatUsedStorage, + parameters[k]:storageOffset() - offset, + parameters[k]:size(), + parameters[k]:stride()) end - for k, v in pairs(storages) do + for k, v in pairs(storages) do -- we could remove this loop flatParameters[{{v+1,v+k:size()}}]:copy(torch.Tensor():set(k)) end - return flatParameters + print('crap') + for k = 1,flatUsedParameters:nElement() do + flatUsedParameters[k] = flatParameters[k+cumSumOfHoles[k] ] + end + print('0') + return flatUsedParameters end -- flatten parameters and gradients diff --git a/hessian.lua b/hessian.lua index 422e1d9..9c79b30 100644 --- a/hessian.lua +++ b/hessian.lua @@ -296,17 +296,29 @@ function nn.hessian.enable() parameters[k]:stride()) parameters[k]:zero() end - if (flatParameters:sum() ~= 0) then - print("<getParameters()> WARNING: found " - .. flatParameters:sum() .. " holes in the parameters vector (i.e. " - .. flatParameters:sum() .. " storage elements that are unused, this " - .. "might be an issue for your optimization procedure)") + + local cumSumOfHoles = flatParameters:cumsum(1) + local nUsedParameters = nParameters - cumSumOfHoles[#cumSumOfHoles] + local flatUsedParameters = torch.Tensor(nUsedParameters) + local flatUsedStorage = flatUsedParameters:storage() + + for k = 1,#parameters do + local offset = cumSumOfHoles[parameters[k]:storageOffset()] + parameters[k]:set(flatUsedStorage, + parameters[k]:storageOffset() - offset, + parameters[k]:size(), + parameters[k]:stride()) end - for k, v in pairs(storages) do + for k, v in pairs(storages) do -- we could remove this loop flatParameters[{{v+1,v+k:size()}}]:copy(torch.Tensor():set(k)) end - return flatParameters + print('crap') + for k = 1,flatUsedParameters:nElement() do + flatUsedParameters[k] = flatParameters[k+cumSumOfHoles[k] ] + end + print('0') + return flatUsedParameters end -- flatten parameters and gradients diff --git a/test/test.lua b/test/test.lua index 0d4a979..583deb2 100644 --- a/test/test.lua +++ b/test/test.lua @@ -1137,6 +1137,8 @@ function nntest.Module_getParameters_4() mytester:asserteq((p[{ {221,320} }] - n.modules[3].weight):norm(), 0, 'error when using cloning') mytester:asserteq((p[{ {321,330} }] - n.modules[3].bias):norm(), 0, 'error when using cloning') + + mytester:asserteq(p:nElement(), 3*(10*10+10), 'error: incorrect number of elements in flat vector') end function nntest.Module_getParameters_5() @@ -1155,6 +1157,8 @@ function nntest.Module_getParameters_5() mytester:asserteq((p[{ {1,100} }] - n.modules[2].weight):norm(), 0, 'error when using cloning+sharing') mytester:asserteq((p[{ {101,110} }] - n.modules[2].bias):norm(), 0, 'error when using cloning+sharing') + + mytester:asserteq(p:nElement(), (10*10+10), 'error: incorrect number of elements in flat vector') end function nntest.Module_getParameters_6() @@ -1174,6 +1178,44 @@ function nntest.Module_getParameters_6() mytester:asserteq((p[{ {111,210} }] - n.modules[3].weight):norm(), 0, 'error when using cloning+sharing') mytester:asserteq((p[{ {211,220} }] - n.modules[3].bias):norm(), 0, 'error when using cloning+sharing') + + mytester:asserteq(p:nElement(), 2*(10*10+10), 'error: incorrect number of elements in flat vector') +end + +function nntest.Module_getParameters_7() + local n = nn.Sequential() + n:add( nn.Linear(10,10) ) + n:add( n.modules[1]:clone('weight','bias') ) + local p = n:getParameters() + + n:add(nn.Linear(10,10)) + p = n:getParameters() + + local n1 = nn.Sequential() + n1:add( nn.Linear(10,10) ) + + local n2 = nn.Sequential() + n2:add( nn.Linear(10,10) ) + + local n = nn.Sequential() + n:add( n1 ) + n:add( n2 ) + + local p = n:getParameters() + + local nf = nn.Sequential() + nf:add( n1 ) + nf:add( nn.Linear(10,1) ) + + local p = nf:getParameters() + + mytester:asserteq((p[{ {1,100} }] - n1.modules[1].weight):norm(), 0, 'error when using cloning+partial realloc') + mytester:asserteq((p[{ {101,110} }] - n1.modules[1].bias):norm(), 0, 'error when using cloning+partial realloc') + + mytester:asserteq((p[{ {111,120} }] - nf.modules[2].weight):norm(), 0, 'error when using cloning+partial realloc') + mytester:asserteq((p[{ {121,121} }] - nf.modules[2].bias):norm(), 0, 'error when using cloning+partial realloc') + + mytester:asserteq(p:nElement(), 121, 'error: incorrect number of elements in flat vector') end mytester:add(nntest) |