diff options
author | Clement Farabet <clement.farabet@gmail.com> | 2012-07-28 18:49:07 +0400 |
---|---|---|
committer | Clement Farabet <clement.farabet@gmail.com> | 2012-07-28 18:49:07 +0400 |
commit | 5bd5bc0834c2e1cc8263bfa200e1b33ad6334240 (patch) | |
tree | acd6071067cd0c43f7747f67432ed76603e66880 /ConfusionMatrix.lua | |
parent | 80a1cfef3538d057856a3b0d8b52f85746645daa (diff) |
Added renderer to confusion matrix
Diffstat (limited to 'ConfusionMatrix.lua')
-rw-r--r-- | ConfusionMatrix.lua | 116 |
1 files changed, 116 insertions, 0 deletions
diff --git a/ConfusionMatrix.lua b/ConfusionMatrix.lua index a4d44e5..58d1536 100644 --- a/ConfusionMatrix.lua +++ b/ConfusionMatrix.lua @@ -8,6 +8,7 @@ -- conf:add( neuralnet:forward(sample), label ) -- accumulate errors -- end -- print(conf) -- print matrix +-- image.display(conf:render()) -- render matrix -- local ConfusionMatrix = torch.class('optim.ConfusionMatrix') @@ -113,3 +114,118 @@ function ConfusionMatrix:__tostring__() str = str .. ' + global correct: ' .. (self.totalValid*100) .. '%' return str end + +function ConfusionMatrix:render(sortmode, display, block, legendwidth) + -- args + local confusion = self.mat + local classes = self.classes + local sortmode = sortmode or 'score' -- 'score' or 'occurrence' + local block = block or 25 + local legendwidth = legendwidth or 200 + local display = display or false + + -- legends + local legend = { + ['score'] = 'Confusion matrix [sorted by scores, global accuracy = %0.3f%%, per-class accuracy = %0.3f%%]', + ['occurrence'] = 'Confusiong matrix [sorted by occurences, accuracy = %0.3f%%, per-class accuracy = %0.3f%%]' + } + + -- parse matrix / normalize / count scores + local diag = Tensor(#classes) + local freqs = Tensor(#classes) + local unconf = confusion + local confusion = confusion:clone() + local corrects = 0 + local total = 0 + for target = 1,#classes do + freqs[target] = confusion[target]:sum() + corrects = corrects + confusion[target][target] + total = total + freqs[target] + confusion[target]:div( math.max(confusion[target]:sum(),1) ) + diag[target] = confusion[target][target] + end + + -- accuracies + local accuracy = corrects / total * 100 + local perclass = 0 + local total = 0 + for target = 1,#classes do + if confusion[target]:sum() > 0 then + perclass = perclass + diag[target] + total = total + 1 + end + end + perclass = perclass / total * 100 + freqs:div(unconf:sum()) + + -- sort matrix + if sortmode == 'score' then + _,order = sort(diag,1,true) + elseif sortmode == 'occurrence' then + _,order = sort(freqs,1,true) + else + error('sort mode must be one of: score | occurrence') + end + + -- render matrix + local render = zeros(#classes*block, #classes*block) + for target = 1,#classes do + for prediction = 1,#classes do + render[{ { (target-1)*block+1,target*block }, { (prediction-1)*block+1,prediction*block } }] = confusion[order[target]][order[prediction]] + end + end + + -- add grid + for target = 1,#classes do + render[{ {target*block},{} }] = 0.1 + render[{ {},{target*block} }] = 0.1 + end + + -- create rendering + require 'qtwidget' + require 'qttorch' + local win1 = qtwidget.newimage( (#render)[2]+legendwidth, (#render)[1] ) + image.display{image=render, win=win1} + + -- add legend + for i in ipairs(classes) do + -- background cell + win1:setcolor{r=0,g=0,b=0} + win1:rectangle((#render)[2],(i-1)*block,legendwidth,block) + win1:fill() + + -- legend + win1:setfont(qt.QFont{serif=false, size=fontsize}) + local gscale = diag[order[i]]*0.8+0.2 + win1:setcolor{r=gscale,g=gscale,b=gscale} + win1:moveto((#render)[2]+10,i*block-block/3) + win1:show(classes[order[i]]) + + -- % + win1:setfont(qt.QFont{serif=false, size=fontsize}) + local gscale = freqs[order[i]]/freqs:max()*0.9+0.1 --3/4 + win1:setcolor{r=gscale*0.5+0.2,g=gscale*0.5+0.2,b=gscale*0.8+0.2} + win1:moveto(90+(#render)[2]+10,i*block-block/3) + win1:show(string.format('[%2.2f%% labels]',math.floor(freqs[order[i]]*10000+0.5)/100)) + + for j in ipairs(classes) do + -- scores + local score = confusion[order[j]][order[i]] + local gscale = (1-score)*(score*0.8+0.2) + win1:setcolor{r=gscale,g=gscale,b=gscale} + win1:moveto((i-1)*block+block/5,(j-1)*block+block*2/3) + win1:show(string.format('%02.0f',math.floor(score*100+0.5))) + end + end + + -- generate tensor + local t = win1:image():toTensor() + + -- display + if display then + image.display{image=t, legend=string.format(legend[sortmode],accuracy,perclass)} + end + + -- return rendering + return t +end |