diff options
-rw-r--r-- | node.lua | 20 |
1 files changed, 15 insertions, 5 deletions
@@ -89,6 +89,20 @@ function nnNode:graphNodeAttributes() end +local function getNanFlag(data) + local isNan = (data:ne(data):sum() > 0) + if isNan then + return 'NaN' + end + if data:max() == math.huge then + return 'inf' + end + if data:min() == -math.huge then + return '-inf' + end + return '' +end + function nnNode:label() local lbl = {} @@ -96,11 +110,7 @@ function nnNode:label() local function getstr(data) if not data then return '' end if istensor(data) then - local nanFlag = '' - local isNan = (data:ne(data):sum() > 0) - if isNan then - nanFlag = 'NaN' - end + local nanFlag = getNanFlag(data) return 'Tensor[' .. table.concat(data:size():totable(),'x') .. ']' .. nanFlag elseif istable(data) then local tstr = {} |