Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/optim.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorClement Farabet <clement.farabet@gmail.com>2012-07-28 18:49:07 +0400
committerClement Farabet <clement.farabet@gmail.com>2012-07-28 18:49:07 +0400
commit5bd5bc0834c2e1cc8263bfa200e1b33ad6334240 (patch)
treeacd6071067cd0c43f7747f67432ed76603e66880 /ConfusionMatrix.lua
parent80a1cfef3538d057856a3b0d8b52f85746645daa (diff)
Added renderer to confusion matrix
Diffstat (limited to 'ConfusionMatrix.lua')
-rw-r--r--ConfusionMatrix.lua116
1 files changed, 116 insertions, 0 deletions
diff --git a/ConfusionMatrix.lua b/ConfusionMatrix.lua
index a4d44e5..58d1536 100644
--- a/ConfusionMatrix.lua
+++ b/ConfusionMatrix.lua
@@ -8,6 +8,7 @@
-- conf:add( neuralnet:forward(sample), label ) -- accumulate errors
-- end
-- print(conf) -- print matrix
+-- image.display(conf:render()) -- render matrix
--
local ConfusionMatrix = torch.class('optim.ConfusionMatrix')
@@ -113,3 +114,118 @@ function ConfusionMatrix:__tostring__()
str = str .. ' + global correct: ' .. (self.totalValid*100) .. '%'
return str
end
+
+function ConfusionMatrix:render(sortmode, display, block, legendwidth)
+ -- args
+ local confusion = self.mat
+ local classes = self.classes
+ local sortmode = sortmode or 'score' -- 'score' or 'occurrence'
+ local block = block or 25
+ local legendwidth = legendwidth or 200
+ local display = display or false
+
+ -- legends
+ local legend = {
+ ['score'] = 'Confusion matrix [sorted by scores, global accuracy = %0.3f%%, per-class accuracy = %0.3f%%]',
+ ['occurrence'] = 'Confusiong matrix [sorted by occurences, accuracy = %0.3f%%, per-class accuracy = %0.3f%%]'
+ }
+
+ -- parse matrix / normalize / count scores
+ local diag = Tensor(#classes)
+ local freqs = Tensor(#classes)
+ local unconf = confusion
+ local confusion = confusion:clone()
+ local corrects = 0
+ local total = 0
+ for target = 1,#classes do
+ freqs[target] = confusion[target]:sum()
+ corrects = corrects + confusion[target][target]
+ total = total + freqs[target]
+ confusion[target]:div( math.max(confusion[target]:sum(),1) )
+ diag[target] = confusion[target][target]
+ end
+
+ -- accuracies
+ local accuracy = corrects / total * 100
+ local perclass = 0
+ local total = 0
+ for target = 1,#classes do
+ if confusion[target]:sum() > 0 then
+ perclass = perclass + diag[target]
+ total = total + 1
+ end
+ end
+ perclass = perclass / total * 100
+ freqs:div(unconf:sum())
+
+ -- sort matrix
+ if sortmode == 'score' then
+ _,order = sort(diag,1,true)
+ elseif sortmode == 'occurrence' then
+ _,order = sort(freqs,1,true)
+ else
+ error('sort mode must be one of: score | occurrence')
+ end
+
+ -- render matrix
+ local render = zeros(#classes*block, #classes*block)
+ for target = 1,#classes do
+ for prediction = 1,#classes do
+ render[{ { (target-1)*block+1,target*block }, { (prediction-1)*block+1,prediction*block } }] = confusion[order[target]][order[prediction]]
+ end
+ end
+
+ -- add grid
+ for target = 1,#classes do
+ render[{ {target*block},{} }] = 0.1
+ render[{ {},{target*block} }] = 0.1
+ end
+
+ -- create rendering
+ require 'qtwidget'
+ require 'qttorch'
+ local win1 = qtwidget.newimage( (#render)[2]+legendwidth, (#render)[1] )
+ image.display{image=render, win=win1}
+
+ -- add legend
+ for i in ipairs(classes) do
+ -- background cell
+ win1:setcolor{r=0,g=0,b=0}
+ win1:rectangle((#render)[2],(i-1)*block,legendwidth,block)
+ win1:fill()
+
+ -- legend
+ win1:setfont(qt.QFont{serif=false, size=fontsize})
+ local gscale = diag[order[i]]*0.8+0.2
+ win1:setcolor{r=gscale,g=gscale,b=gscale}
+ win1:moveto((#render)[2]+10,i*block-block/3)
+ win1:show(classes[order[i]])
+
+ -- %
+ win1:setfont(qt.QFont{serif=false, size=fontsize})
+ local gscale = freqs[order[i]]/freqs:max()*0.9+0.1 --3/4
+ win1:setcolor{r=gscale*0.5+0.2,g=gscale*0.5+0.2,b=gscale*0.8+0.2}
+ win1:moveto(90+(#render)[2]+10,i*block-block/3)
+ win1:show(string.format('[%2.2f%% labels]',math.floor(freqs[order[i]]*10000+0.5)/100))
+
+ for j in ipairs(classes) do
+ -- scores
+ local score = confusion[order[j]][order[i]]
+ local gscale = (1-score)*(score*0.8+0.2)
+ win1:setcolor{r=gscale,g=gscale,b=gscale}
+ win1:moveto((i-1)*block+block/5,(j-1)*block+block*2/3)
+ win1:show(string.format('%02.0f',math.floor(score*100+0.5)))
+ end
+ end
+
+ -- generate tensor
+ local t = win1:image():toTensor()
+
+ -- display
+ if display then
+ image.display{image=t, legend=string.format(legend[sortmode],accuracy,perclass)}
+ end
+
+ -- return rendering
+ return t
+end