diff options
author | Clement Farabet <clement.farabet@gmail.com> | 2012-04-02 23:18:23 +0400 |
---|---|---|
committer | Clement Farabet <clement.farabet@gmail.com> | 2012-04-02 23:18:23 +0400 |
commit | 42cb83a21e09903fd7f3a8e1d754cedfbf5a7faf (patch) | |
tree | 5187b6e8143afe4d8a053e1b85bd14f405013a5f /ConfusionMatrix.lua | |
parent | 572359a16248a7c2f931847911c0269b02634b77 (diff) |
Moved some stable classes from nnx to optim.
They make more sense here.
Diffstat (limited to 'ConfusionMatrix.lua')
-rw-r--r-- | ConfusionMatrix.lua | 115 |
1 files changed, 115 insertions, 0 deletions
diff --git a/ConfusionMatrix.lua b/ConfusionMatrix.lua new file mode 100644 index 0000000..a4d44e5 --- /dev/null +++ b/ConfusionMatrix.lua @@ -0,0 +1,115 @@ +---------------------------------------------------------------------- +-- A Confusion Matrix class +-- +-- Example: +-- conf = optim.ConfusionMatrix( {'cat','dog','person'} ) -- new matrix +-- conf:zero() -- reset matrix +-- for i = 1,N do +-- conf:add( neuralnet:forward(sample), label ) -- accumulate errors +-- end +-- print(conf) -- print matrix +-- +local ConfusionMatrix = torch.class('optim.ConfusionMatrix') + +function ConfusionMatrix:__init(nclasses, classes) + if type(nclasses) == 'table' then + classes = nclasses + nclasses = #classes + end + self.mat = torch.FloatTensor(nclasses,nclasses):zero() + self.valids = torch.FloatTensor(nclasses):zero() + self.unionvalids = torch.FloatTensor(nclasses):zero() + self.nclasses = nclasses + self.totalValid = 0 + self.averageValid = 0 + self.classes = classes or {} +end + +function ConfusionMatrix:add(prediction, target) + if type(prediction) == 'number' then + -- comparing numbers + self.mat[target][prediction] = self.mat[target][prediction] + 1 + elseif type(target) == 'number' then + -- prediction is a vector, then target assumed to be an index + local prediction_1d = torch.FloatTensor(self.nclasses):copy(prediction) + local _,prediction = prediction_1d:max(1) + self.mat[target][prediction[1]] = self.mat[target][prediction[1]] + 1 + else + -- both prediction and target are vectors + local prediction_1d = torch.FloatTensor(self.nclasses):copy(prediction) + local target_1d = torch.FloatTensor(self.nclasses):copy(target) + local _,prediction = prediction_1d:max(1) + local _,target = target_1d:max(1) + self.mat[target[1]][prediction[1]] = self.mat[target[1]][prediction[1]] + 1 + end +end + +function ConfusionMatrix:zero() + self.mat:zero() + self.valids:zero() + self.unionvalids:zero() + self.totalValid = 0 + self.averageValid = 0 +end + +function ConfusionMatrix:updateValids() + local total = 0 + for t = 1,self.nclasses do + self.valids[t] = self.mat[t][t] / self.mat:select(1,t):sum() + self.unionvalids[t] = self.mat[t][t] / (self.mat:select(1,t):sum()+self.mat:select(2,t):sum()-self.mat[t][t]) + total = total + self.mat[t][t] + end + self.totalValid = total / self.mat:sum() + self.averageValid = 0 + self.averageUnionValid = 0 + local nvalids = 0 + local nunionvalids = 0 + for t = 1,self.nclasses do + if not sys.isNaN(self.valids[t]) then + self.averageValid = self.averageValid + self.valids[t] + nvalids = nvalids + 1 + end + if not sys.isNaN(self.valids[t]) and not sys.isNaN(self.unionvalids[t]) then + self.averageUnionValid = self.averageUnionValid + self.unionvalids[t] + nunionvalids = nunionvalids + 1 + end + end + self.averageValid = self.averageValid / nvalids + self.averageUnionValid = self.averageUnionValid / nunionvalids +end + +function ConfusionMatrix:__tostring__() + self:updateValids() + local str = 'ConfusionMatrix:\n' + local nclasses = self.nclasses + str = str .. '[' + for t = 1,nclasses do + local pclass = self.valids[t] * 100 + pclass = string.format('%2.3f', pclass) + if t == 1 then + str = str .. '[' + else + str = str .. ' [' + end + for p = 1,nclasses do + str = str .. '' .. string.format('%8d', self.mat[t][p]) + end + if self.classes and self.classes[1] then + if t == nclasses then + str = str .. ']] ' .. pclass .. '% \t[class: ' .. (self.classes[t] or '') .. ']\n' + else + str = str .. '] ' .. pclass .. '% \t[class: ' .. (self.classes[t] or '') .. ']\n' + end + else + if t == nclasses then + str = str .. ']] ' .. pclass .. '% \n' + else + str = str .. '] ' .. pclass .. '% \n' + end + end + end + str = str .. ' + average row correct: ' .. (self.averageValid*100) .. '% \n' + str = str .. ' + average rowUcol correct (VOC measure): ' .. (self.averageUnionValid*100) .. '% \n' + str = str .. ' + global correct: ' .. (self.totalValid*100) .. '%' + return str +end |