Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nngraph.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2015-03-27 20:38:39 +0300
committerSoumith Chintala <soumith@gmail.com>2015-03-27 20:38:39 +0300
commit950c67076477b70ecf904ce899fda98fbada965d (patch)
tree8c86622fa9129af5feaf447b82517632b4b04e52
parent78c3649f7903454aa9043dcb8d639537c6d5f604 (diff)
parentf75010e827ad359b90e40dabe0e80318a99c7e6f (diff)
Merge pull request #45 from fidlej/topic_types
Showed the non-default tensor types.
-rw-r--r--gmodule.lua5
-rw-r--r--node.lua9
2 files changed, 13 insertions, 1 deletions
diff --git a/gmodule.lua b/gmodule.lua
index 0f0f461..8dc100d 100644
--- a/gmodule.lua
+++ b/gmodule.lua
@@ -352,3 +352,8 @@ function gModule:parameters()
end
return p,gp
end
+
+function gModule:__tostring__()
+ return self.name or torch.type(self)
+end
+
diff --git a/node.lua b/node.lua
index c01bdae..b620456 100644
--- a/node.lua
+++ b/node.lua
@@ -93,6 +93,9 @@ end
local function getNanFlag(data)
+ if data:nElement() == 0 then
+ return ''
+ end
local isNan = (data:ne(data):sum() > 0)
if isNan then
return 'NaN'
@@ -114,7 +117,11 @@ function nnNode:label()
if not data then return '' end
if istensor(data) then
local nanFlag = getNanFlag(data)
- return 'Tensor[' .. table.concat(data:size():totable(),'x') .. ']' .. nanFlag
+ local tensorType = 'Tensor'
+ if data:type() ~= torch.Tensor():type() then
+ tensorType = data:type()
+ end
+ return tensorType .. '[' .. table.concat(data:size():totable(),'x') .. ']' .. nanFlag
elseif istable(data) then
local tstr = {}
for i,v in ipairs(data) do