Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorClement Farabet <clement.farabet@gmail.com>2012-10-21 02:43:53 +0400
committerClement Farabet <clement.farabet@gmail.com>2012-10-21 02:43:53 +0400
commit26c45850a83e6956129429550e13bb582e405740 (patch)
tree248851af8e562d7eaba3d1ceaef9b0d3bd618ed1
parentffa6395c00d5cf66467733bd0b19364a503704be (diff)
Fixed getParameters() for CUDA. When did that break?
-rw-r--r--Module.lua20
1 files changed, 11 insertions, 9 deletions
diff --git a/Module.lua b/Module.lua
index 03d62fb..4423aea 100644
--- a/Module.lua
+++ b/Module.lua
@@ -150,6 +150,8 @@ function Module:getParameters()
-- this function flattens arbitrary lists of parameters,
-- even complex shared ones
local function flatten(parameters)
+ local Tensor = parameters[1].new
+
local storages = {}
local nParameters = 0
for k = 1,#parameters do
@@ -159,7 +161,7 @@ function Module:getParameters()
end
end
- local flatParameters = torch.Tensor(nParameters):fill(1)
+ local flatParameters = Tensor(nParameters):fill(1)
local flatStorage = flatParameters:storage()
for k = 1,#parameters do
@@ -171,21 +173,21 @@ function Module:getParameters()
parameters[k]:zero()
end
- local cumSumOfHoles = flatParameters:cumsum(1)
+ local cumSumOfHoles = flatParameters:float():cumsum(1)
local nUsedParameters = nParameters - cumSumOfHoles[#cumSumOfHoles]
- local flatUsedParameters = torch.Tensor(nUsedParameters)
+ local flatUsedParameters = Tensor(nUsedParameters)
local flatUsedStorage = flatUsedParameters:storage()
for k = 1,#parameters do
- local offset = cumSumOfHoles[parameters[k]:storageOffset()]
- parameters[k]:set(flatUsedStorage,
- parameters[k]:storageOffset() - offset,
- parameters[k]:size(),
- parameters[k]:stride())
+ local offset = cumSumOfHoles[parameters[k]:storageOffset()]
+ parameters[k]:set(flatUsedStorage,
+ parameters[k]:storageOffset() - offset,
+ parameters[k]:size(),
+ parameters[k]:stride())
end
for k, v in pairs(storages) do
- flatParameters[{{v+1,v+k:size()}}]:copy(torch.Tensor():set(k))
+ flatParameters[{{v+1,v+k:size()}}]:copy(Tensor():set(k))
end
for k = 1,flatUsedParameters:nElement() do
flatUsedParameters[k] = flatParameters[k+cumSumOfHoles[k] ]