diff options
author | Ronan Collobert <ronan@collobert.com> | 2015-10-17 01:47:53 +0300 |
---|---|---|
committer | Ronan Collobert <ronan@collobert.com> | 2015-10-17 01:47:53 +0300 |
commit | eecd6fc61b6ed02737efaac0adc5435b986e451e (patch) | |
tree | 3249d641d30fd5e91c5bd8c1ffc9a403d2ff5255 /init.lua | |
parent | d962e0f4e8a0ea04fcc1c03c88debddfeeac9a3a (diff) |
introduced luaT_newlocalmetatable.newlocalmetatable
allows class creation in a local namespace, while being backward-compatible
with luaT_newmetatable.
Diffstat (limited to 'init.lua')
-rw-r--r-- | init.lua | 71 |
1 files changed, 49 insertions, 22 deletions
@@ -1,13 +1,12 @@ - -- We are using paths.require to appease mkl -- Make this work with LuaJIT in Lua 5.2 compatibility mode, which -- renames string.gfind (already deprecated in 5.1) if not string.gfind then - string.gfind = string.gmatch + string.gfind = string.gmatch end if not table.unpack then - table.unpack = unpack + table.unpack = unpack end require "paths" @@ -17,20 +16,20 @@ paths.require "libtorch" -- if a Lua VM is passed to another thread thread local -- variables need to be updated. function torch.updatethreadlocals() - torch.updateerrorhandlers() - local tracking = torch._heaptracking - if tracking == nil then tracking = false end - torch.setheaptracking(tracking) + torch.updateerrorhandlers() + local tracking = torch._heaptracking + if tracking == nil then tracking = false end + torch.setheaptracking(tracking) end --- package stuff function torch.packageLuaPath(name) if not name then local ret = string.match(torch.packageLuaPath('torch'), '(.*)/') - if not ret then --windows? - ret = string.match(torch.packageLuaPath('torch'), '(.*)\\') - end - return ret + if not ret then --windows? + ret = string.match(torch.packageLuaPath('torch'), '(.*)\\') + end + return ret end for path in string.gmatch(package.path, "[^;]+") do path = string.gsub(path, "%?", name) @@ -39,7 +38,7 @@ function torch.packageLuaPath(name) f:close() local ret = string.match(path, "(.*)/") if not ret then --windows? - ret = string.match(path, "(.*)\\") + ret = string.match(path, "(.*)\\") end return ret end @@ -55,7 +54,35 @@ function torch.include(package, file) dofile(torch.packageLuaPath(package) .. '/' .. file) end -function torch.class(tname, parenttname) +function torch.class(...) + local tname, parenttname, module + if select('#', ...) == 3 + and type(select(1, ...)) == 'string' + and type(select(2, ...)) == 'string' + and type(select(3, ...)) == 'table' + then + tname = select(1, ...) + parenttname = select(2, ...) + module = select(3, ...) + elseif select('#', ...) == 2 + and type(select(1, ...)) == 'string' + and type(select(2, ...)) == 'string' + then + tname = select(1, ...) + parenttname = select(2, ...) + elseif select('#', ...) == 2 + and type(select(1, ...)) == 'string' + and type(select(2, ...)) == 'table' + then + tname = select(1, ...) + module = select(2, ...) + elseif select('#', ...) == 1 + and type(select(1, ...)) == 'string' + then + tname = select(1, ...) + else + error('<class name> [<parent class name>] [<module table>] expected') + end local function constructor(...) local self = {} @@ -72,7 +99,7 @@ function torch.class(tname, parenttname) return self end - local mt = torch.newmetatable(tname, parenttname, constructor, nil, factory) + local mt = torch.newmetatable(tname, parenttname, constructor, nil, factory, module) local mpt if parenttname then mpt = torch.getmetatable(parenttname) @@ -104,8 +131,8 @@ local function exactTypeMatch(typeName, typeSpec) end --[[ Returns true if the type given by the passed-in typeName either equals -typeSpec, or ends with ".<typeSpec>". For example, "ab.cd.ef" matches type specs -"ef", "cd.ef", and "ab.cd.ef", but not "f" or "d.ef". ]] + typeSpec, or ends with ".<typeSpec>". For example, "ab.cd.ef" matches type specs + "ef", "cd.ef", and "ab.cd.ef", but not "f" or "d.ef". ]] local function partialTypeMatch(typeName, typeSpec) local diffLen = #typeName - #typeSpec @@ -131,7 +158,7 @@ function torch.isTypeOf(obj, typeSpec) elseif type(typeSpec) == 'string' then matchFunc = partialTypeMatch else - error("type must be provided as [regexp] string, or factory") + error("type must be provided as [regexp] string, or factory") end local mt = getmetatable(obj) @@ -154,11 +181,11 @@ include('Tester.lua') include('test.lua') function torch.totable(obj) - if torch.isTensor(obj) or torch.isStorage(obj) then - return obj:totable() - else - error("obj must be a Storage or a Tensor") - end + if torch.isTensor(obj) or torch.isStorage(obj) then + return obj:totable() + else + error("obj must be a Storage or a Tensor") + end end function torch.isTensor(obj) |