diff options
author | Clement Farabet <clement.farabet@gmail.com> | 2012-10-21 02:43:53 +0400 |
---|---|---|
committer | Clement Farabet <clement.farabet@gmail.com> | 2012-10-21 02:43:53 +0400 |
commit | 26c45850a83e6956129429550e13bb582e405740 (patch) | |
tree | 248851af8e562d7eaba3d1ceaef9b0d3bd618ed1 | |
parent | ffa6395c00d5cf66467733bd0b19364a503704be (diff) |
Fixed getParameters() for CUDA. When did that break?
-rw-r--r-- | Module.lua | 20 |
1 files changed, 11 insertions, 9 deletions
@@ -150,6 +150,8 @@ function Module:getParameters() -- this function flattens arbitrary lists of parameters, -- even complex shared ones local function flatten(parameters) + local Tensor = parameters[1].new + local storages = {} local nParameters = 0 for k = 1,#parameters do @@ -159,7 +161,7 @@ function Module:getParameters() end end - local flatParameters = torch.Tensor(nParameters):fill(1) + local flatParameters = Tensor(nParameters):fill(1) local flatStorage = flatParameters:storage() for k = 1,#parameters do @@ -171,21 +173,21 @@ function Module:getParameters() parameters[k]:zero() end - local cumSumOfHoles = flatParameters:cumsum(1) + local cumSumOfHoles = flatParameters:float():cumsum(1) local nUsedParameters = nParameters - cumSumOfHoles[#cumSumOfHoles] - local flatUsedParameters = torch.Tensor(nUsedParameters) + local flatUsedParameters = Tensor(nUsedParameters) local flatUsedStorage = flatUsedParameters:storage() for k = 1,#parameters do - local offset = cumSumOfHoles[parameters[k]:storageOffset()] - parameters[k]:set(flatUsedStorage, - parameters[k]:storageOffset() - offset, - parameters[k]:size(), - parameters[k]:stride()) + local offset = cumSumOfHoles[parameters[k]:storageOffset()] + parameters[k]:set(flatUsedStorage, + parameters[k]:storageOffset() - offset, + parameters[k]:size(), + parameters[k]:stride()) end for k, v in pairs(storages) do - flatParameters[{{v+1,v+k:size()}}]:copy(torch.Tensor():set(k)) + flatParameters[{{v+1,v+k:size()}}]:copy(Tensor():set(k)) end for k = 1,flatUsedParameters:nElement() do flatUsedParameters[k] = flatParameters[k+cumSumOfHoles[k] ] |