diff options
author | koray kavukcuoglu <koray@kavukcuoglu.org> | 2015-06-18 20:44:22 +0300 |
---|---|---|
committer | koray kavukcuoglu <koray@kavukcuoglu.org> | 2015-06-18 20:44:22 +0300 |
commit | 045c17eab6ec10e6d90b19216c52058afb731656 (patch) | |
tree | 6122286066a45a4b19eb37e0f8d78eb67173b373 | |
parent | eb65b3602cdf3214ff671d5422521ad414f1c3b0 (diff) |
big cleanup, move graphviz related bits to graphviz.lua filescleanup
-rw-r--r-- | doc/example6.lua | 31 | ||||
-rw-r--r-- | graphviz.lua | 82 | ||||
-rw-r--r-- | init.lua | 4 | ||||
-rw-r--r-- | node.lua | 146 | ||||
-rw-r--r-- | rocks/nngraph-scm-1.rockspec (renamed from nngraph-scm-1.rockspec) | 0 | ||||
-rw-r--r-- | test/test_nngraph.lua | 502 |
6 files changed, 366 insertions, 399 deletions
diff --git a/doc/example6.lua b/doc/example6.lua new file mode 100644 index 0000000..39b440f --- /dev/null +++ b/doc/example6.lua @@ -0,0 +1,31 @@ +require 'nngraph' + +-- generate SVG of the graph with the problem node highlighted +-- and hover over the nodes in svg to see the filename:line_number info +-- nodes will be annotated with local variable names even if debug mode is not enabled. +nngraph.setDebug(true) + +local function get_net(from, to) + local from = from or 10 + local to = to or 10 + local input_x = nn.Identity()() + local linear_module = nn.Linear(from, to)(input_x) + + -- Annotate nodes with local variable names + nngraph.annotateNodes() + return nn.gModule({input_x},{linear_module}) +end + +local net = get_net(10,10) + +-- if you give a name to the net, it will use that name to produce the +-- svg in case of error, if not, it will come up with a name +-- that is derived from number of inputs and outputs to the graph +net.name = 'my_bad_linear_net' + +-- prepare an input that is of the wrong size to force an error +local input = torch.rand(11) +pcall(function() net:updateOutput(input) end) +-- it should have produced an error and spit out a graph +-- just run Safari to display the svg +os.execute('open -a Safari my_bad_linear_net.svg') diff --git a/graphviz.lua b/graphviz.lua new file mode 100644 index 0000000..092877a --- /dev/null +++ b/graphviz.lua @@ -0,0 +1,82 @@ +-- handy functions +local utils = paths.dofile('utils.lua') +local istensor = utils.istensor +local istable = utils.istable +local istorchclass = utils.istorchclass + +local function getNanFlag(data) + if data:nElement() == 0 then + return '' + end + local isNan = (data:ne(data):sum() > 0) + if isNan then + return 'NaN' + end + if data:max() == math.huge then + return 'inf' + end + if data:min() == -math.huge then + return '-inf' + end + return '' +end +local function getstr(data) + if not data then return '' end + if istensor(data) then + local nanFlag = getNanFlag(data) + local tensorType = 'Tensor' + if data:type() ~= torch.Tensor():type() then + tensorType = data:type() + end + return tensorType .. '[' .. table.concat(data:size():totable(),'x') .. ']' .. nanFlag + elseif istable(data) then + local tstr = {} + for i,v in ipairs(data) do + table.insert(tstr, getstr(v)) + end + return '{' .. table.concat(tstr,',') .. '}' + else + return tostring(data):gsub('\n','\\l') + end +end +local function getmapindexstr(mapindex) + local tstr = {} + for i,data in ipairs(mapindex) do + local inputId = 'Node' .. (data.forwardNodeId or '') + table.insert(tstr, inputId) + end + return '{' .. table.concat(tstr,',') .. '}' +end + +local Node = torch.getmetatable('nngraph.Node') + + +--[[ +Returns a textual representation of the Node that can be used by graphviz library visualization. +]] +function Node:label() + + local lbl = {} + + for k,v in pairs(self.data) do + local vstr = '' + if k == 'mapindex' then + if #v > 1 then + vstr = getmapindexstr(v) + table.insert(lbl, k .. ' = ' .. vstr) + end + elseif k == 'forwardNodeId' or k == 'annotations' then + -- the forwardNodeId is not displayed in the label. + else + vstr = getstr(v) + table.insert(lbl, k .. ' = ' .. vstr) + end + end + + local desc = '' + if self.data.annotations.description then + desc = 'desc = ' .. self.data.annotations.description .. '\\n' + end + return desc .. table.concat(lbl,"\\l") +end + @@ -6,6 +6,7 @@ nngraph = {} torch.include('nngraph','node.lua') torch.include('nngraph','gmodule.lua') +torch.include('nngraph','graphviz.lua') torch.include('nngraph','graphinspecting.lua') torch.include('nngraph','ModuleFromCriterion.lua') @@ -36,7 +37,7 @@ function Module:__call__(...) for i,dnode in ipairs(input) do if torch.typename(dnode) ~= 'nngraph.Node' then - error('what is this in the input? ' .. tostring(dnode)) + error('Expected nngraph.Node type, what is this in the input? ' .. tostring(dnode)) end mnode:add(dnode,true) end @@ -44,6 +45,7 @@ function Module:__call__(...) return mnode end +-- Modify the __call function to hack into nn.Criterion local Criterion = torch.getmetatable('nn.Criterion') function Criterion:__call__(...) return nn.ModuleFromCriterion(self)(...) @@ -1,30 +1,23 @@ -local utils = paths.dofile('utils.lua') -local istensor = utils.istensor -local istable = utils.istable -local istorchclass = utils.istorchclass -require 'debug' - - -local nnNode,parent = torch.class('nngraph.Node','graph.Node') - +--[[ +This file implements the nngraph.Node. In addition to graph.Node this class +provides some additional functionality for handling neural networks in a graph +]] +local nnNode,parent = torch.class('nngraph.Node','graph.AnnotatedNode') + + +--[[ +nngraph.Node +Args: +* `data` - the same as graph.Node(data). Any object type that will be stored as data +in the graph node. +]] function nnNode:__init(data) - parent.__init(self,data) - self.data.annotations = self.data.annotations or {} - self.data.mapindex = self.data.mapindex or {} - if not self.data.annotations._debugLabel then - self:_makeDebugLabel(debug.getinfo(6, 'Sl')) - end -end - - ---[[ Build a string label which will be used a tooltip when - making a graph.]] -function nnNode:_makeDebugLabel(dinfo) - if dinfo then - self.data.annotations._debugLabel = string.format('[%s]:%d', - dinfo.short_src, dinfo.currentline, dinfo.name) - end + -- level 7 corresponds to level with the nngraph usage of nnNode's + -- inside Module:__call() syntax + parent.__init(self,data, 7) + -- decorate the data with additional info to keep track of order of connected nodes + self.data.mapindex = data.mapindex or {} end @@ -61,106 +54,3 @@ function nnNode:split(noutput) return unpack(selectnodes) end - -function nnNode:annotate(annotations) - for k, v in pairs(annotations) do - self.data.annotations[k] = v - end - - return self -end - - -function nnNode:graphNodeName() - if self.data.annotations.name then - return self.data.annotations.name .. ' (' .. self.id .. ')' - else - return 'Node' .. self.id - end -end - - -function nnNode:graphNodeAttributes() - self.data.annotations.graphAttributes = - self.data.annotations.graphAttributes or {} - if not self.data.annotations.graphAttributes.tooltip then - self.data.annotations.graphAttributes.tooltip = - self.data.annotations._debugLabel - end - - return self.data.annotations.graphAttributes -end - - -local function getNanFlag(data) - if data:nElement() == 0 then - return '' - end - local isNan = (data:ne(data):sum() > 0) - if isNan then - return 'NaN' - end - if data:max() == math.huge then - return 'inf' - end - if data:min() == -math.huge then - return '-inf' - end - return '' -end - -function nnNode:label() - - local lbl = {} - - local function getstr(data) - if not data then return '' end - if istensor(data) then - local nanFlag = getNanFlag(data) - local tensorType = 'Tensor' - if data:type() ~= torch.Tensor():type() then - tensorType = data:type() - end - return tensorType .. '[' .. table.concat(data:size():totable(),'x') .. ']' .. nanFlag - elseif istable(data) then - local tstr = {} - for i,v in ipairs(data) do - table.insert(tstr, getstr(v)) - end - return '{' .. table.concat(tstr,',') .. '}' - else - return tostring(data):gsub('\n','\\l') - end - end - local function getmapindexstr(mapindex) - local tstr = {} - for i,data in ipairs(mapindex) do - local inputId = 'Node' .. (data.forwardNodeId or '') - table.insert(tstr, inputId) - end - return '{' .. table.concat(tstr,',') .. '}' - end - - for k,v in pairs(self.data) do - local vstr = '' - if k== 'mapindex' then - if #v > 1 then - vstr = getmapindexstr(v) - table.insert(lbl, k .. ' = ' .. vstr) - end - elseif k== 'forwardNodeId' or k== 'annotations' then - -- the forwardNodeId is not displayed in the label. - else - vstr = getstr(v) - table.insert(lbl, k .. ' = ' .. vstr) - end - end - - local desc - if self.data.annotations.description then - desc = 'desc = ' .. self.data.annotations.description .. '\\n' - else - desc = '' - end - return desc .. table.concat(lbl,"\\l") -end diff --git a/nngraph-scm-1.rockspec b/rocks/nngraph-scm-1.rockspec index 4963ecd..4963ecd 100644 --- a/nngraph-scm-1.rockspec +++ b/rocks/nngraph-scm-1.rockspec diff --git a/test/test_nngraph.lua b/test/test_nngraph.lua index ac47be2..8dac03a 100644 --- a/test/test_nngraph.lua +++ b/test/test_nngraph.lua @@ -5,358 +5,320 @@ local test = {} local tester = totem.Tester() local function checkGradients(...) - totem.nn.checkGradients(tester, ...) + totem.nn.checkGradients(tester, ...) end function test.test_oneOutput() - local in1 = nn.Identity()() - local out1 = nn.Identity()(in1) - local module = nn.gModule({in1}, {out1}) - - local input = torch.Tensor({1}) - module:forward(input) - tester:eq(module.output, torch.Tensor{1}, "output") - local gradInput = module:backward(input, torch.Tensor({-123})) - tester:eq(gradInput, torch.Tensor{-123}, "gradInput") - - local input2 = torch.Tensor({2}) - module:forward(input2) - tester:eq(module.output, torch.Tensor{2}, "output for input2") - gradInput = module:backward(input2, torch.Tensor({-2})) - tester:eq(gradInput, torch.Tensor{-2}, "expecting a recomputed gradInput") + local in1 = nn.Identity()() + local out1 = nn.Identity()(in1) + local module = nn.gModule({in1}, {out1}) + + local input = torch.Tensor({1}) + module:forward(input) + tester:eq(module.output, torch.Tensor{1}, "output") + local gradInput = module:backward(input, torch.Tensor({-123})) + tester:eq(gradInput, torch.Tensor{-123}, "gradInput") + + local input2 = torch.Tensor({2}) + module:forward(input2) + tester:eq(module.output, torch.Tensor{2}, "output for input2") + gradInput = module:backward(input2, torch.Tensor({-2})) + tester:eq(gradInput, torch.Tensor{-2}, "expecting a recomputed gradInput") end function test.test_twoOutputs() - local in1 = nn.Identity()() - local out1 = nn.Identity()(in1) - local out2 = nn.Identity()(in1) - local module = nn.gModule({in1}, {out1, out2}) - - local input = torch.Tensor({1}) - module:forward(input) - local gradInput = module:backward(input, {torch.Tensor({-2}), torch.Tensor({-3})}) - tester:eq(gradInput, torch.Tensor{-5}, "gradInput of a fork") - checkGradients(module, input) + local in1 = nn.Identity()() + local out1 = nn.Identity()(in1) + local out2 = nn.Identity()(in1) + local module = nn.gModule({in1}, {out1, out2}) + + local input = torch.Tensor({1}) + module:forward(input) + local gradInput = module:backward(input, {torch.Tensor({-2}), torch.Tensor({-3})}) + tester:eq(gradInput, torch.Tensor{-5}, "gradInput of a fork") + checkGradients(module, input) end function test.test_twoGradOutputs() - local in1 = nn.Sigmoid()() - local splitTable = nn.SplitTable(1)({in1}) - local out1, out2 = splitTable:split(2) - local module = nn.gModule({in1}, {out1, out2}) - - local input = torch.randn(2, 3) - local output = module:forward(input) - assert(#output == 2, "wrong number of outputs") - module:backward(input, {torch.randn(3), torch.randn(3)}) - checkGradients(module, input) + local in1 = nn.Sigmoid()() + local splitTable = nn.SplitTable(1)({in1}) + local out1, out2 = splitTable:split(2) + local module = nn.gModule({in1}, {out1, out2}) + + local input = torch.randn(2, 3) + local output = module:forward(input) + assert(#output == 2, "wrong number of outputs") + module:backward(input, {torch.randn(3), torch.randn(3)}) + checkGradients(module, input) end function test.test_twoInputs() - local in1 = nn.Identity()() - local in2 = nn.Identity()() - local prevH, prevCell = in2:split(2) - - local out1 = nn.CMulTable()({in1, prevH, prevCell}) - local module = nn.gModule({in1, in2}, {out1}) - - local input = {torch.randn(3), {torch.randn(3), torch.randn(3)}} - module:forward(input) - local gradInput = module:backward(input, torch.randn(3)) - assert(#gradInput == 2, "wrong number of gradInputs") - assert(type(gradInput[2]) == "table", "wrong gradInput[2] type") - checkGradients(module, input) + local in1 = nn.Identity()() + local in2 = nn.Identity()() + local prevH, prevCell = in2:split(2) + + local out1 = nn.CMulTable()({in1, prevH, prevCell}) + local module = nn.gModule({in1, in2}, {out1}) + + local input = {torch.randn(3), {torch.randn(3), torch.randn(3)}} + module:forward(input) + local gradInput = module:backward(input, torch.randn(3)) + assert(#gradInput == 2, "wrong number of gradInputs") + assert(type(gradInput[2]) == "table", "wrong gradInput[2] type") + checkGradients(module, input) end function test.test_twoInputs2() - local in1 = nn.Sigmoid()() - local in2 = nn.Sigmoid()() - local module = nn.gModule({in1, in2}, {in1, in2, nn.Sigmoid()(in1)}) - - local input = {torch.randn(3), torch.randn(3)} - module:forward(input) - local gradInput = module:backward(input, {torch.randn(3), torch.randn(3), torch.randn(3)}) - checkGradients(module, input) + local in1 = nn.Sigmoid()() + local in2 = nn.Sigmoid()() + local module = nn.gModule({in1, in2}, {in1, in2, nn.Sigmoid()(in1)}) + + local input = {torch.randn(3), torch.randn(3)} + module:forward(input) + local gradInput = module:backward(input, {torch.randn(3), torch.randn(3), torch.randn(3)}) + checkGradients(module, input) end function test.test_splitDebugLabels() - local node = nn.Identity()() - node.data.annotations._debugLabel = "node" - local node1, node2 = node:split(2) - assert(node1.data.annotations._debugLabel == "node-1") - assert(node2.data.annotations._debugLabel == "node-2") + local node = nn.Identity()() + node.data.annotations._debugLabel = "node" + local node1, node2 = node:split(2) + assert(node1.data.annotations._debugLabel == "node-1") + assert(node2.data.annotations._debugLabel == "node-2") end function test.test_identity() - local in1 = nn.Identity()() - local in2 = nn.Identity()() - local module = nn.gModule({in1, in2}, {in1, in2, nn.Identity()(in1)}) - - local input = {torch.randn(3), torch.randn(3)} - module:forward(input) - module:backward(input, {torch.randn(3), torch.randn(3), torch.randn(3)}) - checkGradients(module, input) + local in1 = nn.Identity()() + local in2 = nn.Identity()() + local module = nn.gModule({in1, in2}, {in1, in2, nn.Identity()(in1)}) + + local input = {torch.randn(3), torch.randn(3)} + module:forward(input) + module:backward(input, {torch.randn(3), torch.randn(3), torch.randn(3)}) + checkGradients(module, input) end function test.test_gradInputType() - local xInput = torch.randn(3) - local h = torch.randn(3) - - local x = nn.Identity()() - local prevRnnState = nn.Identity()() - local prevH1, prevCell = prevRnnState:split(2) - local prevH = prevH1 - - local cellOut = nn.CAddTable()({ - nn.CMulTable()({x, prevH}), - nn.CMulTable()({prevH, prevCell})}) - local module = nn.gModule({x, prevRnnState}, {cellOut}) - - local c = torch.randn(h:size()) - local prevRnnState = {h, c} - local input = {xInput, prevRnnState} - local output = module:forward(input) - - local gradOutput = torch.randn(h:size()) - local gradInput = module:backward(input, gradOutput) - - local gradX, gradPrevState = unpack(gradInput) - local gradPrevH, gradPrevCell = unpack(gradPrevState) - assert(type(gradPrevH) == type(h), "wrong gradPrevH type") - - tester:eq(type(gradPrevH), type(h), "wrong gradPrevH type") - tester:eq(gradPrevH:size(), h:size(), "wrong gradPrevH size") - checkGradients(module, input) + local xInput = torch.randn(3) + local h = torch.randn(3) + + local x = nn.Identity()() + local prevRnnState = nn.Identity()() + local prevH1, prevCell = prevRnnState:split(2) + local prevH = prevH1 + + local cellOut = nn.CAddTable()({ + nn.CMulTable()({x, prevH}), + nn.CMulTable()({prevH, prevCell})}) + local module = nn.gModule({x, prevRnnState}, {cellOut}) + + local c = torch.randn(h:size()) + local prevRnnState = {h, c} + local input = {xInput, prevRnnState} + local output = module:forward(input) + + local gradOutput = torch.randn(h:size()) + local gradInput = module:backward(input, gradOutput) + + local gradX, gradPrevState = unpack(gradInput) + local gradPrevH, gradPrevCell = unpack(gradPrevState) + assert(type(gradPrevH) == type(h), "wrong gradPrevH type") + + tester:eq(type(gradPrevH), type(h), "wrong gradPrevH type") + tester:eq(gradPrevH:size(), h:size(), "wrong gradPrevH size") + checkGradients(module, input) end function test.test_tabularInput() - local in1 = nn.SplitTable(1)() - local out1 = nn.CAddTable()(in1) - local module = nn.gModule({in1}, {out1}) + local in1 = nn.SplitTable(1)() + local out1 = nn.CAddTable()(in1) + local module = nn.gModule({in1}, {out1}) - local input = torch.randn(2, 3) - checkGradients(module, input) + local input = torch.randn(2, 3) + checkGradients(module, input) end function test.test_extraTable() - local in1 = nn.Identity()() - local out1 = nn.Identity()(in1) - local module = nn.gModule({in1}, {out1}) + local in1 = nn.Identity()() + local out1 = nn.Identity()(in1) + local module = nn.gModule({in1}, {out1}) - local input = torch.Tensor({123}) - tester:eq(module:forward(input), input, "simple output") - tester:eq(module:forward({input}), {input}, "tabular output") + local input = torch.Tensor({123}) + tester:eq(module:forward(input), input, "simple output") + tester:eq(module:forward({input}), {input}, "tabular output") end function test.test_accGradParameters() - local input = torch.randn(10) + local input = torch.randn(10) - local in1 = nn.CMul(input:nElement())() - local out1 = nn.Identity()(in1) - local out2 = nn.Identity()(in1) - local module = nn.gModule({in1}, {out1, out2}) - checkGradients(module, input) + local in1 = nn.CMul(input:nElement())() + local out1 = nn.Identity()(in1) + local out2 = nn.Identity()(in1) + local module = nn.gModule({in1}, {out1, out2}) + checkGradients(module, input) end function test.test_example1() - local x1 = nn.Linear(20,10)() - local mout = nn.Linear(10,1)(nn.Tanh()(nn.Linear(10,10)(nn.Tanh()(x1)))) - local mlp = nn.gModule({x1},{mout}) + local x1 = nn.Linear(20,10)() + local mout = nn.Linear(10,1)(nn.Tanh()(nn.Linear(10,10)(nn.Tanh()(x1)))) + local mlp = nn.gModule({x1},{mout}) - local x = torch.rand(20) - checkGradients(mlp, x) + local x = torch.rand(20) + checkGradients(mlp, x) end function test.test_example2() - local x1=nn.Linear(20,20)() - local x2=nn.Linear(10,10)() - local m0=nn.Linear(20,1)(nn.Tanh()(x1)) - local m1=nn.Linear(10,1)(nn.Tanh()(x2)) - local madd=nn.CAddTable()({m0,m1}) - local m2=nn.Sigmoid()(madd) - local m3=nn.Tanh()(madd) - local gmod = nn.gModule({x1,x2},{m2,m3}) - - local x = torch.rand(20) - local y = torch.rand(10) - checkGradients(gmod, {x, y}) + local x1=nn.Linear(20,20)() + local x2=nn.Linear(10,10)() + local m0=nn.Linear(20,1)(nn.Tanh()(x1)) + local m1=nn.Linear(10,1)(nn.Tanh()(x2)) + local madd=nn.CAddTable()({m0,m1}) + local m2=nn.Sigmoid()(madd) + local m3=nn.Tanh()(madd) + local gmod = nn.gModule({x1,x2},{m2,m3}) + + local x = torch.rand(20) + local y = torch.rand(10) + checkGradients(gmod, {x, y}) end function test.test_example3() - local m = nn.Sequential() - m:add(nn.SplitTable(1)) - m:add(nn.ParallelTable():add(nn.Linear(10,20)):add(nn.Linear(10,30))) - local input = nn.Identity()() - local input1,input2 = m(input):split(2) - local m3 = nn.JoinTable(1)({input1,input2}) - local g = nn.gModule({input},{m3}) - - local indata = torch.rand(2,10) - checkGradients(g, indata) + local m = nn.Sequential() + m:add(nn.SplitTable(1)) + m:add(nn.ParallelTable():add(nn.Linear(10,20)):add(nn.Linear(10,30))) + local input = nn.Identity()() + local input1,input2 = m(input):split(2) + local m3 = nn.JoinTable(1)({input1,input2}) + local g = nn.gModule({input},{m3}) + + local indata = torch.rand(2,10) + checkGradients(g, indata) end function test.test_example4() - local input = nn.Identity()() - local L1 = nn.Tanh()(nn.Linear(1,2)(input)) - local L2 = nn.Tanh()(nn.Linear(3,6)(nn.JoinTable(1)({input,L1}))) - local L3 = nn.Tanh()(nn.Linear(8,16)(nn.JoinTable(1)({L1,L2}))) - local g = nn.gModule({input},{L3}) - - local indata = torch.rand(1) - checkGradients(g, indata) + local input = nn.Identity()() + local L1 = nn.Tanh()(nn.Linear(1,2)(input)) + local L2 = nn.Tanh()(nn.Linear(3,6)(nn.JoinTable(1)({input,L1}))) + local L3 = nn.Tanh()(nn.Linear(8,16)(nn.JoinTable(1)({L1,L2}))) + local g = nn.gModule({input},{L3}) + + local indata = torch.rand(1) + checkGradients(g, indata) end function test.test_type() - local in1 = nn.Linear(20,10)() - local out1 = nn.Linear(10,1)(nn.Tanh()(nn.Linear(10,10)(nn.Tanh()(in1)))) - local module = nn.gModule({in1}, {out1}) - local input = torch.rand(20) - local output = module:forward(input) - module:backward(input, output) - tester:eq(torch.typename(output), "torch.DoubleTensor") - tester:eq(torch.typename(module.output), "torch.DoubleTensor") - tester:eq(torch.typename(module.gradInput), "torch.DoubleTensor") - tester:eq(torch.typename(module.innode.data.input[1]), "torch.DoubleTensor") - tester:eq(torch.typename(module.outnode.data.input[1]), "torch.DoubleTensor") - tester:eq(torch.typename(module.forwardnodes[1].data.input[1]), "torch.DoubleTensor") - - module:float() - local output = module:forward(input:float()) - tester:eq(torch.typename(output), "torch.FloatTensor") - tester:eq(torch.typename(module.output), "torch.FloatTensor") - tester:eq(torch.typename(module.gradInput), "torch.FloatTensor") - tester:eq(torch.typename(module.innode.data.input[1]), "torch.FloatTensor") - tester:eq(torch.typename(module.outnode.data.input[1]), "torch.FloatTensor") - tester:eq(torch.typename(module.forwardnodes[1].data.input[1]), "torch.FloatTensor") + local in1 = nn.Linear(20,10)() + local out1 = nn.Linear(10,1)(nn.Tanh()(nn.Linear(10,10)(nn.Tanh()(in1)))) + local module = nn.gModule({in1}, {out1}) + local input = torch.rand(20) + local output = module:forward(input) + module:backward(input, output) + tester:eq(torch.typename(output), "torch.DoubleTensor") + tester:eq(torch.typename(module.output), "torch.DoubleTensor") + tester:eq(torch.typename(module.gradInput), "torch.DoubleTensor") + tester:eq(torch.typename(module.innode.data.input[1]), "torch.DoubleTensor") + tester:eq(torch.typename(module.outnode.data.input[1]), "torch.DoubleTensor") + tester:eq(torch.typename(module.forwardnodes[1].data.input[1]), "torch.DoubleTensor") + + module:float() + local output = module:forward(input:float()) + tester:eq(torch.typename(output), "torch.FloatTensor") + tester:eq(torch.typename(module.output), "torch.FloatTensor") + tester:eq(torch.typename(module.gradInput), "torch.FloatTensor") + tester:eq(torch.typename(module.innode.data.input[1]), "torch.FloatTensor") + tester:eq(torch.typename(module.outnode.data.input[1]), "torch.FloatTensor") + tester:eq(torch.typename(module.forwardnodes[1].data.input[1]), "torch.FloatTensor") end function test.test_nestedGradInput() - local x = nn.Identity()() - local h1 = nn.Sequential():add(nn.JoinTable(2)):add(nn.Tanh()) - local h2 = nn.Sequential():add(nn.JoinTable(2)):add(nn.Identity()) - local out = nn.CAddTable()({h1(x), h2(x)}) + local x = nn.Identity()() + local h1 = nn.Sequential():add(nn.JoinTable(2)):add(nn.Tanh()) + local h2 = nn.Sequential():add(nn.JoinTable(2)):add(nn.Identity()) + local out = nn.CAddTable()({h1(x), h2(x)}) - local model = nn.gModule({x}, {out}) + local model = nn.gModule({x}, {out}) - local input = {} - input[1] = torch.randn(3, 3) - input[2] = torch.randn(3, 3) - input[3] = torch.randn(3, 3) + local input = {} + input[1] = torch.randn(3, 3) + input[2] = torch.randn(3, 3) + input[3] = torch.randn(3, 3) - checkGradients(model, input) + checkGradients(model, input) - local input = {} - input[1] = torch.randn(2, 3) - input[2] = torch.randn(2, 3) - input[3] = torch.randn(2, 3) + local input = {} + input[1] = torch.randn(2, 3) + input[2] = torch.randn(2, 3) + input[3] = torch.randn(2, 3) - checkGradients(model, input) + checkGradients(model, input) end function test.test_unusedInput() - local x = nn.Identity()() - local h = nn.Identity()() - local h2 = nn.Identity()() + local x = nn.Identity()() + local h = nn.Identity()() + local h2 = nn.Identity()() - local ok, result = pcall(nn.gModule, {x, h}, {x}) - assert(not ok, "the unused input should be detected") + local ok, result = pcall(nn.gModule, {x, h}, {x}) + assert(not ok, "the unused input should be detected") end function test.test_unusedChild() - local prevState = nn.Identity()() - local h, cell = prevState:split(2) + local prevState = nn.Identity()() + local h, cell = prevState:split(2) - local ok, result = pcall(nn.gModule, {prevState}, {h}) - assert(not ok, "the unused cell should be detected") + local ok, result = pcall(nn.gModule, {prevState}, {h}) + assert(not ok, "the unused cell should be detected") end function test.test_nilInput() - local ok, result = pcall(function() nn.Sigmoid()(nil) end) - assert(not ok, "the nil input should be detected") + local ok, result = pcall(function() nn.Sigmoid()(nil) end) + assert(not ok, "the nil input should be detected") end function test.test_unusedNode() - local in1 = nn.Identity()() - local in2 = nn.Identity()() - local middleResult = nn.Sigmoid()(in2) - local out1 = nn.Sigmoid()(in1) + local in1 = nn.Identity()() + local in2 = nn.Identity()() + local middleResult = nn.Sigmoid()(in2) + local out1 = nn.Sigmoid()(in1) - local ok, result = pcall(nn.gModule, {in1, in2}, {out1}) - assert(not ok, "the unused middleResult should be detected") + local ok, result = pcall(nn.gModule, {in1, in2}, {out1}) + assert(not ok, "the unused middleResult should be detected") end function test.test_usageAfterSplit() - local prevState = nn.Identity()() - local h, cell = prevState:split(2) - local nextState = nn.Identity()(prevState) - local transformed = nn.Sigmoid()(cell) - - local model = nn.gModule({prevState}, {h, nextState, transformed}) - local nHidden = 10 - local input = {torch.randn(nHidden), torch.randn(nHidden)} - checkGradients(model, input) + local prevState = nn.Identity()() + local h, cell = prevState:split(2) + local nextState = nn.Identity()(prevState) + local transformed = nn.Sigmoid()(cell) + + local model = nn.gModule({prevState}, {h, nextState, transformed}) + local nHidden = 10 + local input = {torch.randn(nHidden), torch.randn(nHidden)} + checkGradients(model, input) end function test.test_resizeNestedAs() - local in1 = nn.Identity()() - local out1 = nn.Identity()(in1) - local out2 = nn.Identity()(in1) - - local net = nn.gModule({in1}, {out1, out2}) - local input = {torch.randn(10), {torch.randn(3), torch.randn(4)}} - net:forward(input) - net:backward(input, net.output) - checkGradients(net, input) - - input = {torch.randn(10), {torch.randn(3), torch.randn(4), torch.randn(5)}} - net:forward(input) - net:backward(input, net.output) - checkGradients(net, input) - - input = {torch.randn(10), {torch.randn(3), torch.randn(4)}} - net:forward(input) - local gradInput = net:backward(input, net.output) - tester:eq(#(gradInput[2]), 2, "gradInput[2] size") - checkGradients(net, input) + local in1 = nn.Identity()() + local out1 = nn.Identity()(in1) + local out2 = nn.Identity()(in1) + + local net = nn.gModule({in1}, {out1, out2}) + local input = {torch.randn(10), {torch.randn(3), torch.randn(4)}} + net:forward(input) + net:backward(input, net.output) + checkGradients(net, input) + + input = {torch.randn(10), {torch.randn(3), torch.randn(4), torch.randn(5)}} + net:forward(input) + net:backward(input, net.output) + checkGradients(net, input) + + input = {torch.randn(10), {torch.randn(3), torch.randn(4)}} + net:forward(input) + local gradInput = net:backward(input, net.output) + tester:eq(#(gradInput[2]), 2, "gradInput[2] size") + checkGradients(net, input) end - -function test.test_annotateGraph() - local input = nn.Identity()():annotate( - {name = 'Input', description = 'DescA', - graphAttributes = {color = 'red'}}) - - local hidden_a = nn.Linear(10, 10)(input):annotate( - {name = 'Hidden A', description = 'DescB', - graphAttributes = {color = 'blue', fontcolor='green', tooltip = 'I am green'}}) - local hidden_b = nn.Sigmoid()(hidden_a) - local output = nn.Linear(10, 10)(hidden_b) - local net = nn.gModule({input}, {output}) - - tester:assert(hidden_a:label():match('DescB')) - local fg_tmpfile = os.tmpname() - local bg_tmpfile = os.tmpname() - graph.dot(net.fg, 'Test', fg_tmpfile) - graph.dot(net.fg, 'Test BG', bg_tmpfile) - - local function checkDotFile(tmpfile) - local dotcontent = io.open(tmpfile .. '.dot', 'r'):read("*all") - tester:assert( - dotcontent:match('%[label=%"Input.*DescA.*%" color=red%]')) - tester:assert( - dotcontent:match( - '%[label=%"Hidden A.*DescB.*%".*fontcolor=green.*%]')) - tester:assert( - dotcontent:match('%[label=%".*DescB.*%".*color=blue.*%]')) - tester:assert( - dotcontent:match( - '%[label=%".*DescB.*%".*tooltip=%".*test_nngraph.lua.*%".*%]')) - end - - checkDotFile(fg_tmpfile) - checkDotFile(bg_tmpfile) -end - - tester:add(test):run() |