diff options
author | Clement Farabet <clement.farabet@gmail.com> | 2012-04-02 23:18:23 +0400 |
---|---|---|
committer | Clement Farabet <clement.farabet@gmail.com> | 2012-04-02 23:18:23 +0400 |
commit | 42cb83a21e09903fd7f3a8e1d754cedfbf5a7faf (patch) | |
tree | 5187b6e8143afe4d8a053e1b85bd14f405013a5f | |
parent | 572359a16248a7c2f931847911c0269b02634b77 (diff) |
Moved some stable classes from nnx to optim.
They make more sense here.
-rw-r--r-- | ConfusionMatrix.lua | 115 | ||||
-rw-r--r-- | Logger.lua | 130 | ||||
-rw-r--r-- | init.lua | 5 |
3 files changed, 250 insertions, 0 deletions
diff --git a/ConfusionMatrix.lua b/ConfusionMatrix.lua new file mode 100644 index 0000000..a4d44e5 --- /dev/null +++ b/ConfusionMatrix.lua @@ -0,0 +1,115 @@ +---------------------------------------------------------------------- +-- A Confusion Matrix class +-- +-- Example: +-- conf = optim.ConfusionMatrix( {'cat','dog','person'} ) -- new matrix +-- conf:zero() -- reset matrix +-- for i = 1,N do +-- conf:add( neuralnet:forward(sample), label ) -- accumulate errors +-- end +-- print(conf) -- print matrix +-- +local ConfusionMatrix = torch.class('optim.ConfusionMatrix') + +function ConfusionMatrix:__init(nclasses, classes) + if type(nclasses) == 'table' then + classes = nclasses + nclasses = #classes + end + self.mat = torch.FloatTensor(nclasses,nclasses):zero() + self.valids = torch.FloatTensor(nclasses):zero() + self.unionvalids = torch.FloatTensor(nclasses):zero() + self.nclasses = nclasses + self.totalValid = 0 + self.averageValid = 0 + self.classes = classes or {} +end + +function ConfusionMatrix:add(prediction, target) + if type(prediction) == 'number' then + -- comparing numbers + self.mat[target][prediction] = self.mat[target][prediction] + 1 + elseif type(target) == 'number' then + -- prediction is a vector, then target assumed to be an index + local prediction_1d = torch.FloatTensor(self.nclasses):copy(prediction) + local _,prediction = prediction_1d:max(1) + self.mat[target][prediction[1]] = self.mat[target][prediction[1]] + 1 + else + -- both prediction and target are vectors + local prediction_1d = torch.FloatTensor(self.nclasses):copy(prediction) + local target_1d = torch.FloatTensor(self.nclasses):copy(target) + local _,prediction = prediction_1d:max(1) + local _,target = target_1d:max(1) + self.mat[target[1]][prediction[1]] = self.mat[target[1]][prediction[1]] + 1 + end +end + +function ConfusionMatrix:zero() + self.mat:zero() + self.valids:zero() + self.unionvalids:zero() + self.totalValid = 0 + self.averageValid = 0 +end + +function ConfusionMatrix:updateValids() + local total = 0 + for t = 1,self.nclasses do + self.valids[t] = self.mat[t][t] / self.mat:select(1,t):sum() + self.unionvalids[t] = self.mat[t][t] / (self.mat:select(1,t):sum()+self.mat:select(2,t):sum()-self.mat[t][t]) + total = total + self.mat[t][t] + end + self.totalValid = total / self.mat:sum() + self.averageValid = 0 + self.averageUnionValid = 0 + local nvalids = 0 + local nunionvalids = 0 + for t = 1,self.nclasses do + if not sys.isNaN(self.valids[t]) then + self.averageValid = self.averageValid + self.valids[t] + nvalids = nvalids + 1 + end + if not sys.isNaN(self.valids[t]) and not sys.isNaN(self.unionvalids[t]) then + self.averageUnionValid = self.averageUnionValid + self.unionvalids[t] + nunionvalids = nunionvalids + 1 + end + end + self.averageValid = self.averageValid / nvalids + self.averageUnionValid = self.averageUnionValid / nunionvalids +end + +function ConfusionMatrix:__tostring__() + self:updateValids() + local str = 'ConfusionMatrix:\n' + local nclasses = self.nclasses + str = str .. '[' + for t = 1,nclasses do + local pclass = self.valids[t] * 100 + pclass = string.format('%2.3f', pclass) + if t == 1 then + str = str .. '[' + else + str = str .. ' [' + end + for p = 1,nclasses do + str = str .. '' .. string.format('%8d', self.mat[t][p]) + end + if self.classes and self.classes[1] then + if t == nclasses then + str = str .. ']] ' .. pclass .. '% \t[class: ' .. (self.classes[t] or '') .. ']\n' + else + str = str .. '] ' .. pclass .. '% \t[class: ' .. (self.classes[t] or '') .. ']\n' + end + else + if t == nclasses then + str = str .. ']] ' .. pclass .. '% \n' + else + str = str .. '] ' .. pclass .. '% \n' + end + end + end + str = str .. ' + average row correct: ' .. (self.averageValid*100) .. '% \n' + str = str .. ' + average rowUcol correct (VOC measure): ' .. (self.averageUnionValid*100) .. '% \n' + str = str .. ' + global correct: ' .. (self.totalValid*100) .. '%' + return str +end diff --git a/Logger.lua b/Logger.lua new file mode 100644 index 0000000..1e065bf --- /dev/null +++ b/Logger.lua @@ -0,0 +1,130 @@ +---------------------------------------------------------------------- +-- Logger: a simple class to log symbols during training, +-- and automate plot generation +-- +-- Example: +-- logger = optim.Logger('somefile.log') -- file to save stuff +-- +-- for i = 1,N do -- log some symbols during +-- train_error = ... -- training/testing +-- test_error = ... +-- logger:add{['training error'] = train_error, +-- ['test error'] = test_error} +-- end +-- +-- logger:style{['training error'] = '-', -- define styles for plots +-- ['test error'] = '-'} +-- logger:plot() -- and plot +-- +local Logger = torch.class('optim.Logger') + +function Logger:__init(filename, timestamp) + if filename then + self.name = filename + os.execute('mkdir -p "' .. sys.dirname(filename) .. '"') + if timestamp then + -- append timestamp to create unique log file + filename = filename .. '-'..os.date("%Y_%m_%d_%X") + end + self.file = io.open(filename,'w') + self.epsfile = self.name .. '.eps' + else + self.file = io.stdout + self.name = 'stdout' + print('<Logger> warning: no path provided, logging to std out') + end + self.empty = true + self.symbols = {} + self.styles = {} + self.figure = nil +end + +function Logger:add(symbols) + -- (1) first time ? print symbols' names on first row + if self.empty then + self.empty = false + self.nsymbols = #symbols + for k,val in pairs(symbols) do + self.file:write(k .. '\t') + self.symbols[k] = {} + self.styles[k] = {'+'} + end + self.file:write('\n') + end + -- (2) print all symbols on one row + for k,val in pairs(symbols) do + if type(val) == 'number' then + self.file:write(string.format('%11.4e',val) .. '\t') + elseif type(val) == 'string' then + self.file:write(val .. '\t') + else + xlua.error('can only log numbers and strings', 'Logger') + end + end + self.file:write('\n') + self.file:flush() + -- (3) save symbols in internal table + for k,val in pairs(symbols) do + table.insert(self.symbols[k], val) + end +end + +function Logger:style(symbols) + for name,style in pairs(symbols) do + if type(style) == 'string' then + self.styles[name] = {style} + elseif type(style) == 'table' then + self.styles[name] = style + else + xlua.error('style should be a string or a table of strings','Logger') + end + end +end + +function Logger:plot(...) + if not xlua.require('gnuplot') then + if not self.warned then + print('<Logger> warning: cannot plot with this version of Torch') + self.warned = true + end + return + end + local plotit = false + local plots = {} + local plotsymbol = + function(name,list) + if #list > 1 then + local nelts = #list + local plot_y = torch.Tensor(nelts) + for i = 1,nelts do + plot_y[i] = list[i] + end + for _,style in ipairs(self.styles[name]) do + table.insert(plots, {name, plot_y, style}) + end + plotit = true + end + end + local args = {...} + if not args[1] then -- plot all symbols + for name,list in pairs(self.symbols) do + plotsymbol(name,list) + end + else -- plot given symbols + for i,name in ipairs(args) do + plotsymbol(name,self.symbols[name]) + end + end + if plotit then + self.figure = gnuplot.figure(self.figure) + gnuplot.plot(plots) + gnuplot.title('<Logger::' .. self.name .. '>') + if self.epsfile then + os.execute('rm -f ' .. self.epsfile) + gnuplot.epsfigure(self.epsfile) + gnuplot.plot(plots) + gnuplot.title('<Logger::' .. self.name .. '>') + gnuplot.plotflush() + end + end +end @@ -3,8 +3,13 @@ require 'torch' optim = {} +-- optimizations torch.include('optim', 'sgd.lua') torch.include('optim', 'cg.lua') torch.include('optim', 'asgd.lua') torch.include('optim', 'fista.lua') torch.include('optim', 'lbfgs.lua') + +-- tools +torch.include('optim', 'ConfusionMatrix.lua') +torch.include('optim', 'Logger.lua') |