Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nngraph.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2015-11-04 17:58:10 +0300
committerSoumith Chintala <soumith@gmail.com>2015-11-04 17:58:10 +0300
commit0d52b2ce85f2c9d36d5d9d8fda4147bc75c17159 (patch)
tree67e7c8276548f81cb37ad102fbfd342892e756c0
parentd761deb7b61f3ef8072206be4e04f82ec7dec750 (diff)
adding Lua52 compatibility fixes
-rw-r--r--ModuleFromCriterion.lua2
-rw-r--r--graphinspecting.lua1
-rw-r--r--node.lua2
-rw-r--r--test/test_nngraph.lua1
4 files changed, 6 insertions, 0 deletions
diff --git a/ModuleFromCriterion.lua b/ModuleFromCriterion.lua
index 8775c3d..faca485 100644
--- a/ModuleFromCriterion.lua
+++ b/ModuleFromCriterion.lua
@@ -10,6 +10,8 @@ function ModuleFromCriterion:__init(criterion)
self.gradInput = {torch.Tensor(), torch.Tensor()}
end
+local unpack = unpack or table.unpack -- lua52 compat
+
--[[ The input is a {prediction, target} pair.
The output is a tensor with one number: the criterion output.
--]]
diff --git a/graphinspecting.lua b/graphinspecting.lua
index e0676c4..8e858cf 100644
--- a/graphinspecting.lua
+++ b/graphinspecting.lua
@@ -114,6 +114,7 @@ function nngraph.setDebug(enable)
nn.gModule[funcName] = function(...)
local args = {...}
local gmodule = args[1]
+ local unpack = unpack or table.unpack
return runChecked(function()
return origFunc(unpack(args))
end, onError, gmodule)
diff --git a/node.lua b/node.lua
index f9e893e..0f2261f 100644
--- a/node.lua
+++ b/node.lua
@@ -55,6 +55,8 @@ function nnNode:split(noutput)
node:add(mnode,true)
table.insert(selectnodes,node)
end
+
+ local unpack = unpack or table.unpack -- Lua52 compat
return unpack(selectnodes)
end
diff --git a/test/test_nngraph.lua b/test/test_nngraph.lua
index 95a1658..e3d8982 100644
--- a/test/test_nngraph.lua
+++ b/test/test_nngraph.lua
@@ -121,6 +121,7 @@ function test.test_gradInputType()
local gradOutput = torch.randn(h:size())
local gradInput = module:backward(input, gradOutput)
+ local unpack = unpack or table.unpack
local gradX, gradPrevState = unpack(gradInput)
local gradPrevH, gradPrevCell = unpack(gradPrevState)
assert(type(gradPrevH) == type(h), "wrong gradPrevH type")