diff options
author | Clement Farabet <clement.farabet@gmail.com> | 2011-09-05 01:58:37 +0400 |
---|---|---|
committer | Clement Farabet <clement.farabet@gmail.com> | 2011-09-05 01:58:37 +0400 |
commit | ecd615e257bc819a85eb1d97d9ebac3a6de989f0 (patch) | |
tree | 6366f3704ac0fabe48a1e46e5be5e6ac242de2d7 | |
parent | d82237992e1f8ec21ff937a7117fa439e2e8c068 (diff) |
Added serializing methods for confusion matrix.
-rw-r--r-- | ConfusionMatrix.lua | 22 |
1 files changed, 20 insertions, 2 deletions
diff --git a/ConfusionMatrix.lua b/ConfusionMatrix.lua index 96c9aa4..ed3f000 100644 --- a/ConfusionMatrix.lua +++ b/ConfusionMatrix.lua @@ -11,7 +11,7 @@ function ConfusionMatrix:__init(nclasses, classes) self.nclasses = nclasses self.totalValid = 0 self.averageValid = 0 - self.classes = classes + self.classes = classes or {} end function ConfusionMatrix:add(prediction, target) @@ -74,7 +74,7 @@ function ConfusionMatrix:__tostring__() for p = 1,nclasses do str = str .. '' .. string.format('%8d', self.mat[t][p]) end - if self.classes then + if self.classes and self.classes[1] then if t == nclasses then str = str .. ']] ' .. pclass .. '% \t[class: ' .. (self.classes[t] or '') .. ']\n' else @@ -92,3 +92,21 @@ function ConfusionMatrix:__tostring__() str = str .. ' + global correct: ' .. (self.totalValid*100) .. '%' return str end + +function ConfusionMatrix:write(file) + file:writeObject(self.mat) + file:writeObject(self.valids) + file:writeInt(self.nclasses) + file:writeInt(self.totalValid) + file:writeInt(self.averageValid) + file:writeObject(self.classes) +end + +function ConfusionMatrix:read(file) + self.mat = file:readObject() + self.valids = file:readObject() + self.nclasses = file:readInt() + self.totalValid = file:readInt() + self.averageValid = file:readInt() + self.classes = file:readObject() +end |