diff options
author | koray kavukcuoglu <koray@kavukcuoglu.org> | 2015-03-12 18:37:52 +0300 |
---|---|---|
committer | koray kavukcuoglu <koray@kavukcuoglu.org> | 2015-03-12 18:37:52 +0300 |
commit | 989d5e51ba527f51ef71bd005950a94ce02566aa (patch) | |
tree | fd4bee58adb5e237b199403afcc0a5573fabf690 | |
parent | 53e4bedf82ece70c91dabe9b81a780e288fddecb (diff) | |
parent | 3a8726b8605ee82320db5684961b44b5d4892eb1 (diff) |
Merge pull request #41 from yozw/fix-nngraph-cuda
Include all data in type conversion when type() is called on an nn.gModule.
-rw-r--r-- | gmodule.lua | 27 | ||||
-rw-r--r-- | test/test_nngraph.lua | 12 |
2 files changed, 38 insertions, 1 deletions
diff --git a/gmodule.lua b/gmodule.lua index e69e79f..0f0f461 100644 --- a/gmodule.lua +++ b/gmodule.lua @@ -121,7 +121,34 @@ function gModule:evaluate() self:apply(function(module) module:evaluate() end) end +--[[ Recursively applies type(type_str) to any tensors in the argument. If the +argument is a tensor, type(type_str) is applied; if the argument is an array, +this function recurses into it. ]] +local function recursiveType(param, type_str) + if torch.type(param) == 'table' then + for i = 1, #param do + param[i] = recursiveType(param[i], type_str) + end + elseif torch.typename(param) and + torch.typename(param):find('torch%..+Tensor') then + param = param:type(type_str) + end + return param +end + function gModule:type(type) + local function applyTypeToTable(table) + for key, value in pairs(table) do + table[key] = recursiveType(table[key], type) + end + end + + -- Convert any stored data in self, and in the in and out nodes + applyTypeToTable(self) + if self.innode then applyTypeToTable(self.innode.data) end + if self.outnode then applyTypeToTable(self.outnode.data) end + + -- Loop through modules and convert data self:apply(function(module) module:type(type) end) return self end diff --git a/test/test_nngraph.lua b/test/test_nngraph.lua index 4062c5a..b709cda 100644 --- a/test/test_nngraph.lua +++ b/test/test_nngraph.lua @@ -203,14 +203,24 @@ function test.test_type() local in1 = nn.Linear(20,10)() local out1 = nn.Linear(10,1)(nn.Tanh()(nn.Linear(10,10)(nn.Tanh()(in1)))) local module = nn.gModule({in1}, {out1}) - local input = torch.rand(20) local output = module:forward(input) + module:backward(input, output) tester:eq(torch.typename(output), "torch.DoubleTensor") + tester:eq(torch.typename(module.output), "torch.DoubleTensor") + tester:eq(torch.typename(module.gradInput), "torch.DoubleTensor") + tester:eq(torch.typename(module.innode.data.input[1]), "torch.DoubleTensor") + tester:eq(torch.typename(module.outnode.data.input[1]), "torch.DoubleTensor") + tester:eq(torch.typename(module.forwardnodes[1].data.input[1]), "torch.DoubleTensor") module:float() local output = module:forward(input:float()) tester:eq(torch.typename(output), "torch.FloatTensor") + tester:eq(torch.typename(module.output), "torch.FloatTensor") + tester:eq(torch.typename(module.gradInput), "torch.FloatTensor") + tester:eq(torch.typename(module.innode.data.input[1]), "torch.FloatTensor") + tester:eq(torch.typename(module.outnode.data.input[1]), "torch.FloatTensor") + tester:eq(torch.typename(module.forwardnodes[1].data.input[1]), "torch.FloatTensor") end function test.test_nestedGradInput() |