diff options
author | Clement Farabet <clement.farabet@gmail.com> | 2011-07-12 00:47:09 +0400 |
---|---|---|
committer | Clement Farabet <clement.farabet@gmail.com> | 2011-07-12 00:47:09 +0400 |
commit | 6d4bd46a0473b229492ac9602ab72e010d0ac41b (patch) | |
tree | f360e47e0ba7aa23e365bb4e19417c589f731904 | |
parent | 73133afa2bc2a55256c4a93b113db74b3373410d (diff) |
Better plotting in Logger.
-rw-r--r-- | Logger.lua | 33 |
1 files changed, 28 insertions, 5 deletions
@@ -3,16 +3,19 @@ local Logger = torch.class('nn.Logger') function Logger:__init(filename) if filename then + self.name = filename os.execute('mkdir -p ' .. sys.dirname(filename)) filename = sys.concat(filename .. '-'..os.date("%Y_%m_%d_%X")) self.file = io.open(filename,'w') else self.file = io.stdout + self.name = 'stdout' print('<Logger> warning: no path provided, logging to std out') end self.empty = true self.symbols = {} - self.figures = {} + self.styles = {} + self.figure = nil end function Logger:add(symbols) @@ -23,6 +26,7 @@ function Logger:add(symbols) for k,val in pairs(symbols) do self.file:write(k .. '\t') self.symbols[k] = {} + self.styles[k] = {'+'} end self.file:write('\n') end @@ -44,6 +48,18 @@ function Logger:add(symbols) end end +function Logger:style(symbols) + for name,style in pairs(symbols) do + if type(style) == 'string' then + self.styles[name] = {style} + elseif type(style) == 'table' then + self.styles[name] = style + else + xlua.error('style should be a string or a table of strings','Logger') + end + end +end + function Logger:plot(...) if not lab.plot then if not self.warned then @@ -51,18 +67,20 @@ function Logger:plot(...) end return end + local plot = false + local plots = {} local plotsymbol = function(name,list) if #list > 1 then local nelts = #list - local plot_x = lab.range(1,nelts) local plot_y = torch.Tensor(nelts) for i = 1,nelts do plot_y[i] = list[i] end - self.figures[name] = lab.figure(self.figures[name]) - lab.plot(name, plot_x, plot_y, '-') - lab.title(name) + for _,style in ipairs(self.styles[name]) do + table.insert(plots, {name, plot_y, style}) + end + plot = true end end local args = {...} @@ -75,4 +93,9 @@ function Logger:plot(...) plotsymbol(name,self.symbols[name]) end end + if plot then + self.figure = lab.figure(self.figure) + lab.plot(plots) + lab.title('<Logger::' .. self.name .. '>') + end end |