From ab0a21a961f70476d0f9910428600ed283e24f05 Mon Sep 17 00:00:00 2001 From: Sam Gross Date: Mon, 4 May 2015 16:55:44 -0400 Subject: Support Lua 5.2 --- init.lua | 336 +++++++++++++++++++++++++++++++-------------------------------- 1 file changed, 164 insertions(+), 172 deletions(-) diff --git a/init.lua b/init.lua index 4357fd8..e6036b1 100644 --- a/init.lua +++ b/init.lua @@ -33,174 +33,163 @@ -- June 30, 2011, 4:54PM - creation - Clement Farabet ---------------------------------------------------------------------- -require 'os' -require 'sys' -require 'io' -require 'math' -require 'torch' - --- remember startup variables (to protect them) -rawset(_G, '_protect_',{'_protect_','xlua'}) +local os = require 'os' +local sys = require 'sys' +local io = require 'io' +local math = require 'math' +local torch = require 'torch' + +xlua = {} +local _protect_ = {} for k,v in pairs(_G) do - table.insert(_G._protect_, k) + table.insert(_protect_, k) end -local glob = _G -local torch = torch -local pairs = pairs -local ipairs = ipairs -local table = table -local string = string -local pcall = pcall -local loadstring = loadstring -local _protect_ = _protect_ - -module 'xlua' - -- extra files -glob.require 'xlua.OptionParser' -glob.require 'xlua.Profiler' +require 'xlua.OptionParser' +require 'xlua.Profiler' + ---------------------------------------------------------------------- -- better print function ---------------------------------------------------------------------- -print = function(obj,...) - if glob.type(obj) == 'table' then - local mt = glob.getmetatable(obj) - if mt and mt.__tostring__ then - glob.io.write(mt.__tostring__(obj)) - else - local tos = glob.tostring(obj) - local obj_w_usage = false - if tos and not glob.string.find(tos,'table: ') then - if obj.usage and glob.type(obj.usage) == 'string' then - glob.io.write(obj.usage) - glob.io.write('\n\nFIELDS:\n') - obj_w_usage = true - else - glob.io.write(tos .. ':\n') - end - end - glob.io.write('{') - local tab = '' - local idx = 1 - for k,v in pairs(obj) do - if idx > 1 then glob.io.write(',\n') end - if glob.type(v) == 'userdata' then - glob.io.write(tab .. '[' .. k .. ']' .. ' = ') - else - local tostr = glob.tostring(v):gsub('\n','\\n') - if #tostr>40 then - local tostrshort = tostr:sub(1,40) .. glob.sys.COLORS.none - glob.io.write(tab .. '[' .. glob.tostring(k) .. ']' .. ' = ' .. tostrshort .. ' ... ') - else - glob.io.write(tab .. '[' .. glob.tostring(k) .. ']' .. ' = ' .. tostr) - end - end - tab = ' ' - idx = idx + 1 - end - glob.io.write('}') - if obj_w_usage then - glob.io.write('') - end - end - else - glob.io.write(glob.tostring(obj)) - end - if glob.select('#',...) > 0 then - glob.io.write(' ') - print(...) - else - glob.io.write('\n') - end - end -glob.rawset(glob, 'xprint', print) +function xlua.print(obj,...) + if type(obj) == 'table' then + local mt = getmetatable(obj) + if mt and mt.__tostring__ then + io.write(mt.__tostring__(obj)) + else + local tos = tostring(obj) + local obj_w_usage = false + if tos and not string.find(tos,'table: ') then + if obj.usage and type(obj.usage) == 'string' then + io.write(obj.usage) + io.write('\n\nFIELDS:\n') + obj_w_usage = true + else + io.write(tos .. ':\n') + end + end + io.write('{') + local tab = '' + local idx = 1 + for k,v in pairs(obj) do + if idx > 1 then io.write(',\n') end + if type(v) == 'userdata' then + io.write(tab .. '[' .. k .. ']' .. ' = ') + else + local tostr = tostring(v):gsub('\n','\\n') + if #tostr>40 then + local tostrshort = tostr:sub(1,40) .. sys.COLORS.none + io.write(tab .. '[' .. tostring(k) .. ']' .. ' = ' .. tostrshort .. ' ... ') + else + io.write(tab .. '[' .. tostring(k) .. ']' .. ' = ' .. tostr) + end + end + tab = ' ' + idx = idx + 1 + end + io.write('}') + if obj_w_usage then + io.write('') + end + end + else + io.write(tostring(obj)) + end + if select('#',...) > 0 then + io.write(' ') + print(...) + else + io.write('\n') + end +end +rawset(_G, 'xprint', xlua.print) ---------------------------------------------------------------------- -- log all session, by replicating stdout to a file ---------------------------------------------------------------------- -log = function(file) - glob.os.execute('mkdir -p "' .. glob.sys.dirname(file) .. '"') - local f = glob.assert(glob.io.open(file,'w')) - glob.io._write = glob.io.write - glob._print = glob.print - glob.print = glob.xprint - glob.io.write = function(...) - glob.io._write(...) - local arg = {...} - for i = 1,glob.select('#',...) do - f:write(arg[i]) - end - f:flush() - end - end +function xlua.log(file) + os.execute('mkdir -p "' .. sys.dirname(file) .. '"') + local f = assert(io.open(file,'w')) + io._write = io.write + _G._print = _G.print + _G.print = xlua.print + io.write = function(...) + io._write(...) + local arg = {...} + for i = 1,select('#',...) do + f:write(arg[i]) + end + f:flush() + end +end ---------------------------------------------------------------------- -- clear all globals ---------------------------------------------------------------------- -clearall = function() - for k,v in pairs(glob) do +function xlua.clearall() + for k,v in pairs(_G) do local protected = false local lib = false for i,p in ipairs(_protect_) do - if k == p then protected = true end + if k == p then protected = true end end - for p in pairs(glob.package.loaded) do - if k == p then lib = true end + for p in pairs(package.loaded) do + if k == p then lib = true end end if not protected then - glob[k] = nil - if lib then glob.package.loaded[k] = nil end + _G[k] = nil + if lib then package.loaded[k] = nil end end end - glob.collectgarbage() + collectgarbage() end ---------------------------------------------------------------------- -- clear one variable ---------------------------------------------------------------------- -clear = function(var) - glob[var] = nil - glob.collectgarbage() +function xlua.clear(var) + _G[var] = nil + collectgarbage() end ---------------------------------------------------------------------- -- prints globals ---------------------------------------------------------------------- -who = function() +function xlua.who() local user = {} local libs = {} - for k,v in pairs(glob) do + for k,v in pairs(_G) do local protected = false local lib = false for i,p in ipairs(_protect_) do - if k == p then protected = true end + if k == p then protected = true end end - for p in pairs(glob.package.loaded) do - if k == p and p ~= '_G' then lib = true end + for p in pairs(package.loaded) do + if k == p and p ~= '_G' then lib = true end end if lib then - glob.table.insert(libs, k) + table.insert(libs, k) elseif not protected then - user[k] = glob[k] + user[k] = _G[k] end end - print('') - print('Global Libs:') - print(libs) - print('') - print('Global Vars:') - print(user) - print('') + xlua.print('') + xlua.print('Global Libs:') + xlua.print(libs) + xlua.print('') + xlua.print('Global Vars:') + xlua.print(user) + xlua.print('') end ---------------------------------------------------------------------- -- time ---------------------------------------------------------------------- -function formatTime(seconds) +function xlua.formatTime(seconds) -- decompose: - local floor = glob.math.floor + local floor = math.floor local days = floor(seconds / 3600/24) seconds = seconds - days*3600*24 local hours = floor(seconds / 3600) @@ -224,14 +213,15 @@ function formatTime(seconds) -- return formatted time return f end +local formatTime = xlua.formatTime ---------------------------------------------------------------------- -- progress bar ---------------------------------------------------------------------- do local function getTermLength() - local tputf = glob.io.popen('tput cols', 'r') - local w = glob.tonumber(tputf:read('*a')) + local tputf = io.popen('tput cols', 'r') + local w = tonumber(tputf:read('*a')) local rc = {tputf:close()} if rc[3] == 0 then return w else return 80 end @@ -244,14 +234,14 @@ do local times local indices local termLength = getTermLength() - function progress(current, goal) + function xlua.progress(current, goal) -- defaults: local barLength = termLength - 34 local smoothing = 100 local maxfps = 10 -- Compute percentage - local percent = glob.math.floor(((current) * barLength) / goal) + local percent = math.floor(((current) * barLength) / goal) -- start new bar if (barDone and ((previous == -1) or (percent < previous))) then @@ -262,27 +252,27 @@ do times = {timer:time().real} indices = {current} else - glob.io.write('\r') + io.write('\r') end --if (percent ~= previous and not barDone) then if (not barDone) then previous = percent -- print bar - glob.io.write(' [') + io.write(' [') for i=1,barLength do - if (i < percent) then glob.io.write('=') - elseif (i == percent) then glob.io.write('>') - else glob.io.write('.') end + if (i < percent) then io.write('=') + elseif (i == percent) then io.write('>') + else io.write('.') end end - glob.io.write('] ') - for i=1,termLength-barLength-4 do glob.io.write(' ') end - for i=1,termLength-barLength-4 do glob.io.write('\b') end + io.write('] ') + for i=1,termLength-barLength-4 do io.write(' ') end + for i=1,termLength-barLength-4 do io.write('\b') end -- time stats local elapsed = timer:time().real local step = (elapsed-times[1]) / (current-indices[1]) if current==indices[1] then step = 0 end - local remaining = glob.math.max(0,(goal - current)*step) + local remaining = math.max(0,(goal - current)*step) table.insert(indices, current) table.insert(times, elapsed) if #indices > smoothing then @@ -290,18 +280,18 @@ do times = table.splice(times) end tm = 'ETA: ' .. formatTime(remaining) .. ' | Step: ' .. formatTime(step) - glob.io.write(tm) + io.write(tm) -- go back to center of bar, and print progress - for i=1,6+#tm+barLength/2 do glob.io.write('\b') end - glob.io.write(' ', current, '/', goal, ' ') + for i=1,6+#tm+barLength/2 do io.write('\b') end + io.write(' ', current, '/', goal, ' ') -- reset for next bar if (percent == barLength) then barDone = true - glob.io.write('\n') + io.write('\n') end -- flush - glob.io.write('\r') - glob.io.flush() + io.write('\r') + io.flush() end end end @@ -310,36 +300,36 @@ end -- prints an error with nice formatting. If domain is provided, it is used as -- following: msg -------------------------------------------------------------------------------- -function error(message, domain, usage) - local c = glob.sys.COLORS +function xlua.error(message, domain, usage) + local c = sys.COLORS if domain then message = '<' .. domain .. '> ' .. message end - local col_msg = c.Red .. glob.tostring(message) .. c.none + local col_msg = c.Red .. tostring(message) .. c.none if usage then col_msg = col_msg .. '\n' .. usage end - glob.error(col_msg) + error(col_msg) end -glob.rawset(glob, 'xerror', error) +rawset(_G, 'xerror', xlua.error) -------------------------------------------------------------------------------- -- provides standard try/catch functions -------------------------------------------------------------------------------- -function trycatch(try,catch) - local ok,err = glob.pcall(func) +function xlua.trycatch(try,catch) + local ok,err = pcall(func) if not ok then catch(err) end end -------------------------------------------------------------------------------- -- returns true if package is installed, rather than crashing stupidly :-) -------------------------------------------------------------------------------- -function installed(package) +function xlua.installed(package) local found = false - local p = glob.package.path .. ';' .. glob.package.cpath + local p = package.path .. ';' .. package.cpath for path in p:gfind('.-;') do path = path:gsub(';',''):gsub('?',package) - if glob.sys.filep(path) then + if sys.filep(path) then found = true p = path break @@ -356,10 +346,10 @@ end -- @param luarocks if true, then try to install missing package with luarocks -- @param server specify a luarocks server -------------------------------------------------------------------------------- -function require(package,luarocks,server) +function xlua.require(package,luarocks,server) local loaded - local load = function() loaded = glob.require(package) end - local ok,err = glob.pcall(load) + local load = function() loaded = require(package) end + local ok,err = pcall(load) if not ok then print(err) print('warning: <' .. package .. '> could not be loaded (is it installed?)') @@ -367,7 +357,7 @@ function require(package,luarocks,server) end return loaded end -glob.rawset(glob, 'xrequire', require) +rawset(_G, 'xrequire', xlua.require) -------------------------------------------------------------------------------- -- standard usage function: used to display automated help for functions @@ -377,8 +367,8 @@ glob.rawset(glob, 'xrequire', require) -- @param example usage example -- @param ... [optional] arguments -------------------------------------------------------------------------------- -function usage(funcname, description, example, ...) - local c = glob.sys.COLORS +function xlua.usage(funcname, description, example, ...) + local c = sys.COLORS local style = { banner = '+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++', @@ -421,7 +411,7 @@ function usage(funcname, description, example, ...) end str = str .. key .. '-- ' .. param.help if param.default or param.default == false then - str = str .. ' [default = ' .. glob.tostring(param.default) .. ']' + str = str .. ' [default = ' .. tostring(param.default) .. ']' elseif param.defaulta then str = str .. ' [default == ' .. param.defaulta .. ']' end @@ -469,7 +459,7 @@ end -- standard argument function: used to handle named arguments, and -- display automated help for functions -------------------------------------------------------------------------------- -function unpack(args, funcname, description, ...) +function xlua.unpack(args, funcname, description, ...) -- put args in table local defs = {...} @@ -482,23 +472,23 @@ function unpack(args, funcname, description, ...) .. defs[1].arg .. '=' .. defs[1].type .. '}\n' example = example .. funcname .. '(' .. defs[1].type .. ',' .. ' ...)' end - return usage(funcname, description, example, {tabled=defs}) + return xlua.usage(funcname, description, example, {tabled=defs}) end local usage = {} - glob.setmetatable(usage, {__tostring=fusage}) + setmetatable(usage, {__tostring=fusage}) -- get args local iargs = {} if #args == 0 then print(usage) - error('error') - elseif #args == 1 and glob.type(args[1]) == 'table' and #args[1] == 0 - and not (glob.torch and glob.torch.typename(args[1]) ~= nil) then + xlua.error('error') + elseif #args == 1 and type(args[1]) == 'table' and #args[1] == 0 + and not (torch and torch.typename(args[1]) ~= nil) then -- named args iargs = args[1] else -- ordered args - for i = 1,glob.select('#',...) do + for i = 1,select('#',...) do iargs[defs[i].arg] = args[i] end end @@ -509,10 +499,10 @@ function unpack(args, funcname, description, ...) local def = defs[i] -- is value requested ? if def.req and iargs[def.arg] == nil then - local c = glob.sys.COLORS + local c = sys.COLORS print(c.Red .. 'missing argument: ' .. def.arg .. c.none) print(usage) - error('error') + xlua.error('error') end -- get value or default dargs[def.arg] = iargs[def.arg] @@ -530,7 +520,7 @@ function unpack(args, funcname, description, ...) -- stupid lua bug: we return all args by hand if dargs[65] then - error(' oups, cant deal with more than 64 arguments :-)') + xlua.error(' oups, cant deal with more than 64 arguments :-)') end -- return modified args @@ -550,10 +540,10 @@ end -- display automated help for functions -- auto inits the self with usage -------------------------------------------------------------------------------- -function unpack_class(object, args, funcname, description, ...) - local dargs = unpack(args, funcname, description, ...) +function xlua.unpack_class(object, args, funcname, description, ...) + local dargs = xlua.unpack(args, funcname, description, ...) for k,v in pairs(dargs) do - if glob.type(k) ~= 'number' then + if type(k) ~= 'number' then object[k] = v end end @@ -566,8 +556,8 @@ end -- @param name module name -- @param description description of the module -------------------------------------------------------------------------------- -function usage_module(module, name, description) - local c = glob.sys.COLORS +function xlua.usage_module(module, name, description) + local c = sys.COLORS local hasglobals = false local str = c.magenta local str = str .. '+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n' @@ -578,8 +568,8 @@ function usage_module(module, name, description) str = str .. '++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++' str = str .. c.none -- register help - local mt = glob.getmetatable(module) or {} - glob.setmetatable(module,mt) + local mt = getmetatable(module) or {} + setmetatable(module,mt) mt.__tostring = function() return str end return str end @@ -663,10 +653,12 @@ end function string.tosymbol(str) local ok,result = pcall(loadstring('return ' .. str)) if not ok then - glob.error(result) + error(result) elseif not result then - glob.error('symbol "' .. str .. '" does not exist') + error('symbol "' .. str .. '" does not exist') else return result end end + +return xlua \ No newline at end of file -- cgit v1.2.3