Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/optim.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorClement Farabet <clement.farabet@gmail.com>2012-04-02 23:18:23 +0400
committerClement Farabet <clement.farabet@gmail.com>2012-04-02 23:18:23 +0400
commit42cb83a21e09903fd7f3a8e1d754cedfbf5a7faf (patch)
tree5187b6e8143afe4d8a053e1b85bd14f405013a5f
parent572359a16248a7c2f931847911c0269b02634b77 (diff)
Moved some stable classes from nnx to optim.
They make more sense here.
-rw-r--r--ConfusionMatrix.lua115
-rw-r--r--Logger.lua130
-rw-r--r--init.lua5
3 files changed, 250 insertions, 0 deletions
diff --git a/ConfusionMatrix.lua b/ConfusionMatrix.lua
new file mode 100644
index 0000000..a4d44e5
--- /dev/null
+++ b/ConfusionMatrix.lua
@@ -0,0 +1,115 @@
+----------------------------------------------------------------------
+-- A Confusion Matrix class
+--
+-- Example:
+-- conf = optim.ConfusionMatrix( {'cat','dog','person'} ) -- new matrix
+-- conf:zero() -- reset matrix
+-- for i = 1,N do
+-- conf:add( neuralnet:forward(sample), label ) -- accumulate errors
+-- end
+-- print(conf) -- print matrix
+--
+local ConfusionMatrix = torch.class('optim.ConfusionMatrix')
+
+function ConfusionMatrix:__init(nclasses, classes)
+ if type(nclasses) == 'table' then
+ classes = nclasses
+ nclasses = #classes
+ end
+ self.mat = torch.FloatTensor(nclasses,nclasses):zero()
+ self.valids = torch.FloatTensor(nclasses):zero()
+ self.unionvalids = torch.FloatTensor(nclasses):zero()
+ self.nclasses = nclasses
+ self.totalValid = 0
+ self.averageValid = 0
+ self.classes = classes or {}
+end
+
+function ConfusionMatrix:add(prediction, target)
+ if type(prediction) == 'number' then
+ -- comparing numbers
+ self.mat[target][prediction] = self.mat[target][prediction] + 1
+ elseif type(target) == 'number' then
+ -- prediction is a vector, then target assumed to be an index
+ local prediction_1d = torch.FloatTensor(self.nclasses):copy(prediction)
+ local _,prediction = prediction_1d:max(1)
+ self.mat[target][prediction[1]] = self.mat[target][prediction[1]] + 1
+ else
+ -- both prediction and target are vectors
+ local prediction_1d = torch.FloatTensor(self.nclasses):copy(prediction)
+ local target_1d = torch.FloatTensor(self.nclasses):copy(target)
+ local _,prediction = prediction_1d:max(1)
+ local _,target = target_1d:max(1)
+ self.mat[target[1]][prediction[1]] = self.mat[target[1]][prediction[1]] + 1
+ end
+end
+
+function ConfusionMatrix:zero()
+ self.mat:zero()
+ self.valids:zero()
+ self.unionvalids:zero()
+ self.totalValid = 0
+ self.averageValid = 0
+end
+
+function ConfusionMatrix:updateValids()
+ local total = 0
+ for t = 1,self.nclasses do
+ self.valids[t] = self.mat[t][t] / self.mat:select(1,t):sum()
+ self.unionvalids[t] = self.mat[t][t] / (self.mat:select(1,t):sum()+self.mat:select(2,t):sum()-self.mat[t][t])
+ total = total + self.mat[t][t]
+ end
+ self.totalValid = total / self.mat:sum()
+ self.averageValid = 0
+ self.averageUnionValid = 0
+ local nvalids = 0
+ local nunionvalids = 0
+ for t = 1,self.nclasses do
+ if not sys.isNaN(self.valids[t]) then
+ self.averageValid = self.averageValid + self.valids[t]
+ nvalids = nvalids + 1
+ end
+ if not sys.isNaN(self.valids[t]) and not sys.isNaN(self.unionvalids[t]) then
+ self.averageUnionValid = self.averageUnionValid + self.unionvalids[t]
+ nunionvalids = nunionvalids + 1
+ end
+ end
+ self.averageValid = self.averageValid / nvalids
+ self.averageUnionValid = self.averageUnionValid / nunionvalids
+end
+
+function ConfusionMatrix:__tostring__()
+ self:updateValids()
+ local str = 'ConfusionMatrix:\n'
+ local nclasses = self.nclasses
+ str = str .. '['
+ for t = 1,nclasses do
+ local pclass = self.valids[t] * 100
+ pclass = string.format('%2.3f', pclass)
+ if t == 1 then
+ str = str .. '['
+ else
+ str = str .. ' ['
+ end
+ for p = 1,nclasses do
+ str = str .. '' .. string.format('%8d', self.mat[t][p])
+ end
+ if self.classes and self.classes[1] then
+ if t == nclasses then
+ str = str .. ']] ' .. pclass .. '% \t[class: ' .. (self.classes[t] or '') .. ']\n'
+ else
+ str = str .. '] ' .. pclass .. '% \t[class: ' .. (self.classes[t] or '') .. ']\n'
+ end
+ else
+ if t == nclasses then
+ str = str .. ']] ' .. pclass .. '% \n'
+ else
+ str = str .. '] ' .. pclass .. '% \n'
+ end
+ end
+ end
+ str = str .. ' + average row correct: ' .. (self.averageValid*100) .. '% \n'
+ str = str .. ' + average rowUcol correct (VOC measure): ' .. (self.averageUnionValid*100) .. '% \n'
+ str = str .. ' + global correct: ' .. (self.totalValid*100) .. '%'
+ return str
+end
diff --git a/Logger.lua b/Logger.lua
new file mode 100644
index 0000000..1e065bf
--- /dev/null
+++ b/Logger.lua
@@ -0,0 +1,130 @@
+----------------------------------------------------------------------
+-- Logger: a simple class to log symbols during training,
+-- and automate plot generation
+--
+-- Example:
+-- logger = optim.Logger('somefile.log') -- file to save stuff
+--
+-- for i = 1,N do -- log some symbols during
+-- train_error = ... -- training/testing
+-- test_error = ...
+-- logger:add{['training error'] = train_error,
+-- ['test error'] = test_error}
+-- end
+--
+-- logger:style{['training error'] = '-', -- define styles for plots
+-- ['test error'] = '-'}
+-- logger:plot() -- and plot
+--
+local Logger = torch.class('optim.Logger')
+
+function Logger:__init(filename, timestamp)
+ if filename then
+ self.name = filename
+ os.execute('mkdir -p "' .. sys.dirname(filename) .. '"')
+ if timestamp then
+ -- append timestamp to create unique log file
+ filename = filename .. '-'..os.date("%Y_%m_%d_%X")
+ end
+ self.file = io.open(filename,'w')
+ self.epsfile = self.name .. '.eps'
+ else
+ self.file = io.stdout
+ self.name = 'stdout'
+ print('<Logger> warning: no path provided, logging to std out')
+ end
+ self.empty = true
+ self.symbols = {}
+ self.styles = {}
+ self.figure = nil
+end
+
+function Logger:add(symbols)
+ -- (1) first time ? print symbols' names on first row
+ if self.empty then
+ self.empty = false
+ self.nsymbols = #symbols
+ for k,val in pairs(symbols) do
+ self.file:write(k .. '\t')
+ self.symbols[k] = {}
+ self.styles[k] = {'+'}
+ end
+ self.file:write('\n')
+ end
+ -- (2) print all symbols on one row
+ for k,val in pairs(symbols) do
+ if type(val) == 'number' then
+ self.file:write(string.format('%11.4e',val) .. '\t')
+ elseif type(val) == 'string' then
+ self.file:write(val .. '\t')
+ else
+ xlua.error('can only log numbers and strings', 'Logger')
+ end
+ end
+ self.file:write('\n')
+ self.file:flush()
+ -- (3) save symbols in internal table
+ for k,val in pairs(symbols) do
+ table.insert(self.symbols[k], val)
+ end
+end
+
+function Logger:style(symbols)
+ for name,style in pairs(symbols) do
+ if type(style) == 'string' then
+ self.styles[name] = {style}
+ elseif type(style) == 'table' then
+ self.styles[name] = style
+ else
+ xlua.error('style should be a string or a table of strings','Logger')
+ end
+ end
+end
+
+function Logger:plot(...)
+ if not xlua.require('gnuplot') then
+ if not self.warned then
+ print('<Logger> warning: cannot plot with this version of Torch')
+ self.warned = true
+ end
+ return
+ end
+ local plotit = false
+ local plots = {}
+ local plotsymbol =
+ function(name,list)
+ if #list > 1 then
+ local nelts = #list
+ local plot_y = torch.Tensor(nelts)
+ for i = 1,nelts do
+ plot_y[i] = list[i]
+ end
+ for _,style in ipairs(self.styles[name]) do
+ table.insert(plots, {name, plot_y, style})
+ end
+ plotit = true
+ end
+ end
+ local args = {...}
+ if not args[1] then -- plot all symbols
+ for name,list in pairs(self.symbols) do
+ plotsymbol(name,list)
+ end
+ else -- plot given symbols
+ for i,name in ipairs(args) do
+ plotsymbol(name,self.symbols[name])
+ end
+ end
+ if plotit then
+ self.figure = gnuplot.figure(self.figure)
+ gnuplot.plot(plots)
+ gnuplot.title('<Logger::' .. self.name .. '>')
+ if self.epsfile then
+ os.execute('rm -f ' .. self.epsfile)
+ gnuplot.epsfigure(self.epsfile)
+ gnuplot.plot(plots)
+ gnuplot.title('<Logger::' .. self.name .. '>')
+ gnuplot.plotflush()
+ end
+ end
+end
diff --git a/init.lua b/init.lua
index 34bb793..287d0c2 100644
--- a/init.lua
+++ b/init.lua
@@ -3,8 +3,13 @@ require 'torch'
optim = {}
+-- optimizations
torch.include('optim', 'sgd.lua')
torch.include('optim', 'cg.lua')
torch.include('optim', 'asgd.lua')
torch.include('optim', 'fista.lua')
torch.include('optim', 'lbfgs.lua')
+
+-- tools
+torch.include('optim', 'ConfusionMatrix.lua')
+torch.include('optim', 'Logger.lua')