Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/torch7.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRonan Collobert <ronan@collobert.com>2015-10-17 01:47:53 +0300
committerRonan Collobert <ronan@collobert.com>2015-10-17 01:47:53 +0300
commiteecd6fc61b6ed02737efaac0adc5435b986e451e (patch)
tree3249d641d30fd5e91c5bd8c1ffc9a403d2ff5255 /init.lua
parentd962e0f4e8a0ea04fcc1c03c88debddfeeac9a3a (diff)
introduced luaT_newlocalmetatable.newlocalmetatable
allows class creation in a local namespace, while being backward-compatible with luaT_newmetatable.
Diffstat (limited to 'init.lua')
-rw-r--r--init.lua71
1 files changed, 49 insertions, 22 deletions
diff --git a/init.lua b/init.lua
index 29d54e1..a785f89 100644
--- a/init.lua
+++ b/init.lua
@@ -1,13 +1,12 @@
-
-- We are using paths.require to appease mkl
-- Make this work with LuaJIT in Lua 5.2 compatibility mode, which
-- renames string.gfind (already deprecated in 5.1)
if not string.gfind then
- string.gfind = string.gmatch
+ string.gfind = string.gmatch
end
if not table.unpack then
- table.unpack = unpack
+ table.unpack = unpack
end
require "paths"
@@ -17,20 +16,20 @@ paths.require "libtorch"
-- if a Lua VM is passed to another thread thread local
-- variables need to be updated.
function torch.updatethreadlocals()
- torch.updateerrorhandlers()
- local tracking = torch._heaptracking
- if tracking == nil then tracking = false end
- torch.setheaptracking(tracking)
+ torch.updateerrorhandlers()
+ local tracking = torch._heaptracking
+ if tracking == nil then tracking = false end
+ torch.setheaptracking(tracking)
end
--- package stuff
function torch.packageLuaPath(name)
if not name then
local ret = string.match(torch.packageLuaPath('torch'), '(.*)/')
- if not ret then --windows?
- ret = string.match(torch.packageLuaPath('torch'), '(.*)\\')
- end
- return ret
+ if not ret then --windows?
+ ret = string.match(torch.packageLuaPath('torch'), '(.*)\\')
+ end
+ return ret
end
for path in string.gmatch(package.path, "[^;]+") do
path = string.gsub(path, "%?", name)
@@ -39,7 +38,7 @@ function torch.packageLuaPath(name)
f:close()
local ret = string.match(path, "(.*)/")
if not ret then --windows?
- ret = string.match(path, "(.*)\\")
+ ret = string.match(path, "(.*)\\")
end
return ret
end
@@ -55,7 +54,35 @@ function torch.include(package, file)
dofile(torch.packageLuaPath(package) .. '/' .. file)
end
-function torch.class(tname, parenttname)
+function torch.class(...)
+ local tname, parenttname, module
+ if select('#', ...) == 3
+ and type(select(1, ...)) == 'string'
+ and type(select(2, ...)) == 'string'
+ and type(select(3, ...)) == 'table'
+ then
+ tname = select(1, ...)
+ parenttname = select(2, ...)
+ module = select(3, ...)
+ elseif select('#', ...) == 2
+ and type(select(1, ...)) == 'string'
+ and type(select(2, ...)) == 'string'
+ then
+ tname = select(1, ...)
+ parenttname = select(2, ...)
+ elseif select('#', ...) == 2
+ and type(select(1, ...)) == 'string'
+ and type(select(2, ...)) == 'table'
+ then
+ tname = select(1, ...)
+ module = select(2, ...)
+ elseif select('#', ...) == 1
+ and type(select(1, ...)) == 'string'
+ then
+ tname = select(1, ...)
+ else
+ error('<class name> [<parent class name>] [<module table>] expected')
+ end
local function constructor(...)
local self = {}
@@ -72,7 +99,7 @@ function torch.class(tname, parenttname)
return self
end
- local mt = torch.newmetatable(tname, parenttname, constructor, nil, factory)
+ local mt = torch.newmetatable(tname, parenttname, constructor, nil, factory, module)
local mpt
if parenttname then
mpt = torch.getmetatable(parenttname)
@@ -104,8 +131,8 @@ local function exactTypeMatch(typeName, typeSpec)
end
--[[ Returns true if the type given by the passed-in typeName either equals
-typeSpec, or ends with ".<typeSpec>". For example, "ab.cd.ef" matches type specs
-"ef", "cd.ef", and "ab.cd.ef", but not "f" or "d.ef". ]]
+ typeSpec, or ends with ".<typeSpec>". For example, "ab.cd.ef" matches type specs
+ "ef", "cd.ef", and "ab.cd.ef", but not "f" or "d.ef". ]]
local function partialTypeMatch(typeName, typeSpec)
local diffLen = #typeName - #typeSpec
@@ -131,7 +158,7 @@ function torch.isTypeOf(obj, typeSpec)
elseif type(typeSpec) == 'string' then
matchFunc = partialTypeMatch
else
- error("type must be provided as [regexp] string, or factory")
+ error("type must be provided as [regexp] string, or factory")
end
local mt = getmetatable(obj)
@@ -154,11 +181,11 @@ include('Tester.lua')
include('test.lua')
function torch.totable(obj)
- if torch.isTensor(obj) or torch.isStorage(obj) then
- return obj:totable()
- else
- error("obj must be a Storage or a Tensor")
- end
+ if torch.isTensor(obj) or torch.isStorage(obj) then
+ return obj:totable()
+ else
+ error("obj must be a Storage or a Tensor")
+ end
end
function torch.isTensor(obj)