Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorClement Farabet <clement.farabet@gmail.com>2011-07-24 06:36:48 +0400
committerClement Farabet <clement.farabet@gmail.com>2011-07-24 06:36:48 +0400
commita8d2d8f6445bd4069fe812de2856d221e2fad3fb (patch)
treeea82a5931c8621858950f5fee70b02e0f1f1aff7
parentf19c60d367103c51f63b566ebe5e2d5775efecfc (diff)
Added case in ConfusionMatrix, to handle NLL criterions.
-rw-r--r--ConfusionMatrix.lua7
1 files changed, 6 insertions, 1 deletions
diff --git a/ConfusionMatrix.lua b/ConfusionMatrix.lua
index 3486490..ca514f1 100644
--- a/ConfusionMatrix.lua
+++ b/ConfusionMatrix.lua
@@ -18,8 +18,13 @@ function ConfusionMatrix:add(prediction, target)
if type(prediction) == 'number' then
-- comparing numbers
self.mat[target][prediction] = self.mat[target][prediction] + 1
+ elseif type(target) == 'number' then
+ -- prediction is a vector, then target assumed to be an index
+ local prediction_1d = torch.Tensor(prediction):resize(self.nclasses)
+ local _,prediction = lab.max(prediction_1d)
+ self.mat[target][prediction[1]] = self.mat[target][prediction[1]] + 1
else
- -- comparing vectors
+ -- both prediction and target are vectors
local prediction_1d = torch.Tensor(prediction):resize(self.nclasses)
local target_1d = torch.Tensor(target):resize(self.nclasses)
local _,prediction = lab.max(prediction_1d)