diff options
author | Soumith Chintala <soumith@gmail.com> | 2015-11-04 17:58:10 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2015-11-04 17:58:10 +0300 |
commit | 0d52b2ce85f2c9d36d5d9d8fda4147bc75c17159 (patch) | |
tree | 67e7c8276548f81cb37ad102fbfd342892e756c0 | |
parent | d761deb7b61f3ef8072206be4e04f82ec7dec750 (diff) |
adding Lua52 compatibility fixes
-rw-r--r-- | ModuleFromCriterion.lua | 2 | ||||
-rw-r--r-- | graphinspecting.lua | 1 | ||||
-rw-r--r-- | node.lua | 2 | ||||
-rw-r--r-- | test/test_nngraph.lua | 1 |
4 files changed, 6 insertions, 0 deletions
diff --git a/ModuleFromCriterion.lua b/ModuleFromCriterion.lua index 8775c3d..faca485 100644 --- a/ModuleFromCriterion.lua +++ b/ModuleFromCriterion.lua @@ -10,6 +10,8 @@ function ModuleFromCriterion:__init(criterion) self.gradInput = {torch.Tensor(), torch.Tensor()} end +local unpack = unpack or table.unpack -- lua52 compat + --[[ The input is a {prediction, target} pair. The output is a tensor with one number: the criterion output. --]] diff --git a/graphinspecting.lua b/graphinspecting.lua index e0676c4..8e858cf 100644 --- a/graphinspecting.lua +++ b/graphinspecting.lua @@ -114,6 +114,7 @@ function nngraph.setDebug(enable) nn.gModule[funcName] = function(...) local args = {...} local gmodule = args[1] + local unpack = unpack or table.unpack return runChecked(function() return origFunc(unpack(args)) end, onError, gmodule) @@ -55,6 +55,8 @@ function nnNode:split(noutput) node:add(mnode,true) table.insert(selectnodes,node) end + + local unpack = unpack or table.unpack -- Lua52 compat return unpack(selectnodes) end diff --git a/test/test_nngraph.lua b/test/test_nngraph.lua index 95a1658..e3d8982 100644 --- a/test/test_nngraph.lua +++ b/test/test_nngraph.lua @@ -121,6 +121,7 @@ function test.test_gradInputType() local gradOutput = torch.randn(h:size()) local gradInput = module:backward(input, gradOutput) + local unpack = unpack or table.unpack local gradX, gradPrevState = unpack(gradInput) local gradPrevH, gradPrevCell = unpack(gradPrevState) assert(type(gradPrevH) == type(h), "wrong gradPrevH type") |