Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nngraph.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorkoray kavukcuoglu <koray@kavukcuoglu.org>2015-06-18 20:44:22 +0300
committerkoray kavukcuoglu <koray@kavukcuoglu.org>2015-06-18 20:44:22 +0300
commit045c17eab6ec10e6d90b19216c52058afb731656 (patch)
tree6122286066a45a4b19eb37e0f8d78eb67173b373
parenteb65b3602cdf3214ff671d5422521ad414f1c3b0 (diff)
big cleanup, move graphviz related bits to graphviz.lua filescleanup
-rw-r--r--doc/example6.lua31
-rw-r--r--graphviz.lua82
-rw-r--r--init.lua4
-rw-r--r--node.lua146
-rw-r--r--rocks/nngraph-scm-1.rockspec (renamed from nngraph-scm-1.rockspec)0
-rw-r--r--test/test_nngraph.lua502
6 files changed, 366 insertions, 399 deletions
diff --git a/doc/example6.lua b/doc/example6.lua
new file mode 100644
index 0000000..39b440f
--- /dev/null
+++ b/doc/example6.lua
@@ -0,0 +1,31 @@
+require 'nngraph'
+
+-- generate SVG of the graph with the problem node highlighted
+-- and hover over the nodes in svg to see the filename:line_number info
+-- nodes will be annotated with local variable names even if debug mode is not enabled.
+nngraph.setDebug(true)
+
+local function get_net(from, to)
+ local from = from or 10
+ local to = to or 10
+ local input_x = nn.Identity()()
+ local linear_module = nn.Linear(from, to)(input_x)
+
+ -- Annotate nodes with local variable names
+ nngraph.annotateNodes()
+ return nn.gModule({input_x},{linear_module})
+end
+
+local net = get_net(10,10)
+
+-- if you give a name to the net, it will use that name to produce the
+-- svg in case of error, if not, it will come up with a name
+-- that is derived from number of inputs and outputs to the graph
+net.name = 'my_bad_linear_net'
+
+-- prepare an input that is of the wrong size to force an error
+local input = torch.rand(11)
+pcall(function() net:updateOutput(input) end)
+-- it should have produced an error and spit out a graph
+-- just run Safari to display the svg
+os.execute('open -a Safari my_bad_linear_net.svg')
diff --git a/graphviz.lua b/graphviz.lua
new file mode 100644
index 0000000..092877a
--- /dev/null
+++ b/graphviz.lua
@@ -0,0 +1,82 @@
+-- handy functions
+local utils = paths.dofile('utils.lua')
+local istensor = utils.istensor
+local istable = utils.istable
+local istorchclass = utils.istorchclass
+
+local function getNanFlag(data)
+ if data:nElement() == 0 then
+ return ''
+ end
+ local isNan = (data:ne(data):sum() > 0)
+ if isNan then
+ return 'NaN'
+ end
+ if data:max() == math.huge then
+ return 'inf'
+ end
+ if data:min() == -math.huge then
+ return '-inf'
+ end
+ return ''
+end
+local function getstr(data)
+ if not data then return '' end
+ if istensor(data) then
+ local nanFlag = getNanFlag(data)
+ local tensorType = 'Tensor'
+ if data:type() ~= torch.Tensor():type() then
+ tensorType = data:type()
+ end
+ return tensorType .. '[' .. table.concat(data:size():totable(),'x') .. ']' .. nanFlag
+ elseif istable(data) then
+ local tstr = {}
+ for i,v in ipairs(data) do
+ table.insert(tstr, getstr(v))
+ end
+ return '{' .. table.concat(tstr,',') .. '}'
+ else
+ return tostring(data):gsub('\n','\\l')
+ end
+end
+local function getmapindexstr(mapindex)
+ local tstr = {}
+ for i,data in ipairs(mapindex) do
+ local inputId = 'Node' .. (data.forwardNodeId or '')
+ table.insert(tstr, inputId)
+ end
+ return '{' .. table.concat(tstr,',') .. '}'
+end
+
+local Node = torch.getmetatable('nngraph.Node')
+
+
+--[[
+Returns a textual representation of the Node that can be used by graphviz library visualization.
+]]
+function Node:label()
+
+ local lbl = {}
+
+ for k,v in pairs(self.data) do
+ local vstr = ''
+ if k == 'mapindex' then
+ if #v > 1 then
+ vstr = getmapindexstr(v)
+ table.insert(lbl, k .. ' = ' .. vstr)
+ end
+ elseif k == 'forwardNodeId' or k == 'annotations' then
+ -- the forwardNodeId is not displayed in the label.
+ else
+ vstr = getstr(v)
+ table.insert(lbl, k .. ' = ' .. vstr)
+ end
+ end
+
+ local desc = ''
+ if self.data.annotations.description then
+ desc = 'desc = ' .. self.data.annotations.description .. '\\n'
+ end
+ return desc .. table.concat(lbl,"\\l")
+end
+
diff --git a/init.lua b/init.lua
index 4d340f3..e0a966d 100644
--- a/init.lua
+++ b/init.lua
@@ -6,6 +6,7 @@ nngraph = {}
torch.include('nngraph','node.lua')
torch.include('nngraph','gmodule.lua')
+torch.include('nngraph','graphviz.lua')
torch.include('nngraph','graphinspecting.lua')
torch.include('nngraph','ModuleFromCriterion.lua')
@@ -36,7 +37,7 @@ function Module:__call__(...)
for i,dnode in ipairs(input) do
if torch.typename(dnode) ~= 'nngraph.Node' then
- error('what is this in the input? ' .. tostring(dnode))
+ error('Expected nngraph.Node type, what is this in the input? ' .. tostring(dnode))
end
mnode:add(dnode,true)
end
@@ -44,6 +45,7 @@ function Module:__call__(...)
return mnode
end
+-- Modify the __call function to hack into nn.Criterion
local Criterion = torch.getmetatable('nn.Criterion')
function Criterion:__call__(...)
return nn.ModuleFromCriterion(self)(...)
diff --git a/node.lua b/node.lua
index b620456..1c68ba7 100644
--- a/node.lua
+++ b/node.lua
@@ -1,30 +1,23 @@
-local utils = paths.dofile('utils.lua')
-local istensor = utils.istensor
-local istable = utils.istable
-local istorchclass = utils.istorchclass
-require 'debug'
-
-
-local nnNode,parent = torch.class('nngraph.Node','graph.Node')
-
+--[[
+This file implements the nngraph.Node. In addition to graph.Node this class
+provides some additional functionality for handling neural networks in a graph
+]]
+local nnNode,parent = torch.class('nngraph.Node','graph.AnnotatedNode')
+
+
+--[[
+nngraph.Node
+Args:
+* `data` - the same as graph.Node(data). Any object type that will be stored as data
+in the graph node.
+]]
function nnNode:__init(data)
- parent.__init(self,data)
- self.data.annotations = self.data.annotations or {}
- self.data.mapindex = self.data.mapindex or {}
- if not self.data.annotations._debugLabel then
- self:_makeDebugLabel(debug.getinfo(6, 'Sl'))
- end
-end
-
-
---[[ Build a string label which will be used a tooltip when
- making a graph.]]
-function nnNode:_makeDebugLabel(dinfo)
- if dinfo then
- self.data.annotations._debugLabel = string.format('[%s]:%d',
- dinfo.short_src, dinfo.currentline, dinfo.name)
- end
+ -- level 7 corresponds to level with the nngraph usage of nnNode's
+ -- inside Module:__call() syntax
+ parent.__init(self,data, 7)
+ -- decorate the data with additional info to keep track of order of connected nodes
+ self.data.mapindex = data.mapindex or {}
end
@@ -61,106 +54,3 @@ function nnNode:split(noutput)
return unpack(selectnodes)
end
-
-function nnNode:annotate(annotations)
- for k, v in pairs(annotations) do
- self.data.annotations[k] = v
- end
-
- return self
-end
-
-
-function nnNode:graphNodeName()
- if self.data.annotations.name then
- return self.data.annotations.name .. ' (' .. self.id .. ')'
- else
- return 'Node' .. self.id
- end
-end
-
-
-function nnNode:graphNodeAttributes()
- self.data.annotations.graphAttributes =
- self.data.annotations.graphAttributes or {}
- if not self.data.annotations.graphAttributes.tooltip then
- self.data.annotations.graphAttributes.tooltip =
- self.data.annotations._debugLabel
- end
-
- return self.data.annotations.graphAttributes
-end
-
-
-local function getNanFlag(data)
- if data:nElement() == 0 then
- return ''
- end
- local isNan = (data:ne(data):sum() > 0)
- if isNan then
- return 'NaN'
- end
- if data:max() == math.huge then
- return 'inf'
- end
- if data:min() == -math.huge then
- return '-inf'
- end
- return ''
-end
-
-function nnNode:label()
-
- local lbl = {}
-
- local function getstr(data)
- if not data then return '' end
- if istensor(data) then
- local nanFlag = getNanFlag(data)
- local tensorType = 'Tensor'
- if data:type() ~= torch.Tensor():type() then
- tensorType = data:type()
- end
- return tensorType .. '[' .. table.concat(data:size():totable(),'x') .. ']' .. nanFlag
- elseif istable(data) then
- local tstr = {}
- for i,v in ipairs(data) do
- table.insert(tstr, getstr(v))
- end
- return '{' .. table.concat(tstr,',') .. '}'
- else
- return tostring(data):gsub('\n','\\l')
- end
- end
- local function getmapindexstr(mapindex)
- local tstr = {}
- for i,data in ipairs(mapindex) do
- local inputId = 'Node' .. (data.forwardNodeId or '')
- table.insert(tstr, inputId)
- end
- return '{' .. table.concat(tstr,',') .. '}'
- end
-
- for k,v in pairs(self.data) do
- local vstr = ''
- if k== 'mapindex' then
- if #v > 1 then
- vstr = getmapindexstr(v)
- table.insert(lbl, k .. ' = ' .. vstr)
- end
- elseif k== 'forwardNodeId' or k== 'annotations' then
- -- the forwardNodeId is not displayed in the label.
- else
- vstr = getstr(v)
- table.insert(lbl, k .. ' = ' .. vstr)
- end
- end
-
- local desc
- if self.data.annotations.description then
- desc = 'desc = ' .. self.data.annotations.description .. '\\n'
- else
- desc = ''
- end
- return desc .. table.concat(lbl,"\\l")
-end
diff --git a/nngraph-scm-1.rockspec b/rocks/nngraph-scm-1.rockspec
index 4963ecd..4963ecd 100644
--- a/nngraph-scm-1.rockspec
+++ b/rocks/nngraph-scm-1.rockspec
diff --git a/test/test_nngraph.lua b/test/test_nngraph.lua
index ac47be2..8dac03a 100644
--- a/test/test_nngraph.lua
+++ b/test/test_nngraph.lua
@@ -5,358 +5,320 @@ local test = {}
local tester = totem.Tester()
local function checkGradients(...)
- totem.nn.checkGradients(tester, ...)
+ totem.nn.checkGradients(tester, ...)
end
function test.test_oneOutput()
- local in1 = nn.Identity()()
- local out1 = nn.Identity()(in1)
- local module = nn.gModule({in1}, {out1})
-
- local input = torch.Tensor({1})
- module:forward(input)
- tester:eq(module.output, torch.Tensor{1}, "output")
- local gradInput = module:backward(input, torch.Tensor({-123}))
- tester:eq(gradInput, torch.Tensor{-123}, "gradInput")
-
- local input2 = torch.Tensor({2})
- module:forward(input2)
- tester:eq(module.output, torch.Tensor{2}, "output for input2")
- gradInput = module:backward(input2, torch.Tensor({-2}))
- tester:eq(gradInput, torch.Tensor{-2}, "expecting a recomputed gradInput")
+ local in1 = nn.Identity()()
+ local out1 = nn.Identity()(in1)
+ local module = nn.gModule({in1}, {out1})
+
+ local input = torch.Tensor({1})
+ module:forward(input)
+ tester:eq(module.output, torch.Tensor{1}, "output")
+ local gradInput = module:backward(input, torch.Tensor({-123}))
+ tester:eq(gradInput, torch.Tensor{-123}, "gradInput")
+
+ local input2 = torch.Tensor({2})
+ module:forward(input2)
+ tester:eq(module.output, torch.Tensor{2}, "output for input2")
+ gradInput = module:backward(input2, torch.Tensor({-2}))
+ tester:eq(gradInput, torch.Tensor{-2}, "expecting a recomputed gradInput")
end
function test.test_twoOutputs()
- local in1 = nn.Identity()()
- local out1 = nn.Identity()(in1)
- local out2 = nn.Identity()(in1)
- local module = nn.gModule({in1}, {out1, out2})
-
- local input = torch.Tensor({1})
- module:forward(input)
- local gradInput = module:backward(input, {torch.Tensor({-2}), torch.Tensor({-3})})
- tester:eq(gradInput, torch.Tensor{-5}, "gradInput of a fork")
- checkGradients(module, input)
+ local in1 = nn.Identity()()
+ local out1 = nn.Identity()(in1)
+ local out2 = nn.Identity()(in1)
+ local module = nn.gModule({in1}, {out1, out2})
+
+ local input = torch.Tensor({1})
+ module:forward(input)
+ local gradInput = module:backward(input, {torch.Tensor({-2}), torch.Tensor({-3})})
+ tester:eq(gradInput, torch.Tensor{-5}, "gradInput of a fork")
+ checkGradients(module, input)
end
function test.test_twoGradOutputs()
- local in1 = nn.Sigmoid()()
- local splitTable = nn.SplitTable(1)({in1})
- local out1, out2 = splitTable:split(2)
- local module = nn.gModule({in1}, {out1, out2})
-
- local input = torch.randn(2, 3)
- local output = module:forward(input)
- assert(#output == 2, "wrong number of outputs")
- module:backward(input, {torch.randn(3), torch.randn(3)})
- checkGradients(module, input)
+ local in1 = nn.Sigmoid()()
+ local splitTable = nn.SplitTable(1)({in1})
+ local out1, out2 = splitTable:split(2)
+ local module = nn.gModule({in1}, {out1, out2})
+
+ local input = torch.randn(2, 3)
+ local output = module:forward(input)
+ assert(#output == 2, "wrong number of outputs")
+ module:backward(input, {torch.randn(3), torch.randn(3)})
+ checkGradients(module, input)
end
function test.test_twoInputs()
- local in1 = nn.Identity()()
- local in2 = nn.Identity()()
- local prevH, prevCell = in2:split(2)
-
- local out1 = nn.CMulTable()({in1, prevH, prevCell})
- local module = nn.gModule({in1, in2}, {out1})
-
- local input = {torch.randn(3), {torch.randn(3), torch.randn(3)}}
- module:forward(input)
- local gradInput = module:backward(input, torch.randn(3))
- assert(#gradInput == 2, "wrong number of gradInputs")
- assert(type(gradInput[2]) == "table", "wrong gradInput[2] type")
- checkGradients(module, input)
+ local in1 = nn.Identity()()
+ local in2 = nn.Identity()()
+ local prevH, prevCell = in2:split(2)
+
+ local out1 = nn.CMulTable()({in1, prevH, prevCell})
+ local module = nn.gModule({in1, in2}, {out1})
+
+ local input = {torch.randn(3), {torch.randn(3), torch.randn(3)}}
+ module:forward(input)
+ local gradInput = module:backward(input, torch.randn(3))
+ assert(#gradInput == 2, "wrong number of gradInputs")
+ assert(type(gradInput[2]) == "table", "wrong gradInput[2] type")
+ checkGradients(module, input)
end
function test.test_twoInputs2()
- local in1 = nn.Sigmoid()()
- local in2 = nn.Sigmoid()()
- local module = nn.gModule({in1, in2}, {in1, in2, nn.Sigmoid()(in1)})
-
- local input = {torch.randn(3), torch.randn(3)}
- module:forward(input)
- local gradInput = module:backward(input, {torch.randn(3), torch.randn(3), torch.randn(3)})
- checkGradients(module, input)
+ local in1 = nn.Sigmoid()()
+ local in2 = nn.Sigmoid()()
+ local module = nn.gModule({in1, in2}, {in1, in2, nn.Sigmoid()(in1)})
+
+ local input = {torch.randn(3), torch.randn(3)}
+ module:forward(input)
+ local gradInput = module:backward(input, {torch.randn(3), torch.randn(3), torch.randn(3)})
+ checkGradients(module, input)
end
function test.test_splitDebugLabels()
- local node = nn.Identity()()
- node.data.annotations._debugLabel = "node"
- local node1, node2 = node:split(2)
- assert(node1.data.annotations._debugLabel == "node-1")
- assert(node2.data.annotations._debugLabel == "node-2")
+ local node = nn.Identity()()
+ node.data.annotations._debugLabel = "node"
+ local node1, node2 = node:split(2)
+ assert(node1.data.annotations._debugLabel == "node-1")
+ assert(node2.data.annotations._debugLabel == "node-2")
end
function test.test_identity()
- local in1 = nn.Identity()()
- local in2 = nn.Identity()()
- local module = nn.gModule({in1, in2}, {in1, in2, nn.Identity()(in1)})
-
- local input = {torch.randn(3), torch.randn(3)}
- module:forward(input)
- module:backward(input, {torch.randn(3), torch.randn(3), torch.randn(3)})
- checkGradients(module, input)
+ local in1 = nn.Identity()()
+ local in2 = nn.Identity()()
+ local module = nn.gModule({in1, in2}, {in1, in2, nn.Identity()(in1)})
+
+ local input = {torch.randn(3), torch.randn(3)}
+ module:forward(input)
+ module:backward(input, {torch.randn(3), torch.randn(3), torch.randn(3)})
+ checkGradients(module, input)
end
function test.test_gradInputType()
- local xInput = torch.randn(3)
- local h = torch.randn(3)
-
- local x = nn.Identity()()
- local prevRnnState = nn.Identity()()
- local prevH1, prevCell = prevRnnState:split(2)
- local prevH = prevH1
-
- local cellOut = nn.CAddTable()({
- nn.CMulTable()({x, prevH}),
- nn.CMulTable()({prevH, prevCell})})
- local module = nn.gModule({x, prevRnnState}, {cellOut})
-
- local c = torch.randn(h:size())
- local prevRnnState = {h, c}
- local input = {xInput, prevRnnState}
- local output = module:forward(input)
-
- local gradOutput = torch.randn(h:size())
- local gradInput = module:backward(input, gradOutput)
-
- local gradX, gradPrevState = unpack(gradInput)
- local gradPrevH, gradPrevCell = unpack(gradPrevState)
- assert(type(gradPrevH) == type(h), "wrong gradPrevH type")
-
- tester:eq(type(gradPrevH), type(h), "wrong gradPrevH type")
- tester:eq(gradPrevH:size(), h:size(), "wrong gradPrevH size")
- checkGradients(module, input)
+ local xInput = torch.randn(3)
+ local h = torch.randn(3)
+
+ local x = nn.Identity()()
+ local prevRnnState = nn.Identity()()
+ local prevH1, prevCell = prevRnnState:split(2)
+ local prevH = prevH1
+
+ local cellOut = nn.CAddTable()({
+ nn.CMulTable()({x, prevH}),
+ nn.CMulTable()({prevH, prevCell})})
+ local module = nn.gModule({x, prevRnnState}, {cellOut})
+
+ local c = torch.randn(h:size())
+ local prevRnnState = {h, c}
+ local input = {xInput, prevRnnState}
+ local output = module:forward(input)
+
+ local gradOutput = torch.randn(h:size())
+ local gradInput = module:backward(input, gradOutput)
+
+ local gradX, gradPrevState = unpack(gradInput)
+ local gradPrevH, gradPrevCell = unpack(gradPrevState)
+ assert(type(gradPrevH) == type(h), "wrong gradPrevH type")
+
+ tester:eq(type(gradPrevH), type(h), "wrong gradPrevH type")
+ tester:eq(gradPrevH:size(), h:size(), "wrong gradPrevH size")
+ checkGradients(module, input)
end
function test.test_tabularInput()
- local in1 = nn.SplitTable(1)()
- local out1 = nn.CAddTable()(in1)
- local module = nn.gModule({in1}, {out1})
+ local in1 = nn.SplitTable(1)()
+ local out1 = nn.CAddTable()(in1)
+ local module = nn.gModule({in1}, {out1})
- local input = torch.randn(2, 3)
- checkGradients(module, input)
+ local input = torch.randn(2, 3)
+ checkGradients(module, input)
end
function test.test_extraTable()
- local in1 = nn.Identity()()
- local out1 = nn.Identity()(in1)
- local module = nn.gModule({in1}, {out1})
+ local in1 = nn.Identity()()
+ local out1 = nn.Identity()(in1)
+ local module = nn.gModule({in1}, {out1})
- local input = torch.Tensor({123})
- tester:eq(module:forward(input), input, "simple output")
- tester:eq(module:forward({input}), {input}, "tabular output")
+ local input = torch.Tensor({123})
+ tester:eq(module:forward(input), input, "simple output")
+ tester:eq(module:forward({input}), {input}, "tabular output")
end
function test.test_accGradParameters()
- local input = torch.randn(10)
+ local input = torch.randn(10)
- local in1 = nn.CMul(input:nElement())()
- local out1 = nn.Identity()(in1)
- local out2 = nn.Identity()(in1)
- local module = nn.gModule({in1}, {out1, out2})
- checkGradients(module, input)
+ local in1 = nn.CMul(input:nElement())()
+ local out1 = nn.Identity()(in1)
+ local out2 = nn.Identity()(in1)
+ local module = nn.gModule({in1}, {out1, out2})
+ checkGradients(module, input)
end
function test.test_example1()
- local x1 = nn.Linear(20,10)()
- local mout = nn.Linear(10,1)(nn.Tanh()(nn.Linear(10,10)(nn.Tanh()(x1))))
- local mlp = nn.gModule({x1},{mout})
+ local x1 = nn.Linear(20,10)()
+ local mout = nn.Linear(10,1)(nn.Tanh()(nn.Linear(10,10)(nn.Tanh()(x1))))
+ local mlp = nn.gModule({x1},{mout})
- local x = torch.rand(20)
- checkGradients(mlp, x)
+ local x = torch.rand(20)
+ checkGradients(mlp, x)
end
function test.test_example2()
- local x1=nn.Linear(20,20)()
- local x2=nn.Linear(10,10)()
- local m0=nn.Linear(20,1)(nn.Tanh()(x1))
- local m1=nn.Linear(10,1)(nn.Tanh()(x2))
- local madd=nn.CAddTable()({m0,m1})
- local m2=nn.Sigmoid()(madd)
- local m3=nn.Tanh()(madd)
- local gmod = nn.gModule({x1,x2},{m2,m3})
-
- local x = torch.rand(20)
- local y = torch.rand(10)
- checkGradients(gmod, {x, y})
+ local x1=nn.Linear(20,20)()
+ local x2=nn.Linear(10,10)()
+ local m0=nn.Linear(20,1)(nn.Tanh()(x1))
+ local m1=nn.Linear(10,1)(nn.Tanh()(x2))
+ local madd=nn.CAddTable()({m0,m1})
+ local m2=nn.Sigmoid()(madd)
+ local m3=nn.Tanh()(madd)
+ local gmod = nn.gModule({x1,x2},{m2,m3})
+
+ local x = torch.rand(20)
+ local y = torch.rand(10)
+ checkGradients(gmod, {x, y})
end
function test.test_example3()
- local m = nn.Sequential()
- m:add(nn.SplitTable(1))
- m:add(nn.ParallelTable():add(nn.Linear(10,20)):add(nn.Linear(10,30)))
- local input = nn.Identity()()
- local input1,input2 = m(input):split(2)
- local m3 = nn.JoinTable(1)({input1,input2})
- local g = nn.gModule({input},{m3})
-
- local indata = torch.rand(2,10)
- checkGradients(g, indata)
+ local m = nn.Sequential()
+ m:add(nn.SplitTable(1))
+ m:add(nn.ParallelTable():add(nn.Linear(10,20)):add(nn.Linear(10,30)))
+ local input = nn.Identity()()
+ local input1,input2 = m(input):split(2)
+ local m3 = nn.JoinTable(1)({input1,input2})
+ local g = nn.gModule({input},{m3})
+
+ local indata = torch.rand(2,10)
+ checkGradients(g, indata)
end
function test.test_example4()
- local input = nn.Identity()()
- local L1 = nn.Tanh()(nn.Linear(1,2)(input))
- local L2 = nn.Tanh()(nn.Linear(3,6)(nn.JoinTable(1)({input,L1})))
- local L3 = nn.Tanh()(nn.Linear(8,16)(nn.JoinTable(1)({L1,L2})))
- local g = nn.gModule({input},{L3})
-
- local indata = torch.rand(1)
- checkGradients(g, indata)
+ local input = nn.Identity()()
+ local L1 = nn.Tanh()(nn.Linear(1,2)(input))
+ local L2 = nn.Tanh()(nn.Linear(3,6)(nn.JoinTable(1)({input,L1})))
+ local L3 = nn.Tanh()(nn.Linear(8,16)(nn.JoinTable(1)({L1,L2})))
+ local g = nn.gModule({input},{L3})
+
+ local indata = torch.rand(1)
+ checkGradients(g, indata)
end
function test.test_type()
- local in1 = nn.Linear(20,10)()
- local out1 = nn.Linear(10,1)(nn.Tanh()(nn.Linear(10,10)(nn.Tanh()(in1))))
- local module = nn.gModule({in1}, {out1})
- local input = torch.rand(20)
- local output = module:forward(input)
- module:backward(input, output)
- tester:eq(torch.typename(output), "torch.DoubleTensor")
- tester:eq(torch.typename(module.output), "torch.DoubleTensor")
- tester:eq(torch.typename(module.gradInput), "torch.DoubleTensor")
- tester:eq(torch.typename(module.innode.data.input[1]), "torch.DoubleTensor")
- tester:eq(torch.typename(module.outnode.data.input[1]), "torch.DoubleTensor")
- tester:eq(torch.typename(module.forwardnodes[1].data.input[1]), "torch.DoubleTensor")
-
- module:float()
- local output = module:forward(input:float())
- tester:eq(torch.typename(output), "torch.FloatTensor")
- tester:eq(torch.typename(module.output), "torch.FloatTensor")
- tester:eq(torch.typename(module.gradInput), "torch.FloatTensor")
- tester:eq(torch.typename(module.innode.data.input[1]), "torch.FloatTensor")
- tester:eq(torch.typename(module.outnode.data.input[1]), "torch.FloatTensor")
- tester:eq(torch.typename(module.forwardnodes[1].data.input[1]), "torch.FloatTensor")
+ local in1 = nn.Linear(20,10)()
+ local out1 = nn.Linear(10,1)(nn.Tanh()(nn.Linear(10,10)(nn.Tanh()(in1))))
+ local module = nn.gModule({in1}, {out1})
+ local input = torch.rand(20)
+ local output = module:forward(input)
+ module:backward(input, output)
+ tester:eq(torch.typename(output), "torch.DoubleTensor")
+ tester:eq(torch.typename(module.output), "torch.DoubleTensor")
+ tester:eq(torch.typename(module.gradInput), "torch.DoubleTensor")
+ tester:eq(torch.typename(module.innode.data.input[1]), "torch.DoubleTensor")
+ tester:eq(torch.typename(module.outnode.data.input[1]), "torch.DoubleTensor")
+ tester:eq(torch.typename(module.forwardnodes[1].data.input[1]), "torch.DoubleTensor")
+
+ module:float()
+ local output = module:forward(input:float())
+ tester:eq(torch.typename(output), "torch.FloatTensor")
+ tester:eq(torch.typename(module.output), "torch.FloatTensor")
+ tester:eq(torch.typename(module.gradInput), "torch.FloatTensor")
+ tester:eq(torch.typename(module.innode.data.input[1]), "torch.FloatTensor")
+ tester:eq(torch.typename(module.outnode.data.input[1]), "torch.FloatTensor")
+ tester:eq(torch.typename(module.forwardnodes[1].data.input[1]), "torch.FloatTensor")
end
function test.test_nestedGradInput()
- local x = nn.Identity()()
- local h1 = nn.Sequential():add(nn.JoinTable(2)):add(nn.Tanh())
- local h2 = nn.Sequential():add(nn.JoinTable(2)):add(nn.Identity())
- local out = nn.CAddTable()({h1(x), h2(x)})
+ local x = nn.Identity()()
+ local h1 = nn.Sequential():add(nn.JoinTable(2)):add(nn.Tanh())
+ local h2 = nn.Sequential():add(nn.JoinTable(2)):add(nn.Identity())
+ local out = nn.CAddTable()({h1(x), h2(x)})
- local model = nn.gModule({x}, {out})
+ local model = nn.gModule({x}, {out})
- local input = {}
- input[1] = torch.randn(3, 3)
- input[2] = torch.randn(3, 3)
- input[3] = torch.randn(3, 3)
+ local input = {}
+ input[1] = torch.randn(3, 3)
+ input[2] = torch.randn(3, 3)
+ input[3] = torch.randn(3, 3)
- checkGradients(model, input)
+ checkGradients(model, input)
- local input = {}
- input[1] = torch.randn(2, 3)
- input[2] = torch.randn(2, 3)
- input[3] = torch.randn(2, 3)
+ local input = {}
+ input[1] = torch.randn(2, 3)
+ input[2] = torch.randn(2, 3)
+ input[3] = torch.randn(2, 3)
- checkGradients(model, input)
+ checkGradients(model, input)
end
function test.test_unusedInput()
- local x = nn.Identity()()
- local h = nn.Identity()()
- local h2 = nn.Identity()()
+ local x = nn.Identity()()
+ local h = nn.Identity()()
+ local h2 = nn.Identity()()
- local ok, result = pcall(nn.gModule, {x, h}, {x})
- assert(not ok, "the unused input should be detected")
+ local ok, result = pcall(nn.gModule, {x, h}, {x})
+ assert(not ok, "the unused input should be detected")
end
function test.test_unusedChild()
- local prevState = nn.Identity()()
- local h, cell = prevState:split(2)
+ local prevState = nn.Identity()()
+ local h, cell = prevState:split(2)
- local ok, result = pcall(nn.gModule, {prevState}, {h})
- assert(not ok, "the unused cell should be detected")
+ local ok, result = pcall(nn.gModule, {prevState}, {h})
+ assert(not ok, "the unused cell should be detected")
end
function test.test_nilInput()
- local ok, result = pcall(function() nn.Sigmoid()(nil) end)
- assert(not ok, "the nil input should be detected")
+ local ok, result = pcall(function() nn.Sigmoid()(nil) end)
+ assert(not ok, "the nil input should be detected")
end
function test.test_unusedNode()
- local in1 = nn.Identity()()
- local in2 = nn.Identity()()
- local middleResult = nn.Sigmoid()(in2)
- local out1 = nn.Sigmoid()(in1)
+ local in1 = nn.Identity()()
+ local in2 = nn.Identity()()
+ local middleResult = nn.Sigmoid()(in2)
+ local out1 = nn.Sigmoid()(in1)
- local ok, result = pcall(nn.gModule, {in1, in2}, {out1})
- assert(not ok, "the unused middleResult should be detected")
+ local ok, result = pcall(nn.gModule, {in1, in2}, {out1})
+ assert(not ok, "the unused middleResult should be detected")
end
function test.test_usageAfterSplit()
- local prevState = nn.Identity()()
- local h, cell = prevState:split(2)
- local nextState = nn.Identity()(prevState)
- local transformed = nn.Sigmoid()(cell)
-
- local model = nn.gModule({prevState}, {h, nextState, transformed})
- local nHidden = 10
- local input = {torch.randn(nHidden), torch.randn(nHidden)}
- checkGradients(model, input)
+ local prevState = nn.Identity()()
+ local h, cell = prevState:split(2)
+ local nextState = nn.Identity()(prevState)
+ local transformed = nn.Sigmoid()(cell)
+
+ local model = nn.gModule({prevState}, {h, nextState, transformed})
+ local nHidden = 10
+ local input = {torch.randn(nHidden), torch.randn(nHidden)}
+ checkGradients(model, input)
end
function test.test_resizeNestedAs()
- local in1 = nn.Identity()()
- local out1 = nn.Identity()(in1)
- local out2 = nn.Identity()(in1)
-
- local net = nn.gModule({in1}, {out1, out2})
- local input = {torch.randn(10), {torch.randn(3), torch.randn(4)}}
- net:forward(input)
- net:backward(input, net.output)
- checkGradients(net, input)
-
- input = {torch.randn(10), {torch.randn(3), torch.randn(4), torch.randn(5)}}
- net:forward(input)
- net:backward(input, net.output)
- checkGradients(net, input)
-
- input = {torch.randn(10), {torch.randn(3), torch.randn(4)}}
- net:forward(input)
- local gradInput = net:backward(input, net.output)
- tester:eq(#(gradInput[2]), 2, "gradInput[2] size")
- checkGradients(net, input)
+ local in1 = nn.Identity()()
+ local out1 = nn.Identity()(in1)
+ local out2 = nn.Identity()(in1)
+
+ local net = nn.gModule({in1}, {out1, out2})
+ local input = {torch.randn(10), {torch.randn(3), torch.randn(4)}}
+ net:forward(input)
+ net:backward(input, net.output)
+ checkGradients(net, input)
+
+ input = {torch.randn(10), {torch.randn(3), torch.randn(4), torch.randn(5)}}
+ net:forward(input)
+ net:backward(input, net.output)
+ checkGradients(net, input)
+
+ input = {torch.randn(10), {torch.randn(3), torch.randn(4)}}
+ net:forward(input)
+ local gradInput = net:backward(input, net.output)
+ tester:eq(#(gradInput[2]), 2, "gradInput[2] size")
+ checkGradients(net, input)
end
-
-function test.test_annotateGraph()
- local input = nn.Identity()():annotate(
- {name = 'Input', description = 'DescA',
- graphAttributes = {color = 'red'}})
-
- local hidden_a = nn.Linear(10, 10)(input):annotate(
- {name = 'Hidden A', description = 'DescB',
- graphAttributes = {color = 'blue', fontcolor='green', tooltip = 'I am green'}})
- local hidden_b = nn.Sigmoid()(hidden_a)
- local output = nn.Linear(10, 10)(hidden_b)
- local net = nn.gModule({input}, {output})
-
- tester:assert(hidden_a:label():match('DescB'))
- local fg_tmpfile = os.tmpname()
- local bg_tmpfile = os.tmpname()
- graph.dot(net.fg, 'Test', fg_tmpfile)
- graph.dot(net.fg, 'Test BG', bg_tmpfile)
-
- local function checkDotFile(tmpfile)
- local dotcontent = io.open(tmpfile .. '.dot', 'r'):read("*all")
- tester:assert(
- dotcontent:match('%[label=%"Input.*DescA.*%" color=red%]'))
- tester:assert(
- dotcontent:match(
- '%[label=%"Hidden A.*DescB.*%".*fontcolor=green.*%]'))
- tester:assert(
- dotcontent:match('%[label=%".*DescB.*%".*color=blue.*%]'))
- tester:assert(
- dotcontent:match(
- '%[label=%".*DescB.*%".*tooltip=%".*test_nngraph.lua.*%".*%]'))
- end
-
- checkDotFile(fg_tmpfile)
- checkDotFile(bg_tmpfile)
-end
-
-
tester:add(test):run()