Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nngraph.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorkoray kavukcuoglu <koray@kavukcuoglu.org>2015-03-12 18:37:52 +0300
committerkoray kavukcuoglu <koray@kavukcuoglu.org>2015-03-12 18:37:52 +0300
commit989d5e51ba527f51ef71bd005950a94ce02566aa (patch)
treefd4bee58adb5e237b199403afcc0a5573fabf690
parent53e4bedf82ece70c91dabe9b81a780e288fddecb (diff)
parent3a8726b8605ee82320db5684961b44b5d4892eb1 (diff)
Merge pull request #41 from yozw/fix-nngraph-cuda
Include all data in type conversion when type() is called on an nn.gModule.
-rw-r--r--gmodule.lua27
-rw-r--r--test/test_nngraph.lua12
2 files changed, 38 insertions, 1 deletions
diff --git a/gmodule.lua b/gmodule.lua
index e69e79f..0f0f461 100644
--- a/gmodule.lua
+++ b/gmodule.lua
@@ -121,7 +121,34 @@ function gModule:evaluate()
self:apply(function(module) module:evaluate() end)
end
+--[[ Recursively applies type(type_str) to any tensors in the argument. If the
+argument is a tensor, type(type_str) is applied; if the argument is an array,
+this function recurses into it. ]]
+local function recursiveType(param, type_str)
+ if torch.type(param) == 'table' then
+ for i = 1, #param do
+ param[i] = recursiveType(param[i], type_str)
+ end
+ elseif torch.typename(param) and
+ torch.typename(param):find('torch%..+Tensor') then
+ param = param:type(type_str)
+ end
+ return param
+end
+
function gModule:type(type)
+ local function applyTypeToTable(table)
+ for key, value in pairs(table) do
+ table[key] = recursiveType(table[key], type)
+ end
+ end
+
+ -- Convert any stored data in self, and in the in and out nodes
+ applyTypeToTable(self)
+ if self.innode then applyTypeToTable(self.innode.data) end
+ if self.outnode then applyTypeToTable(self.outnode.data) end
+
+ -- Loop through modules and convert data
self:apply(function(module) module:type(type) end)
return self
end
diff --git a/test/test_nngraph.lua b/test/test_nngraph.lua
index 4062c5a..b709cda 100644
--- a/test/test_nngraph.lua
+++ b/test/test_nngraph.lua
@@ -203,14 +203,24 @@ function test.test_type()
local in1 = nn.Linear(20,10)()
local out1 = nn.Linear(10,1)(nn.Tanh()(nn.Linear(10,10)(nn.Tanh()(in1))))
local module = nn.gModule({in1}, {out1})
-
local input = torch.rand(20)
local output = module:forward(input)
+ module:backward(input, output)
tester:eq(torch.typename(output), "torch.DoubleTensor")
+ tester:eq(torch.typename(module.output), "torch.DoubleTensor")
+ tester:eq(torch.typename(module.gradInput), "torch.DoubleTensor")
+ tester:eq(torch.typename(module.innode.data.input[1]), "torch.DoubleTensor")
+ tester:eq(torch.typename(module.outnode.data.input[1]), "torch.DoubleTensor")
+ tester:eq(torch.typename(module.forwardnodes[1].data.input[1]), "torch.DoubleTensor")
module:float()
local output = module:forward(input:float())
tester:eq(torch.typename(output), "torch.FloatTensor")
+ tester:eq(torch.typename(module.output), "torch.FloatTensor")
+ tester:eq(torch.typename(module.gradInput), "torch.FloatTensor")
+ tester:eq(torch.typename(module.innode.data.input[1]), "torch.FloatTensor")
+ tester:eq(torch.typename(module.outnode.data.input[1]), "torch.FloatTensor")
+ tester:eq(torch.typename(module.forwardnodes[1].data.input[1]), "torch.FloatTensor")
end
function test.test_nestedGradInput()