diff options
author | Clement Farabet <clement.farabet@gmail.com> | 2012-10-22 22:54:07 +0400 |
---|---|---|
committer | Clement Farabet <clement.farabet@gmail.com> | 2012-10-22 22:54:07 +0400 |
commit | 1703ce686332ecf4f98f28e94ef1820c5c2a5e63 (patch) | |
tree | 6d55547294732e98724ec330259b906430de14ed | |
parent | b1bd980dc1ea466a640e4af87535fed1044c978c (diff) | |
parent | 26c45850a83e6956129429550e13bb582e405740 (diff) |
Merge branch 'master' of github.com:andresy/torch
-rw-r--r-- | Module.lua | 20 |
1 files changed, 11 insertions, 9 deletions
@@ -150,6 +150,8 @@ function Module:getParameters() -- this function flattens arbitrary lists of parameters, -- even complex shared ones local function flatten(parameters) + local Tensor = parameters[1].new + local storages = {} local nParameters = 0 for k = 1,#parameters do @@ -159,7 +161,7 @@ function Module:getParameters() end end - local flatParameters = torch.Tensor(nParameters):fill(1) + local flatParameters = Tensor(nParameters):fill(1) local flatStorage = flatParameters:storage() for k = 1,#parameters do @@ -171,21 +173,21 @@ function Module:getParameters() parameters[k]:zero() end - local cumSumOfHoles = flatParameters:cumsum(1) + local cumSumOfHoles = flatParameters:float():cumsum(1) local nUsedParameters = nParameters - cumSumOfHoles[#cumSumOfHoles] - local flatUsedParameters = torch.Tensor(nUsedParameters) + local flatUsedParameters = Tensor(nUsedParameters) local flatUsedStorage = flatUsedParameters:storage() for k = 1,#parameters do - local offset = cumSumOfHoles[parameters[k]:storageOffset()] - parameters[k]:set(flatUsedStorage, - parameters[k]:storageOffset() - offset, - parameters[k]:size(), - parameters[k]:stride()) + local offset = cumSumOfHoles[parameters[k]:storageOffset()] + parameters[k]:set(flatUsedStorage, + parameters[k]:storageOffset() - offset, + parameters[k]:size(), + parameters[k]:stride()) end for k, v in pairs(storages) do - flatParameters[{{v+1,v+k:size()}}]:copy(torch.Tensor():set(k)) + flatParameters[{{v+1,v+k:size()}}]:copy(Tensor():set(k)) end for k = 1,flatUsedParameters:nElement() do flatUsedParameters[k] = flatParameters[k+cumSumOfHoles[k] ] |