From 4635a92c8d0064a88ee208290828dc279a3ae648 Mon Sep 17 00:00:00 2001 From: Koray Kavukcuoglu Date: Fri, 11 Sep 2015 12:55:59 +0100 Subject: make sure forward/backward runs can deal with parameter nodes since they do not have any inputs coming in. add a display function that does not use qt, but browser --- gmodule.lua | 13 +++++++++++++ graphinspecting.lua | 16 ++++++++++++++++ init.lua | 2 ++ 3 files changed, 31 insertions(+) diff --git a/gmodule.lua b/gmodule.lua index 3aba8a3..1d569b2 100644 --- a/gmodule.lua +++ b/gmodule.lua @@ -264,6 +264,11 @@ function gModule:runForwardFunction(func,input) propagate(node,input) else local input = node.data.input + + -- a parameter node is captured + if input == nil and node.data.module ~= nil then + input = {} + end if #input == 1 then input = input[1] end @@ -340,6 +345,10 @@ function gModule:updateGradInput(input,gradOutput) gradInput = gradOutput else local input = node.data.input + -- a parameter node is captured + if input == nil and node.data.module ~= nil then + input = {} + end if #input == 1 then input = input[1] end @@ -395,6 +404,10 @@ function gModule:accGradParameters(input,gradOutput,lr) gradOutput = node.data.gradOutputBuffer end local input = node.data.input + -- a parameter node is captured + if input == nil and node.data.module ~= nil then + input = {} + end if #input == 1 then input = input[1] end diff --git a/graphinspecting.lua b/graphinspecting.lua index 0ccb168..e0676c4 100644 --- a/graphinspecting.lua +++ b/graphinspecting.lua @@ -140,3 +140,19 @@ function nngraph.annotateNodes(stackLevel) end end +--[[ + SVG visualization for gmodule + TODO: add custom coloring with node types +]] +function nngraph.display(gmodule) + local ffi = require 'ffi' + local cmd + if ffi.os == 'Linux' then + cmd = 'xdg-open' + elseif ffi.os == 'OSX' then + cmd = 'open -a Safari' + end + local fname = os.tmpname() + graph.dot(gmodule.fg, fname, fname) + os.execute(cmd .. ' ' .. fname .. '.svg') +end diff --git a/init.lua b/init.lua index a009eb0..a76b80d 100644 --- a/init.lua +++ b/init.lua @@ -46,3 +46,5 @@ local Criterion = torch.getmetatable('nn.Criterion') function Criterion:__call__(...) return nn.ModuleFromCriterion(self)(...) end + +return nngraph -- cgit v1.2.3