Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nngraph.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@fb.com>2016-04-28 07:16:47 +0300
committerSoumith Chintala <soumith@fb.com>2016-04-28 07:16:47 +0300
commit7be68b72cccaf69433af5efc7061b288fd22a4e8 (patch)
tree9b88c2aba9db0c6f0771ce0572517762f5b87ccd
parent9c2079c7c9393a301f3a6286663dda9ef2f82b4e (diff)
fixes for lua 5.2
-rw-r--r--test/test_nngraph.lua15
1 files changed, 10 insertions, 5 deletions
diff --git a/test/test_nngraph.lua b/test/test_nngraph.lua
index b919c84..49b5e93 100644
--- a/test/test_nngraph.lua
+++ b/test/test_nngraph.lua
@@ -335,7 +335,7 @@ function test.test_gradInputType()
local output = nn.Linear(10, 10)(hidden_b)
local net = nn.gModule({input}, {output})
- tester:assert(hidden_a:label():match('DescB'))
+ tester:assert(hidden_a:label():match('DescB') ~= nil)
local fg_tmpfile = os.tmpname()
local bg_tmpfile = os.tmpname()
graph.dot(net.fg, 'Test', fg_tmpfile)
@@ -344,15 +344,15 @@ function test.test_gradInputType()
local function checkDotFile(tmpfile)
local dotcontent = io.open(tmpfile .. '.dot', 'r'):read("*all")
tester:assert(
- dotcontent:match('%[color=red.*label=%"Input.*DescA.*%".*%]'))
+ dotcontent:match('%[color=red.*label=%"Input.*DescA.*%".*%]') ~= nil)
tester:assert(
dotcontent:match(
- '%[.*fontcolor=green.*label=%"Hidden A.*DescB.*%".*%]'))
+ '%[.*fontcolor=green.*label=%"Hidden A.*DescB.*%".*%]') ~= nil)
tester:assert(
- dotcontent:match('%[color=blue.*label=%".*DescB.*%".*%]'))
+ dotcontent:match('%[color=blue.*label=%".*DescB.*%".*%]') ~= nil)
tester:assert(
dotcontent:match(
- '%[.*label=%".*DescB.*%".*tooltip=%"I am green%".*%]'))
+ '%[.*label=%".*DescB.*%".*tooltip=%"I am green%".*%]') ~= nil)
end
checkDotFile(fg_tmpfile)
@@ -382,6 +382,10 @@ function test.test_gradInputType()
end
function test.test_gradOutputZeroOptim()
+ local unpack = function(...)
+ if _G[unpack] then return _G[unpack](...)
+ else return table.unpack(...) end
+ end
-- Make module that produces an expanded zero gradInput tensor
local dummyModule = nn.Module()
dummyModule.updateOutput = function(self, input)
@@ -410,6 +414,7 @@ function test.test_gradInputType()
local ok, result = pcall(nn.Module.forward, mod, input)
assert(ok, "forward should succeed")
+ nn.Module.backward( mod, input, gradOutput)
ok, result = pcall(nn.Module.backward, mod, input, gradOutput)
assert(ok, "backward should succeed")