Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJames Kirkpatrick <kirkpatrick@google.com>2014-10-01 12:16:27 +0400
committerSoumith Chintala <soumith@gmail.com>2014-10-27 17:05:16 +0300
commit8581f84f45c2baef72dd32f6336410480b617ff7 (patch)
treee590661f77f4780dee406b44cc6e11e3fdce09cf /Module.lua
parent8d7d03ebe72ae507ae716f292778123dc34e04b1 (diff)
Corrected getParamaters for partial views
Module:getParameters was incorrectly overwriting parameters that were partial views on larger storages.
Diffstat (limited to 'Module.lua')
-rw-r--r--Module.lua11
1 files changed, 9 insertions, 2 deletions
diff --git a/Module.lua b/Module.lua
index 1d7d732..d9410c9 100644
--- a/Module.lua
+++ b/Module.lua
@@ -205,6 +205,7 @@ function Module:getParameters()
parameters[k]:zero()
end
+ local maskParameters= flatParameters:float():clone()
local cumSumOfHoles = flatParameters:float():cumsum(1)
local nUsedParameters = nParameters - cumSumOfHoles[#cumSumOfHoles]
local flatUsedParameters = Tensor(nUsedParameters)
@@ -222,12 +223,18 @@ function Module:getParameters()
local k, v = unpack(storageAndOffset)
flatParameters[{{v+1,v+k:size()}}]:copy(Tensor():set(k))
end
+
if cumSumOfHoles:sum() == 0 then
flatUsedParameters:copy(flatParameters)
else
- for k = 1,flatUsedParameters:nElement() do
- flatUsedParameters[k] = flatParameters[k+cumSumOfHoles[k]]
+ local counter = 0
+ for k = 1,flatParameters:nElement() do
+ if maskParameters[k] == 0 then
+ counter = counter + 1
+ flatUsedParameters[counter] = flatParameters[counter+cumSumOfHoles[k]]
+ end
end
+ assert (counter == nUsedParameters)
end
return flatUsedParameters
end