Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJames Kirkpatrick <kirkpatrick@google.com>2014-10-01 12:16:27 +0400
committerSoumith Chintala <soumith@gmail.com>2014-10-27 17:05:16 +0300
commit8581f84f45c2baef72dd32f6336410480b617ff7 (patch)
treee590661f77f4780dee406b44cc6e11e3fdce09cf
parent8d7d03ebe72ae507ae716f292778123dc34e04b1 (diff)
Corrected getParamaters for partial views
Module:getParameters was incorrectly overwriting parameters that were partial views on larger storages.
-rw-r--r--Module.lua11
-rw-r--r--test/test.lua31
2 files changed, 40 insertions, 2 deletions
diff --git a/Module.lua b/Module.lua
index 1d7d732..d9410c9 100644
--- a/Module.lua
+++ b/Module.lua
@@ -205,6 +205,7 @@ function Module:getParameters()
parameters[k]:zero()
end
+ local maskParameters= flatParameters:float():clone()
local cumSumOfHoles = flatParameters:float():cumsum(1)
local nUsedParameters = nParameters - cumSumOfHoles[#cumSumOfHoles]
local flatUsedParameters = Tensor(nUsedParameters)
@@ -222,12 +223,18 @@ function Module:getParameters()
local k, v = unpack(storageAndOffset)
flatParameters[{{v+1,v+k:size()}}]:copy(Tensor():set(k))
end
+
if cumSumOfHoles:sum() == 0 then
flatUsedParameters:copy(flatParameters)
else
- for k = 1,flatUsedParameters:nElement() do
- flatUsedParameters[k] = flatParameters[k+cumSumOfHoles[k]]
+ local counter = 0
+ for k = 1,flatParameters:nElement() do
+ if maskParameters[k] == 0 then
+ counter = counter + 1
+ flatUsedParameters[counter] = flatParameters[counter+cumSumOfHoles[k]]
+ end
end
+ assert (counter == nUsedParameters)
end
return flatUsedParameters
end
diff --git a/test/test.lua b/test/test.lua
index b035928..aa1d062 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -1745,6 +1745,37 @@ function nntest.Module_getParameters_7()
mytester:asserteq(p:nElement(), 121, 'error: incorrect number of elements in flat vector')
end
+function nntest.Module_getParameters_8()
+ local function makeMLP(nin, ns)
+ local net = nn.Sequential()
+
+ for k,v in ipairs(ns) do
+ net:add(nn.Linear(nin, v))
+ nin = v
+ end
+ _,_ = net:getParameters()
+ return net
+ end
+
+ local mlp1 = makeMLP(10, {10,10})
+ local mlp2 = makeMLP(10, {10,10})
+
+ local net = nn.Sequential():add(mlp1:get(1))
+ :add(mlp2:get(1))
+
+ -- clone the second MLP to ensure that the weights before calling getParameters are preserved
+ mlp2 = mlp2:clone()
+
+ local p, gp = net:getParameters()
+
+ mytester:asserteq((p[{ {1,100} }] - net.modules[1].weight):norm(), 0, 'error when using partial realloc')
+ mytester:asserteq((p[{ {111,210} }] - net.modules[2].weight):norm(), 0, 'error when using partial realloc')
+ -- check that the weights have the same values as before get Parameters was called
+ mytester:asserteq((net.modules[1].weight - mlp1.modules[1].weight):norm(), 0, ' error when using partial realloc')
+ mytester:asserteq((net.modules[2].weight - mlp2.modules[1].weight):norm(), 0, ' error when using partial realloc')
+
+end
+
function nntest.PairwiseDistance()
-- Note: testJacobian doesn't support table inputs, and rather than re-write
-- it so that it does, I'll just use a split table module on the input.