Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorClement Farabet <clement.farabet@gmail.com>2012-09-21 23:33:25 +0400
committerClement Farabet <clement.farabet@gmail.com>2012-09-21 23:33:25 +0400
commitc1cb8b84d6f7a9ab3d76dee256e79667616fa2ac (patch)
tree630e9d93cc45bf56da84917744c2ba512b730577
parentc4ec7d3e275074df635b78e2c1558931c5931919 (diff)
Added an extra corner case to getParameters().
This might finally fix all the possible corners. Fix by Michael Matthieu.
-rw-r--r--Module.lua26
-rw-r--r--hessian.lua26
-rw-r--r--test/test.lua42
3 files changed, 80 insertions, 14 deletions
diff --git a/Module.lua b/Module.lua
index a296db7..9df96c8 100644
--- a/Module.lua
+++ b/Module.lua
@@ -170,17 +170,29 @@ function Module:getParameters()
parameters[k]:stride())
parameters[k]:zero()
end
- if (flatParameters:sum() ~= 0) then
- print("<getParameters()> WARNING: found "
- .. flatParameters:sum() .. " holes in the parameters vector (i.e. "
- .. flatParameters:sum() .. " storage elements that are unused, this "
- .. "might be an issue for your optimization procedure)")
+
+ local cumSumOfHoles = flatParameters:cumsum(1)
+ local nUsedParameters = nParameters - cumSumOfHoles[#cumSumOfHoles]
+ local flatUsedParameters = torch.Tensor(nUsedParameters)
+ local flatUsedStorage = flatUsedParameters:storage()
+
+ for k = 1,#parameters do
+ local offset = cumSumOfHoles[parameters[k]:storageOffset()]
+ parameters[k]:set(flatUsedStorage,
+ parameters[k]:storageOffset() - offset,
+ parameters[k]:size(),
+ parameters[k]:stride())
end
- for k, v in pairs(storages) do
+ for k, v in pairs(storages) do -- we could remove this loop
flatParameters[{{v+1,v+k:size()}}]:copy(torch.Tensor():set(k))
end
- return flatParameters
+ print('crap')
+ for k = 1,flatUsedParameters:nElement() do
+ flatUsedParameters[k] = flatParameters[k+cumSumOfHoles[k] ]
+ end
+ print('0')
+ return flatUsedParameters
end
-- flatten parameters and gradients
diff --git a/hessian.lua b/hessian.lua
index 422e1d9..9c79b30 100644
--- a/hessian.lua
+++ b/hessian.lua
@@ -296,17 +296,29 @@ function nn.hessian.enable()
parameters[k]:stride())
parameters[k]:zero()
end
- if (flatParameters:sum() ~= 0) then
- print("<getParameters()> WARNING: found "
- .. flatParameters:sum() .. " holes in the parameters vector (i.e. "
- .. flatParameters:sum() .. " storage elements that are unused, this "
- .. "might be an issue for your optimization procedure)")
+
+ local cumSumOfHoles = flatParameters:cumsum(1)
+ local nUsedParameters = nParameters - cumSumOfHoles[#cumSumOfHoles]
+ local flatUsedParameters = torch.Tensor(nUsedParameters)
+ local flatUsedStorage = flatUsedParameters:storage()
+
+ for k = 1,#parameters do
+ local offset = cumSumOfHoles[parameters[k]:storageOffset()]
+ parameters[k]:set(flatUsedStorage,
+ parameters[k]:storageOffset() - offset,
+ parameters[k]:size(),
+ parameters[k]:stride())
end
- for k, v in pairs(storages) do
+ for k, v in pairs(storages) do -- we could remove this loop
flatParameters[{{v+1,v+k:size()}}]:copy(torch.Tensor():set(k))
end
- return flatParameters
+ print('crap')
+ for k = 1,flatUsedParameters:nElement() do
+ flatUsedParameters[k] = flatParameters[k+cumSumOfHoles[k] ]
+ end
+ print('0')
+ return flatUsedParameters
end
-- flatten parameters and gradients
diff --git a/test/test.lua b/test/test.lua
index 0d4a979..583deb2 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -1137,6 +1137,8 @@ function nntest.Module_getParameters_4()
mytester:asserteq((p[{ {221,320} }] - n.modules[3].weight):norm(), 0, 'error when using cloning')
mytester:asserteq((p[{ {321,330} }] - n.modules[3].bias):norm(), 0, 'error when using cloning')
+
+ mytester:asserteq(p:nElement(), 3*(10*10+10), 'error: incorrect number of elements in flat vector')
end
function nntest.Module_getParameters_5()
@@ -1155,6 +1157,8 @@ function nntest.Module_getParameters_5()
mytester:asserteq((p[{ {1,100} }] - n.modules[2].weight):norm(), 0, 'error when using cloning+sharing')
mytester:asserteq((p[{ {101,110} }] - n.modules[2].bias):norm(), 0, 'error when using cloning+sharing')
+
+ mytester:asserteq(p:nElement(), (10*10+10), 'error: incorrect number of elements in flat vector')
end
function nntest.Module_getParameters_6()
@@ -1174,6 +1178,44 @@ function nntest.Module_getParameters_6()
mytester:asserteq((p[{ {111,210} }] - n.modules[3].weight):norm(), 0, 'error when using cloning+sharing')
mytester:asserteq((p[{ {211,220} }] - n.modules[3].bias):norm(), 0, 'error when using cloning+sharing')
+
+ mytester:asserteq(p:nElement(), 2*(10*10+10), 'error: incorrect number of elements in flat vector')
+end
+
+function nntest.Module_getParameters_7()
+ local n = nn.Sequential()
+ n:add( nn.Linear(10,10) )
+ n:add( n.modules[1]:clone('weight','bias') )
+ local p = n:getParameters()
+
+ n:add(nn.Linear(10,10))
+ p = n:getParameters()
+
+ local n1 = nn.Sequential()
+ n1:add( nn.Linear(10,10) )
+
+ local n2 = nn.Sequential()
+ n2:add( nn.Linear(10,10) )
+
+ local n = nn.Sequential()
+ n:add( n1 )
+ n:add( n2 )
+
+ local p = n:getParameters()
+
+ local nf = nn.Sequential()
+ nf:add( n1 )
+ nf:add( nn.Linear(10,1) )
+
+ local p = nf:getParameters()
+
+ mytester:asserteq((p[{ {1,100} }] - n1.modules[1].weight):norm(), 0, 'error when using cloning+partial realloc')
+ mytester:asserteq((p[{ {101,110} }] - n1.modules[1].bias):norm(), 0, 'error when using cloning+partial realloc')
+
+ mytester:asserteq((p[{ {111,120} }] - nf.modules[2].weight):norm(), 0, 'error when using cloning+partial realloc')
+ mytester:asserteq((p[{ {121,121} }] - nf.modules[2].bias):norm(), 0, 'error when using cloning+partial realloc')
+
+ mytester:asserteq(p:nElement(), 121, 'error: incorrect number of elements in flat vector')
end
mytester:add(nntest)