Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/xlua.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2015-05-05 00:08:10 +0300
committerSoumith Chintala <soumith@gmail.com>2015-05-05 00:08:10 +0300
commit775ed6c39195470da876ab111bf02cc6b790e04e (patch)
tree1b8b14ae8d313a2a938285efec8d9d1762dafa0a
parent24397ac3506e57dd05f48fedc9bbabe79c39f9d2 (diff)
parentab0a21a961f70476d0f9910428600ed283e24f05 (diff)
Merge pull request #7 from colesbury/lua52
Support Lua 5.2
-rw-r--r--init.lua336
1 files changed, 164 insertions, 172 deletions
diff --git a/init.lua b/init.lua
index 4357fd8..e6036b1 100644
--- a/init.lua
+++ b/init.lua
@@ -33,174 +33,163 @@
-- June 30, 2011, 4:54PM - creation - Clement Farabet
----------------------------------------------------------------------
-require 'os'
-require 'sys'
-require 'io'
-require 'math'
-require 'torch'
-
--- remember startup variables (to protect them)
-rawset(_G, '_protect_',{'_protect_','xlua'})
+local os = require 'os'
+local sys = require 'sys'
+local io = require 'io'
+local math = require 'math'
+local torch = require 'torch'
+
+xlua = {}
+local _protect_ = {}
for k,v in pairs(_G) do
- table.insert(_G._protect_, k)
+ table.insert(_protect_, k)
end
-local glob = _G
-local torch = torch
-local pairs = pairs
-local ipairs = ipairs
-local table = table
-local string = string
-local pcall = pcall
-local loadstring = loadstring
-local _protect_ = _protect_
-
-module 'xlua'
-
-- extra files
-glob.require 'xlua.OptionParser'
-glob.require 'xlua.Profiler'
+require 'xlua.OptionParser'
+require 'xlua.Profiler'
+
----------------------------------------------------------------------
-- better print function
----------------------------------------------------------------------
-print = function(obj,...)
- if glob.type(obj) == 'table' then
- local mt = glob.getmetatable(obj)
- if mt and mt.__tostring__ then
- glob.io.write(mt.__tostring__(obj))
- else
- local tos = glob.tostring(obj)
- local obj_w_usage = false
- if tos and not glob.string.find(tos,'table: ') then
- if obj.usage and glob.type(obj.usage) == 'string' then
- glob.io.write(obj.usage)
- glob.io.write('\n\nFIELDS:\n')
- obj_w_usage = true
- else
- glob.io.write(tos .. ':\n')
- end
- end
- glob.io.write('{')
- local tab = ''
- local idx = 1
- for k,v in pairs(obj) do
- if idx > 1 then glob.io.write(',\n') end
- if glob.type(v) == 'userdata' then
- glob.io.write(tab .. '[' .. k .. ']' .. ' = <userdata>')
- else
- local tostr = glob.tostring(v):gsub('\n','\\n')
- if #tostr>40 then
- local tostrshort = tostr:sub(1,40) .. glob.sys.COLORS.none
- glob.io.write(tab .. '[' .. glob.tostring(k) .. ']' .. ' = ' .. tostrshort .. ' ... ')
- else
- glob.io.write(tab .. '[' .. glob.tostring(k) .. ']' .. ' = ' .. tostr)
- end
- end
- tab = ' '
- idx = idx + 1
- end
- glob.io.write('}')
- if obj_w_usage then
- glob.io.write('')
- end
- end
- else
- glob.io.write(glob.tostring(obj))
- end
- if glob.select('#',...) > 0 then
- glob.io.write(' ')
- print(...)
- else
- glob.io.write('\n')
- end
- end
-glob.rawset(glob, 'xprint', print)
+function xlua.print(obj,...)
+ if type(obj) == 'table' then
+ local mt = getmetatable(obj)
+ if mt and mt.__tostring__ then
+ io.write(mt.__tostring__(obj))
+ else
+ local tos = tostring(obj)
+ local obj_w_usage = false
+ if tos and not string.find(tos,'table: ') then
+ if obj.usage and type(obj.usage) == 'string' then
+ io.write(obj.usage)
+ io.write('\n\nFIELDS:\n')
+ obj_w_usage = true
+ else
+ io.write(tos .. ':\n')
+ end
+ end
+ io.write('{')
+ local tab = ''
+ local idx = 1
+ for k,v in pairs(obj) do
+ if idx > 1 then io.write(',\n') end
+ if type(v) == 'userdata' then
+ io.write(tab .. '[' .. k .. ']' .. ' = <userdata>')
+ else
+ local tostr = tostring(v):gsub('\n','\\n')
+ if #tostr>40 then
+ local tostrshort = tostr:sub(1,40) .. sys.COLORS.none
+ io.write(tab .. '[' .. tostring(k) .. ']' .. ' = ' .. tostrshort .. ' ... ')
+ else
+ io.write(tab .. '[' .. tostring(k) .. ']' .. ' = ' .. tostr)
+ end
+ end
+ tab = ' '
+ idx = idx + 1
+ end
+ io.write('}')
+ if obj_w_usage then
+ io.write('')
+ end
+ end
+ else
+ io.write(tostring(obj))
+ end
+ if select('#',...) > 0 then
+ io.write(' ')
+ print(...)
+ else
+ io.write('\n')
+ end
+end
+rawset(_G, 'xprint', xlua.print)
----------------------------------------------------------------------
-- log all session, by replicating stdout to a file
----------------------------------------------------------------------
-log = function(file)
- glob.os.execute('mkdir -p "' .. glob.sys.dirname(file) .. '"')
- local f = glob.assert(glob.io.open(file,'w'))
- glob.io._write = glob.io.write
- glob._print = glob.print
- glob.print = glob.xprint
- glob.io.write = function(...)
- glob.io._write(...)
- local arg = {...}
- for i = 1,glob.select('#',...) do
- f:write(arg[i])
- end
- f:flush()
- end
- end
+function xlua.log(file)
+ os.execute('mkdir -p "' .. sys.dirname(file) .. '"')
+ local f = assert(io.open(file,'w'))
+ io._write = io.write
+ _G._print = _G.print
+ _G.print = xlua.print
+ io.write = function(...)
+ io._write(...)
+ local arg = {...}
+ for i = 1,select('#',...) do
+ f:write(arg[i])
+ end
+ f:flush()
+ end
+end
----------------------------------------------------------------------
-- clear all globals
----------------------------------------------------------------------
-clearall = function()
- for k,v in pairs(glob) do
+function xlua.clearall()
+ for k,v in pairs(_G) do
local protected = false
local lib = false
for i,p in ipairs(_protect_) do
- if k == p then protected = true end
+ if k == p then protected = true end
end
- for p in pairs(glob.package.loaded) do
- if k == p then lib = true end
+ for p in pairs(package.loaded) do
+ if k == p then lib = true end
end
if not protected then
- glob[k] = nil
- if lib then glob.package.loaded[k] = nil end
+ _G[k] = nil
+ if lib then package.loaded[k] = nil end
end
end
- glob.collectgarbage()
+ collectgarbage()
end
----------------------------------------------------------------------
-- clear one variable
----------------------------------------------------------------------
-clear = function(var)
- glob[var] = nil
- glob.collectgarbage()
+function xlua.clear(var)
+ _G[var] = nil
+ collectgarbage()
end
----------------------------------------------------------------------
-- prints globals
----------------------------------------------------------------------
-who = function()
+function xlua.who()
local user = {}
local libs = {}
- for k,v in pairs(glob) do
+ for k,v in pairs(_G) do
local protected = false
local lib = false
for i,p in ipairs(_protect_) do
- if k == p then protected = true end
+ if k == p then protected = true end
end
- for p in pairs(glob.package.loaded) do
- if k == p and p ~= '_G' then lib = true end
+ for p in pairs(package.loaded) do
+ if k == p and p ~= '_G' then lib = true end
end
if lib then
- glob.table.insert(libs, k)
+ table.insert(libs, k)
elseif not protected then
- user[k] = glob[k]
+ user[k] = _G[k]
end
end
- print('')
- print('Global Libs:')
- print(libs)
- print('')
- print('Global Vars:')
- print(user)
- print('')
+ xlua.print('')
+ xlua.print('Global Libs:')
+ xlua.print(libs)
+ xlua.print('')
+ xlua.print('Global Vars:')
+ xlua.print(user)
+ xlua.print('')
end
----------------------------------------------------------------------
-- time
----------------------------------------------------------------------
-function formatTime(seconds)
+function xlua.formatTime(seconds)
-- decompose:
- local floor = glob.math.floor
+ local floor = math.floor
local days = floor(seconds / 3600/24)
seconds = seconds - days*3600*24
local hours = floor(seconds / 3600)
@@ -224,14 +213,15 @@ function formatTime(seconds)
-- return formatted time
return f
end
+local formatTime = xlua.formatTime
----------------------------------------------------------------------
-- progress bar
----------------------------------------------------------------------
do
local function getTermLength()
- local tputf = glob.io.popen('tput cols', 'r')
- local w = glob.tonumber(tputf:read('*a'))
+ local tputf = io.popen('tput cols', 'r')
+ local w = tonumber(tputf:read('*a'))
local rc = {tputf:close()}
if rc[3] == 0 then return w
else return 80 end
@@ -244,14 +234,14 @@ do
local times
local indices
local termLength = getTermLength()
- function progress(current, goal)
+ function xlua.progress(current, goal)
-- defaults:
local barLength = termLength - 34
local smoothing = 100
local maxfps = 10
-- Compute percentage
- local percent = glob.math.floor(((current) * barLength) / goal)
+ local percent = math.floor(((current) * barLength) / goal)
-- start new bar
if (barDone and ((previous == -1) or (percent < previous))) then
@@ -262,27 +252,27 @@ do
times = {timer:time().real}
indices = {current}
else
- glob.io.write('\r')
+ io.write('\r')
end
--if (percent ~= previous and not barDone) then
if (not barDone) then
previous = percent
-- print bar
- glob.io.write(' [')
+ io.write(' [')
for i=1,barLength do
- if (i < percent) then glob.io.write('=')
- elseif (i == percent) then glob.io.write('>')
- else glob.io.write('.') end
+ if (i < percent) then io.write('=')
+ elseif (i == percent) then io.write('>')
+ else io.write('.') end
end
- glob.io.write('] ')
- for i=1,termLength-barLength-4 do glob.io.write(' ') end
- for i=1,termLength-barLength-4 do glob.io.write('\b') end
+ io.write('] ')
+ for i=1,termLength-barLength-4 do io.write(' ') end
+ for i=1,termLength-barLength-4 do io.write('\b') end
-- time stats
local elapsed = timer:time().real
local step = (elapsed-times[1]) / (current-indices[1])
if current==indices[1] then step = 0 end
- local remaining = glob.math.max(0,(goal - current)*step)
+ local remaining = math.max(0,(goal - current)*step)
table.insert(indices, current)
table.insert(times, elapsed)
if #indices > smoothing then
@@ -290,18 +280,18 @@ do
times = table.splice(times)
end
tm = 'ETA: ' .. formatTime(remaining) .. ' | Step: ' .. formatTime(step)
- glob.io.write(tm)
+ io.write(tm)
-- go back to center of bar, and print progress
- for i=1,6+#tm+barLength/2 do glob.io.write('\b') end
- glob.io.write(' ', current, '/', goal, ' ')
+ for i=1,6+#tm+barLength/2 do io.write('\b') end
+ io.write(' ', current, '/', goal, ' ')
-- reset for next bar
if (percent == barLength) then
barDone = true
- glob.io.write('\n')
+ io.write('\n')
end
-- flush
- glob.io.write('\r')
- glob.io.flush()
+ io.write('\r')
+ io.flush()
end
end
end
@@ -310,36 +300,36 @@ end
-- prints an error with nice formatting. If domain is provided, it is used as
-- following: <domain> msg
--------------------------------------------------------------------------------
-function error(message, domain, usage)
- local c = glob.sys.COLORS
+function xlua.error(message, domain, usage)
+ local c = sys.COLORS
if domain then
message = '<' .. domain .. '> ' .. message
end
- local col_msg = c.Red .. glob.tostring(message) .. c.none
+ local col_msg = c.Red .. tostring(message) .. c.none
if usage then
col_msg = col_msg .. '\n' .. usage
end
- glob.error(col_msg)
+ error(col_msg)
end
-glob.rawset(glob, 'xerror', error)
+rawset(_G, 'xerror', xlua.error)
--------------------------------------------------------------------------------
-- provides standard try/catch functions
--------------------------------------------------------------------------------
-function trycatch(try,catch)
- local ok,err = glob.pcall(func)
+function xlua.trycatch(try,catch)
+ local ok,err = pcall(func)
if not ok then catch(err) end
end
--------------------------------------------------------------------------------
-- returns true if package is installed, rather than crashing stupidly :-)
--------------------------------------------------------------------------------
-function installed(package)
+function xlua.installed(package)
local found = false
- local p = glob.package.path .. ';' .. glob.package.cpath
+ local p = package.path .. ';' .. package.cpath
for path in p:gfind('.-;') do
path = path:gsub(';',''):gsub('?',package)
- if glob.sys.filep(path) then
+ if sys.filep(path) then
found = true
p = path
break
@@ -356,10 +346,10 @@ end
-- @param luarocks if true, then try to install missing package with luarocks
-- @param server specify a luarocks server
--------------------------------------------------------------------------------
-function require(package,luarocks,server)
+function xlua.require(package,luarocks,server)
local loaded
- local load = function() loaded = glob.require(package) end
- local ok,err = glob.pcall(load)
+ local load = function() loaded = require(package) end
+ local ok,err = pcall(load)
if not ok then
print(err)
print('warning: <' .. package .. '> could not be loaded (is it installed?)')
@@ -367,7 +357,7 @@ function require(package,luarocks,server)
end
return loaded
end
-glob.rawset(glob, 'xrequire', require)
+rawset(_G, 'xrequire', xlua.require)
--------------------------------------------------------------------------------
-- standard usage function: used to display automated help for functions
@@ -377,8 +367,8 @@ glob.rawset(glob, 'xrequire', require)
-- @param example usage example
-- @param ... [optional] arguments
--------------------------------------------------------------------------------
-function usage(funcname, description, example, ...)
- local c = glob.sys.COLORS
+function xlua.usage(funcname, description, example, ...)
+ local c = sys.COLORS
local style = {
banner = '+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++',
@@ -421,7 +411,7 @@ function usage(funcname, description, example, ...)
end
str = str .. key .. '-- ' .. param.help
if param.default or param.default == false then
- str = str .. ' [default = ' .. glob.tostring(param.default) .. ']'
+ str = str .. ' [default = ' .. tostring(param.default) .. ']'
elseif param.defaulta then
str = str .. ' [default == ' .. param.defaulta .. ']'
end
@@ -469,7 +459,7 @@ end
-- standard argument function: used to handle named arguments, and
-- display automated help for functions
--------------------------------------------------------------------------------
-function unpack(args, funcname, description, ...)
+function xlua.unpack(args, funcname, description, ...)
-- put args in table
local defs = {...}
@@ -482,23 +472,23 @@ function unpack(args, funcname, description, ...)
.. defs[1].arg .. '=' .. defs[1].type .. '}\n'
example = example .. funcname .. '(' .. defs[1].type .. ',' .. ' ...)'
end
- return usage(funcname, description, example, {tabled=defs})
+ return xlua.usage(funcname, description, example, {tabled=defs})
end
local usage = {}
- glob.setmetatable(usage, {__tostring=fusage})
+ setmetatable(usage, {__tostring=fusage})
-- get args
local iargs = {}
if #args == 0 then
print(usage)
- error('error')
- elseif #args == 1 and glob.type(args[1]) == 'table' and #args[1] == 0
- and not (glob.torch and glob.torch.typename(args[1]) ~= nil) then
+ xlua.error('error')
+ elseif #args == 1 and type(args[1]) == 'table' and #args[1] == 0
+ and not (torch and torch.typename(args[1]) ~= nil) then
-- named args
iargs = args[1]
else
-- ordered args
- for i = 1,glob.select('#',...) do
+ for i = 1,select('#',...) do
iargs[defs[i].arg] = args[i]
end
end
@@ -509,10 +499,10 @@ function unpack(args, funcname, description, ...)
local def = defs[i]
-- is value requested ?
if def.req and iargs[def.arg] == nil then
- local c = glob.sys.COLORS
+ local c = sys.COLORS
print(c.Red .. 'missing argument: ' .. def.arg .. c.none)
print(usage)
- error('error')
+ xlua.error('error')
end
-- get value or default
dargs[def.arg] = iargs[def.arg]
@@ -530,7 +520,7 @@ function unpack(args, funcname, description, ...)
-- stupid lua bug: we return all args by hand
if dargs[65] then
- error('<xlua.unpack> oups, cant deal with more than 64 arguments :-)')
+ xlua.error('<xlua.unpack> oups, cant deal with more than 64 arguments :-)')
end
-- return modified args
@@ -550,10 +540,10 @@ end
-- display automated help for functions
-- auto inits the self with usage
--------------------------------------------------------------------------------
-function unpack_class(object, args, funcname, description, ...)
- local dargs = unpack(args, funcname, description, ...)
+function xlua.unpack_class(object, args, funcname, description, ...)
+ local dargs = xlua.unpack(args, funcname, description, ...)
for k,v in pairs(dargs) do
- if glob.type(k) ~= 'number' then
+ if type(k) ~= 'number' then
object[k] = v
end
end
@@ -566,8 +556,8 @@ end
-- @param name module name
-- @param description description of the module
--------------------------------------------------------------------------------
-function usage_module(module, name, description)
- local c = glob.sys.COLORS
+function xlua.usage_module(module, name, description)
+ local c = sys.COLORS
local hasglobals = false
local str = c.magenta
local str = str .. '+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n'
@@ -578,8 +568,8 @@ function usage_module(module, name, description)
str = str .. '++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++'
str = str .. c.none
-- register help
- local mt = glob.getmetatable(module) or {}
- glob.setmetatable(module,mt)
+ local mt = getmetatable(module) or {}
+ setmetatable(module,mt)
mt.__tostring = function() return str end
return str
end
@@ -663,10 +653,12 @@ end
function string.tosymbol(str)
local ok,result = pcall(loadstring('return ' .. str))
if not ok then
- glob.error(result)
+ error(result)
elseif not result then
- glob.error('symbol "' .. str .. '" does not exist')
+ error('symbol "' .. str .. '" does not exist')
else
return result
end
end
+
+return xlua \ No newline at end of file