diff options
author | Soumith Chintala <soumith@fb.com> | 2016-04-28 07:16:47 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@fb.com> | 2016-04-28 07:16:47 +0300 |
commit | 7be68b72cccaf69433af5efc7061b288fd22a4e8 (patch) | |
tree | 9b88c2aba9db0c6f0771ce0572517762f5b87ccd | |
parent | 9c2079c7c9393a301f3a6286663dda9ef2f82b4e (diff) |
fixes for lua 5.2
-rw-r--r-- | test/test_nngraph.lua | 15 |
1 files changed, 10 insertions, 5 deletions
diff --git a/test/test_nngraph.lua b/test/test_nngraph.lua index b919c84..49b5e93 100644 --- a/test/test_nngraph.lua +++ b/test/test_nngraph.lua @@ -335,7 +335,7 @@ function test.test_gradInputType() local output = nn.Linear(10, 10)(hidden_b) local net = nn.gModule({input}, {output}) - tester:assert(hidden_a:label():match('DescB')) + tester:assert(hidden_a:label():match('DescB') ~= nil) local fg_tmpfile = os.tmpname() local bg_tmpfile = os.tmpname() graph.dot(net.fg, 'Test', fg_tmpfile) @@ -344,15 +344,15 @@ function test.test_gradInputType() local function checkDotFile(tmpfile) local dotcontent = io.open(tmpfile .. '.dot', 'r'):read("*all") tester:assert( - dotcontent:match('%[color=red.*label=%"Input.*DescA.*%".*%]')) + dotcontent:match('%[color=red.*label=%"Input.*DescA.*%".*%]') ~= nil) tester:assert( dotcontent:match( - '%[.*fontcolor=green.*label=%"Hidden A.*DescB.*%".*%]')) + '%[.*fontcolor=green.*label=%"Hidden A.*DescB.*%".*%]') ~= nil) tester:assert( - dotcontent:match('%[color=blue.*label=%".*DescB.*%".*%]')) + dotcontent:match('%[color=blue.*label=%".*DescB.*%".*%]') ~= nil) tester:assert( dotcontent:match( - '%[.*label=%".*DescB.*%".*tooltip=%"I am green%".*%]')) + '%[.*label=%".*DescB.*%".*tooltip=%"I am green%".*%]') ~= nil) end checkDotFile(fg_tmpfile) @@ -382,6 +382,10 @@ function test.test_gradInputType() end function test.test_gradOutputZeroOptim() + local unpack = function(...) + if _G[unpack] then return _G[unpack](...) + else return table.unpack(...) end + end -- Make module that produces an expanded zero gradInput tensor local dummyModule = nn.Module() dummyModule.updateOutput = function(self, input) @@ -410,6 +414,7 @@ function test.test_gradInputType() local ok, result = pcall(nn.Module.forward, mod, input) assert(ok, "forward should succeed") + nn.Module.backward( mod, input, gradOutput) ok, result = pcall(nn.Module.backward, mod, input, gradOutput) assert(ok, "backward should succeed") |