Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorClement Farabet <clement.farabet@gmail.com>2011-07-12 00:47:09 +0400
committerClement Farabet <clement.farabet@gmail.com>2011-07-12 00:47:09 +0400
commit6d4bd46a0473b229492ac9602ab72e010d0ac41b (patch)
treef360e47e0ba7aa23e365bb4e19417c589f731904
parent73133afa2bc2a55256c4a93b113db74b3373410d (diff)
Better plotting in Logger.
-rw-r--r--Logger.lua33
1 files changed, 28 insertions, 5 deletions
diff --git a/Logger.lua b/Logger.lua
index 334b42f..76ff1ab 100644
--- a/Logger.lua
+++ b/Logger.lua
@@ -3,16 +3,19 @@ local Logger = torch.class('nn.Logger')
function Logger:__init(filename)
if filename then
+ self.name = filename
os.execute('mkdir -p ' .. sys.dirname(filename))
filename = sys.concat(filename .. '-'..os.date("%Y_%m_%d_%X"))
self.file = io.open(filename,'w')
else
self.file = io.stdout
+ self.name = 'stdout'
print('<Logger> warning: no path provided, logging to std out')
end
self.empty = true
self.symbols = {}
- self.figures = {}
+ self.styles = {}
+ self.figure = nil
end
function Logger:add(symbols)
@@ -23,6 +26,7 @@ function Logger:add(symbols)
for k,val in pairs(symbols) do
self.file:write(k .. '\t')
self.symbols[k] = {}
+ self.styles[k] = {'+'}
end
self.file:write('\n')
end
@@ -44,6 +48,18 @@ function Logger:add(symbols)
end
end
+function Logger:style(symbols)
+ for name,style in pairs(symbols) do
+ if type(style) == 'string' then
+ self.styles[name] = {style}
+ elseif type(style) == 'table' then
+ self.styles[name] = style
+ else
+ xlua.error('style should be a string or a table of strings','Logger')
+ end
+ end
+end
+
function Logger:plot(...)
if not lab.plot then
if not self.warned then
@@ -51,18 +67,20 @@ function Logger:plot(...)
end
return
end
+ local plot = false
+ local plots = {}
local plotsymbol =
function(name,list)
if #list > 1 then
local nelts = #list
- local plot_x = lab.range(1,nelts)
local plot_y = torch.Tensor(nelts)
for i = 1,nelts do
plot_y[i] = list[i]
end
- self.figures[name] = lab.figure(self.figures[name])
- lab.plot(name, plot_x, plot_y, '-')
- lab.title(name)
+ for _,style in ipairs(self.styles[name]) do
+ table.insert(plots, {name, plot_y, style})
+ end
+ plot = true
end
end
local args = {...}
@@ -75,4 +93,9 @@ function Logger:plot(...)
plotsymbol(name,self.symbols[name])
end
end
+ if plot then
+ self.figure = lab.figure(self.figure)
+ lab.plot(plots)
+ lab.title('<Logger::' .. self.name .. '>')
+ end
end