diff options
author | Clement Farabet <clement.farabet@gmail.com> | 2011-07-24 06:36:48 +0400 |
---|---|---|
committer | Clement Farabet <clement.farabet@gmail.com> | 2011-07-24 06:36:48 +0400 |
commit | a8d2d8f6445bd4069fe812de2856d221e2fad3fb (patch) | |
tree | ea82a5931c8621858950f5fee70b02e0f1f1aff7 | |
parent | f19c60d367103c51f63b566ebe5e2d5775efecfc (diff) |
Added case in ConfusionMatrix, to handle NLL criterions.
-rw-r--r-- | ConfusionMatrix.lua | 7 |
1 files changed, 6 insertions, 1 deletions
diff --git a/ConfusionMatrix.lua b/ConfusionMatrix.lua index 3486490..ca514f1 100644 --- a/ConfusionMatrix.lua +++ b/ConfusionMatrix.lua @@ -18,8 +18,13 @@ function ConfusionMatrix:add(prediction, target) if type(prediction) == 'number' then -- comparing numbers self.mat[target][prediction] = self.mat[target][prediction] + 1 + elseif type(target) == 'number' then + -- prediction is a vector, then target assumed to be an index + local prediction_1d = torch.Tensor(prediction):resize(self.nclasses) + local _,prediction = lab.max(prediction_1d) + self.mat[target][prediction[1]] = self.mat[target][prediction[1]] + 1 else - -- comparing vectors + -- both prediction and target are vectors local prediction_1d = torch.Tensor(prediction):resize(self.nclasses) local target_1d = torch.Tensor(target):resize(self.nclasses) local _,prediction = lab.max(prediction_1d) |