Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nngraph.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKoray Kavukcuoglu <koray@kavukcuoglu.org>2015-09-11 14:55:59 +0300
committerKoray Kavukcuoglu <koray@kavukcuoglu.org>2015-09-11 14:55:59 +0300
commit4635a92c8d0064a88ee208290828dc279a3ae648 (patch)
tree1592d2773aacbbe0e0f389f85fc3608edc70ca8e
parentfdf8c99c59959ebec7310c33ea185091f59bb818 (diff)
make sure forward/backward runs can deal with parameter nodes since theynnop
do not have any inputs coming in. add a display function that does not use qt, but browser
-rw-r--r--gmodule.lua13
-rw-r--r--graphinspecting.lua16
-rw-r--r--init.lua2
3 files changed, 31 insertions, 0 deletions
diff --git a/gmodule.lua b/gmodule.lua
index 3aba8a3..1d569b2 100644
--- a/gmodule.lua
+++ b/gmodule.lua
@@ -264,6 +264,11 @@ function gModule:runForwardFunction(func,input)
propagate(node,input)
else
local input = node.data.input
+
+ -- a parameter node is captured
+ if input == nil and node.data.module ~= nil then
+ input = {}
+ end
if #input == 1 then
input = input[1]
end
@@ -340,6 +345,10 @@ function gModule:updateGradInput(input,gradOutput)
gradInput = gradOutput
else
local input = node.data.input
+ -- a parameter node is captured
+ if input == nil and node.data.module ~= nil then
+ input = {}
+ end
if #input == 1 then
input = input[1]
end
@@ -395,6 +404,10 @@ function gModule:accGradParameters(input,gradOutput,lr)
gradOutput = node.data.gradOutputBuffer
end
local input = node.data.input
+ -- a parameter node is captured
+ if input == nil and node.data.module ~= nil then
+ input = {}
+ end
if #input == 1 then
input = input[1]
end
diff --git a/graphinspecting.lua b/graphinspecting.lua
index 0ccb168..e0676c4 100644
--- a/graphinspecting.lua
+++ b/graphinspecting.lua
@@ -140,3 +140,19 @@ function nngraph.annotateNodes(stackLevel)
end
end
+--[[
+ SVG visualization for gmodule
+ TODO: add custom coloring with node types
+]]
+function nngraph.display(gmodule)
+ local ffi = require 'ffi'
+ local cmd
+ if ffi.os == 'Linux' then
+ cmd = 'xdg-open'
+ elseif ffi.os == 'OSX' then
+ cmd = 'open -a Safari'
+ end
+ local fname = os.tmpname()
+ graph.dot(gmodule.fg, fname, fname)
+ os.execute(cmd .. ' ' .. fname .. '.svg')
+end
diff --git a/init.lua b/init.lua
index a009eb0..a76b80d 100644
--- a/init.lua
+++ b/init.lua
@@ -46,3 +46,5 @@ local Criterion = torch.getmetatable('nn.Criterion')
function Criterion:__call__(...)
return nn.ModuleFromCriterion(self)(...)
end
+
+return nngraph