Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/trepl.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorClement Farabet <clement.farabet@gmail.com>2014-07-30 04:22:57 +0400
committerClement Farabet <clement.farabet@gmail.com>2014-07-30 04:22:57 +0400
commit782fb01a49344ae765098f21e07010876d080e9e (patch)
treed88ffddfbcc68eabb6875da3a7a69063e80e4e9b /init.lua
Synced trepl
Diffstat (limited to 'init.lua')
-rw-r--r--init.lua769
1 files changed, 769 insertions, 0 deletions
diff --git a/init.lua b/init.lua
new file mode 100644
index 0000000..55d6ef2
--- /dev/null
+++ b/init.lua
@@ -0,0 +1,769 @@
+--[============================================================================[
+ REPL: A REPL for Lua (with support for Torch objects).
+
+ This REPL is embeddable, and doesn't depend on C libraries.
+ It's usable with Torch, and with MOAI.
+
+ For full completion support, and history, install lua-linenoise:
+ $ luarocks install linenoise
+
+ Support for SHELL commands:
+ > $ ls
+ > $ ll
+ > $ ls -l
+ (prepend any command by $, from the Lua interpreter)
+
+ Copyright: MIT / BSD / Do whatever you want with it.
+ Clement Farabet, 2013
+--]============================================================================]
+
+-- Require Torch
+pcall(require,'torch')
+pcall(require,'paths')
+
+-- Colors:
+local colors = require 'trepl.colors'
+local col = require 'trepl.colorize'
+
+-- Help string:
+local selfhelp = [[
+ ______ __
+ /_ __/__ ________/ /
+ / / / _ \/ __/ __/ _ \
+ /_/ \___/_/ \__/_//_/
+
+]]..col.red('th')..[[ is an enhanced interpreter (repl) for Torch7/LuaJIT.
+
+]]..col.blue('Features:')..[[
+
+ Tab-completion on nested namespaces
+ Tab-completion on disk files (when opening a string)
+ History
+ Pretty print (table introspection and coloring)
+ Auto-print after eval (can be stopped with ;)
+ Each command is profiled, timing is reported
+ No need for '=' to print
+ Easy help on functions/packages:
+ ]]..col.magenta("? torch.randn")..[[
+ Shell commands with:
+ ]]..col.magenta("$ ls -l")..[[
+ Print all user globals with:
+ ]]..col.magenta("who()")..[[
+ Import a package's symbols globally with:
+ ]]..col.magenta("import 'torch' ")..[[
+ Require is overloaded to provide relative (form within a file) search paths:
+ ]]..col.magenta("require './local/lib' ")..[[
+ Optional strict global namespace monitoring:
+ ]]..col.magenta('th -g')..[[
+ Optional async repl (based on https://github.com/clementfarabet/async):
+ ]]..col.magenta('th -a')..[[
+]]
+
+-- If no Torch:
+if not torch then
+ torch = {
+ typename = function() return '' end
+ }
+end
+
+-- helper
+local function sizestr(x)
+ local strt = {}
+ if _G.torch.typename(x):find('torch.*Storage') then
+ return _G.torch.typename(x):match('torch%.(.+)') .. ' - size: ' .. x:size()
+ end
+ if x:nDimension() == 0 then
+ table.insert(strt, _G.torch.typename(x):match('torch%.(.+)') .. ' - empty')
+ else
+ table.insert(strt, _G.torch.typename(x):match('torch%.(.+)') .. ' - size: ')
+ for i=1,x:nDimension() do
+ table.insert(strt, x:size(i))
+ if i ~= x:nDimension() then
+ table.insert(strt, 'x')
+ end
+ end
+ end
+ return table.concat(strt)
+end
+
+-- k : name of variable
+-- m : max length
+local function printvar(key,val,m)
+ local name = '[' .. tostring(key) .. ']'
+ --io.write(name)
+ name = name .. string.rep(' ',m-name:len()+2)
+ local tp = type(val)
+ if tp == 'userdata' then
+ tp = torch.typename(val) or ''
+ if tp:find('torch.*Tensor') then
+ tp = sizestr(val)
+ elseif tp:find('torch.*Storage') then
+ tp = sizestr(val)
+ else
+ tp = tostring(val)
+ end
+ elseif tp == 'table' then
+ tp = tp .. ' - size: ' .. #val
+ elseif tp == 'string' then
+ local tostr = val:gsub('\n','\\n')
+ if #tostr>40 then
+ tostr = tostr:sub(1,40) .. '...'
+ end
+ tp = tp .. ' : "' .. tostr .. '"'
+ else
+ tp = tostring(val)
+ end
+ return name .. ' = ' .. tp
+end
+
+-- helper
+local function getmaxlen(vars)
+ local m = 0
+ if type(vars) ~= 'table' then return tostring(vars):len() end
+ for k,v in pairs(vars) do
+ local s = tostring(k)
+ if s:len() > m then
+ m = s:len()
+ end
+ end
+ return m
+end
+
+-- overload print:
+if not print_old then
+ print_old=print
+end
+
+-- a function to colorize output:
+local function colorize(object,nested)
+ -- Apply:
+ local apply = col
+
+ -- Type?
+ if object == nil then
+ return apply['Black']('nil')
+ elseif type(object) == 'number' then
+ return apply['cyan'](tostring(object))
+ elseif type(object) == 'boolean' then
+ return apply['blue'](tostring(object))
+ elseif type(object) == 'string' then
+ if nested then
+ return apply['Black']('"')..apply['green'](object)..apply['Black']('"')
+ else
+ return apply['none'](object)
+ end
+ elseif type(object) == 'function' then
+ return apply['magenta'](tostring(object))
+ elseif type(object) == 'userdata' or type(object) == 'cdata' then
+ local tp = torch.typename(object) or ''
+ if tp:find('torch.*Tensor') then
+ tp = sizestr(object)
+ elseif tp:find('torch.*Storage') then
+ tp = sizestr(object)
+ else
+ tp = tostring(object)
+ end
+ if tp ~= '' then
+ return apply['red'](tp)
+ else
+ return apply['red'](tostring(object))
+ end
+ elseif type(object) == 'table' then
+ return apply['green'](tostring(object))
+ else
+ return apply['_black'](tostring(object))
+ end
+end
+
+-- This is a new recursive, colored print.
+local ndepth = 4
+function print_new(...)
+ local function rawprint(o)
+ io.write(tostring(o or '') .. '\n')
+ io.flush()
+ end
+ local objs = {...}
+ local function printrecursive(obj,depth)
+ local depth = depth or 0
+ local tab = depth*4
+ local line = function(s) for i=1,tab do io.write(' ') end rawprint(s) end
+ if next(obj) then
+ line('{')
+ tab = tab+2
+ for k,v in pairs(obj) do
+ if type(v) == 'table' then
+ if depth >= (ndepth-1) or next(v) == nil then
+ line(tostring(k) .. ' : {}')
+ else
+ line(tostring(k) .. ' : ') printrecursive(v,depth+1)
+ end
+ else
+ line(tostring(k) .. ' : ' .. colorize(v,true))
+ end
+ end
+ tab = tab-2
+ line('}')
+ else
+ line('{}')
+ end
+ end
+ for i = 1,select('#',...) do
+ local obj = select(i,...)
+ if type(obj) ~= 'table' then
+ if type(obj) == 'userdata' or type(obj) == 'cdata' then
+ rawprint(obj)
+ else
+ io.write(colorize(obj) .. '\t')
+ if i == select('#',...) then
+ rawprint()
+ end
+ end
+ elseif getmetatable(obj) and getmetatable(obj).__tostring then
+ rawprint(obj)
+ else
+ printrecursive(obj)
+ end
+ end
+end
+
+
+function setprintlevel(n)
+ if n == nil or n < 0 then
+ error('expected number [0,+)')
+ end
+ n = math.floor(n)
+ ndepth = n
+ if ndepth == 0 then
+ print = print_old
+ else
+ print = print_new
+ end
+end
+setprintlevel(5)
+
+-- Import, ala Python
+function import(package, forced)
+ local ret = require(package)
+ local symbols = {}
+ if _G[package] then
+ _G._torchimport = _G._torchimport or {}
+ _G._torchimport[package] = _G[package]
+ symbols = _G[package]
+ elseif ret and type(ret) == 'table' then
+ _G._torchimport = _G._torchimport or {}
+ _G._torchimport[package] = ret
+ symbols = ret
+ end
+ for k,v in pairs(symbols) do
+ if not _G[k] or forced then
+ _G[k] = v
+ end
+ end
+end
+
+-- Smarter require (ala Node.js)
+local drequire = require
+function require(name)
+ if name:find('^%.') then
+ local file = debug.getinfo(2).source:gsub('^@','')
+ local dir = '.'
+ if path.exists(file) then
+ dir = path.dirname(file)
+ end
+ local pkgpath = path.join(dir,name)
+ if path.isfile(pkgpath..'.lua') then
+ return dofile(pkgpath..'.lua')
+ elseif path.isfile(pkgpath) then
+ return dofile(pkgpath)
+ elseif path.isfile(pkgpath..'.so') then
+ return package.loadlib(pkgpath..'.so', 'luaopen_'..path.basename(name))()
+ elseif path.isfile(pkgpath..'.dylib') then
+ return package.loadlib(pkgpath..'.dylib', 'luaopen_'..path.basename(name))()
+ else
+ local initpath = path.join(pkgpath,'init.lua')
+ return dofile(initpath)
+ end
+ else
+ return drequire(name)
+ end
+end
+
+-- Who
+-- a simple function that prints all the symbols defined by the user
+-- very much like Matlab's who function
+function who(system)
+ local m = getmaxlen(_G)
+ local p = _G._preloaded_
+ local function printsymb(sys)
+ for k,v in pairs(_G) do
+ if (sys and p[k]) or (not sys and not p[k]) then
+ print(printvar(k,_G[k],m))
+ end
+ end
+ end
+ if system then
+ print('== System Variables ==')
+ printsymb(true)
+ end
+ print('== User Variables ==')
+ printsymb(false)
+ print('==')
+end
+
+-- Monitor Globals
+function monitor_G(cb)
+ -- user CB or strict mode
+ local strict
+ if type(cb) == 'boolean' then
+ strict = true
+ cb = nil
+ end
+
+ -- Force load of penlight packages:
+ stringx = require 'pl.stringx'
+ tablex = require 'pl.tablex'
+ path = require 'pl.path'
+ dir = require 'pl.dir'
+ text = require 'pl.text'
+
+ -- Store current globals:
+ local evercreated = {}
+ for k in pairs(_G) do
+ evercreated[k] = true
+ end
+
+ -- Overwrite global namespace meta tables to monitor it:
+ setmetatable(_G, {
+ __newindex = function(G,key,val)
+ if not evercreated[key] then
+ if cb then
+ cb(key)
+ else
+ local file = debug.getinfo(2).source:gsub('^@','')
+ local line = debug.getinfo(2).currentline
+ local report = print
+ if strict then
+ report = error
+ end
+ if line > 0 then
+ report(colors.red .. 'created global variable: '
+ .. colors.blue .. key .. colors.none
+ .. ' @ ' .. colors.magenta .. file .. colors.none
+ .. ':' .. colors.green .. line .. colors.none
+ )
+ else
+ report(colors.red .. 'created global variable: '
+ .. colors.blue .. key .. colors.none
+ .. ' @ ' .. colors.yellow .. '[C-module]' .. colors.none
+ )
+ end
+ end
+ end
+ evercreated[key] = true
+ rawset(G,key,val)
+ end,
+ __index = function (table, key)
+ error(colors.red .. "attempt to read undeclared variable " .. colors.blue .. key .. colors.none, 2)
+ end,
+ })
+end
+
+-- Tracekback (error printout)
+local function traceback(message)
+ local tp = type(message)
+ if tp ~= "string" and tp ~= "number" then return message end
+ local debug = _G.debug
+ if type(debug) ~= "table" then return message end
+ local tb = debug.traceback
+ if type(tb) ~= "function" then return message end
+ return tb(message)
+end
+
+-- Prompt:
+local function prompt(aux)
+ local s
+ if not aux then
+ s = 'th> '
+ else
+ s = '..> '
+ end
+ return s
+end
+
+-- Aliases:
+local aliases = [[
+ alias ls='ls -GF';
+ alias ll='ls -lhF';
+ alias la='ls -ahF';
+ alias lla='ls -lahF';
+]]
+
+-- Penlight
+pcall(require,'pl')
+
+-- Reults:
+_RESULTS = {}
+_LAST = ''
+
+-- Readline:
+local readline_ok,readline = pcall(require,"trepl.readline")
+if not readline_ok then
+ print(col.red('WARNING: ') .. 'could not find/load readline, defaulting to linenoise')
+end
+
+-- REPL:
+function repl_readline()
+ -- Completer:
+ local completer = require 'trepl.completer'
+ completer.final_char_setter = readline.completion_append_character
+
+ local inputrc = paths.concat(os.getenv('HOME'),'.inputrc')
+ if not paths.filep(inputrc) then
+ local finputrc = io.open(inputrc,'w')
+ local trepl =
+[[
+$if lua
+ # filter up and down arrows using characters typed so far
+ "\e[A":history-search-backward
+ "\e[B":history-search-forward
+$endif
+]]
+ finputrc:write(trepl)
+ finputrc:close()
+ end
+
+ -- Timer
+ local timer_start, timer_stop
+ if torch and torch.Timer then
+ local t = torch.Timer()
+ local start = 0
+ timer_start = function()
+ start = t:time().real
+ end
+ timer_stop = function()
+ local step = t:time().real - start
+ for i = 1,70 do io.write(' ') end
+ print(col.Black(string.format('[%0.04fs]', step)))
+ end
+ else
+ timer_start = function() end
+ timer_stop = function() end
+ end
+
+ -- History:
+ local history = os.getenv('HOME') .. '/.luahistory'
+
+ -- Readline callback:
+ readline.shell{
+ -- History:
+ history = history,
+
+ -- Completer:
+ complete = completer.complete,
+
+ -- Chars:
+ word_break_characters = " \t\n\"\\'><=;:+-*/%^~#{}()[].,",
+
+ -- Get command:
+ getcommand = function()
+ -- get the first line
+ local line = coroutine.yield(prompt())
+ local cmd = line .. '\n'
+
+ -- = (lua supports that)
+ if cmd:sub(1,1) == "=" then
+ cmd = "return "..cmd:sub(2)
+ end
+
+ -- Interupt?
+ if line == 'exit' then
+ io.stdout:write('Do you really want to exit ([y]/n)? ') io.flush()
+ local line = io.read('*l')
+ if line == '' or line:lower() == 'y' then
+ os.exit()
+ end
+ end
+
+ -- OS Commands:
+ if line and line:find('^%s-%$') then
+ local cline = line:gsub('^%s-%$','')
+ if io.popen then
+ local f = io.popen(aliases .. ' ' .. cline)
+ local res = f:read('*a')
+ f:close()
+ io.write(col.none(res)) io.flush()
+ table.insert(_RESULTS, res)
+ _LAST = _RESULTS[#_RESULTS]
+ else
+ os.execute(aliases .. ' ' .. cline)
+ end
+ timer_stop()
+ return line
+ end
+
+ -- Shortcut to get help:
+ if line and line:find('^%s-?') then
+ local ok = pcall(require,'dok')
+ if ok then
+ local pkg = line:gsub('^%s-?','')
+ if pkg:gsub('%s*','') == '' then
+ print(selfhelp)
+ return
+ else
+ line = 'help(' .. pkg .. ')'
+ end
+ else
+ print('error: could not load help backend')
+ return line
+ end
+ end
+
+ -- try to return first:
+ timer_start()
+ local pok,ok,err
+ if line:find(';%s-$') or line:find('^%s-print') then
+ ok = false
+ elseif line:match('^%s*$') then
+ return nil
+ else
+ local func, perr = loadstring('local f = function() return '..line..' end local res = {f()} print(unpack(res)) table.insert(_RESULTS,res[1])')
+ if func then
+ pok = true
+ ok,err = xpcall(func, traceback)
+ end
+ end
+
+ -- run ok:
+ if ok then
+ _LAST = _RESULTS[#_RESULTS]
+ timer_stop()
+ return line
+ end
+
+ -- parsed ok, but failed to run (code error):
+ if pok then
+ print(err)
+ return cmd:sub(1, -2)
+ end
+
+ -- continue to get lines until get a complete chunk
+ local func, err
+ while true do
+ -- if not go ahead:
+ func, err = loadstring(cmd)
+ if func or err:sub(-7) ~= "'<eof>'" then break end
+
+ -- concat:
+ cmd = cmd .. coroutine.yield(prompt(true)) .. '\n'
+ end
+
+ -- exec chunk:
+ if not cmd:match("^%s*$") then
+ local ff,err=loadstring(cmd)
+ if not ff then
+ print(err)
+ return cmd:sub(1, -2)
+ end
+ local res = {xpcall(ff, traceback)}
+ local ok,err = res[1], res[2]
+ if not ok then
+ print(err)
+ else
+ if err ~= nil then
+ table.remove(res,1)
+ print(unpack(res))
+ end
+ end
+ timer_stop()
+ return cmd:sub(1, -2) -- remove last \n for history
+ end
+ end,
+ }
+ io.stderr:write"\n"
+end
+
+-- No readline -> LineNoise?
+local nextline
+if not readline_ok then
+ print('ici')
+ -- Load linenoise:
+ local ok,L = pcall(require,'linenoise')
+ ok = false
+ if not ok then
+ -- No readline, no linenoise... default to plain io:
+ nextline = function()
+ io.write(prompt()) io.flush()
+ return io.read('*line')
+ end
+
+ -- Really poor:
+ print(col.red('WARNING: ') .. 'could not find/load linenoise, defaulting to raw repl')
+ else
+ -- History:
+ local history = os.getenv('HOME') .. '/.luahistory'
+ L.historyload(history)
+
+ -- Completion:
+ L.setcompletion(function(c,s)
+ -- Check if we're in a string
+ local ignore,str = s:gfind('(.-)"([a-zA-Z%._]*)$')()
+ local quote = '"'
+ if not str then
+ ignore,str = s:gfind('(.-)\'([a-zA-Z%._]*)$')()
+ quote = "'"
+ end
+
+ -- String?
+ if str then
+ -- Complete from disk:
+ local f = io.popen('ls ' .. str..'* 2> /dev/null')
+ local res = f:read('*all')
+ f:close()
+ res = res:gsub('(%s*)$','')
+ local elts = stringx.split(res,'\n')
+ for _,elt in ipairs(elts) do
+ L.addcompletion(c,ignore .. quote .. elt)
+ end
+ return
+ end
+
+ -- Get symbol of interest
+ local ignore,str = s:gfind('(.-)([a-zA-Z%._]*)$')()
+
+ -- Lookup globals:
+ if not str:find('%.') then
+ for k,v in pairs(_G) do
+ if k:find('^'..str) then
+ L.addcompletion(c,ignore .. k)
+ end
+ end
+ end
+
+ -- Lookup packages:
+ local base,sub = str:gfind('(.*)%.(.*)')()
+ if base then
+ local ok,res = pcall(loadstring('return ' .. base))
+ for k,v in pairs(res) do
+ if k:find('^'..sub) then
+ L.addcompletion(c,ignore .. base .. '.' .. k)
+ end
+ end
+ end
+ end)
+
+ -- read line:
+ nextline = function()
+ -- Get line:
+ local line = L.linenoise(prompt())
+
+ -- Save:
+ if line and not line:find('^%s-$') then
+ L.historyadd(line)
+ L.historysave(history)
+ end
+
+ -- Return line:
+ return line
+ end
+ end
+end
+
+-- The default repl
+function repl_linenoise()
+ -- Timer
+ local timer_start, timer_stop
+ if torch and torch.Timer then
+ local t = torch.Timer()
+ local start = 0
+ timer_start = function()
+ start = t:time().real
+ end
+ timer_stop = function()
+ local step = t:time().real - start
+ for i = 1,70 do io.write(' ') end
+ print(col.Black(string.format('[%0.04fs]', step)))
+ end
+ else
+ timer_start = function() end
+ timer_stop = function() end
+ end
+
+ -- REPL:
+ while true do
+ -- READ:
+ local line = nextline()
+
+ -- Interupt?
+ if not line or line == 'exit' then
+ io.write('Do you really want to exit ([y]/n)? ') io.flush()
+ local line = io.read('*l')
+ if line == '' or line:lower() == 'y' then
+ os.exit()
+ end
+ end
+ if line == 'break' then
+ break
+ end
+
+ -- OS Commands:
+ if line and line:find('^%s-%$') then
+ line = line:gsub('^%s-%$','')
+ if io.popen then
+ local f = io.popen(aliases .. ' ' .. line)
+ local res = f:read('*a')
+ f:close()
+ io.write(col._black(res)) io.flush()
+ table.insert(_RESULTS, res)
+ else
+ os.execute(aliases .. ' ' .. line)
+ end
+ line = nil
+ end
+
+ -- Support the crappy '=', as Lua does:
+ if line and line:find('^%s-=') then
+ line = line:gsub('^%s-=','')
+ end
+
+ -- Shortcut to get help:
+ if line and line:find('^%s-?') then
+ local ok = pcall(require,'dok')
+ if ok then
+ line = 'help(' .. line:gsub('^%s-?','') .. ')'
+ else
+ print('error: could not load help backend')
+ line = nil
+ end
+ end
+
+ -- EVAL:
+ if line then
+ timer_start()
+ local ok,err
+ if line:find(';%s-$') or line:find('^%s-print') then
+ ok = false
+ else
+ ok,err = xpcall(loadstring('local f = function() return '..line..' end local res = {f()} print(unpack(res)) table.insert(_RESULTS,res[1])'), traceback)
+ end
+ if not ok then
+ local ok,err = xpcall(loadstring(line), traceback)
+ if not ok then
+ print(err)
+ end
+ end
+ timer_stop()
+ end
+
+ -- Last result:
+ _LAST = _RESULTS[#_RESULTS]
+ end
+end
+
+-- Store preloaded symbols, for who()
+_G._preloaded_ = {}
+for k,v in pairs(_G) do
+ _G._preloaded_[k] = true
+end
+
+-- return repl, just call it to start it!
+return (readline_ok and repl_readline) or repl_linenoise