Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/optim.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2016-04-14 18:57:52 +0300
committerSoumith Chintala <soumith@gmail.com>2016-04-14 18:57:52 +0300
commite72d123262c8a8fa81cb9857eb699020ab9fdc3b (patch)
tree82b79132c159b7adee5deccaa02ad1bdf8c3a4cc
parentaeeacbd628c5464db577c229597986e6f59501aa (diff)
parent207bb273b7eb0362a77d67a4e2493bfaa0e425db (diff)
Merge pull request #100 from apaszke/conf_speedup
Improve ConfusionMatrix performance
-rw-r--r--ConfusionMatrix.lua16
1 files changed, 12 insertions, 4 deletions
diff --git a/ConfusionMatrix.lua b/ConfusionMatrix.lua
index 2809fbe..c3c89a5 100644
--- a/ConfusionMatrix.lua
+++ b/ConfusionMatrix.lua
@@ -25,6 +25,7 @@ function ConfusionMatrix:__init(nclasses, classes)
self.averageValid = 0
self.classes = classes or {}
-- buffers
+ self._mat_flat = self.mat:view(-1)
self._target = torch.FloatTensor()
self._prediction = torch.FloatTensor()
self._max = torch.FloatTensor()
@@ -101,10 +102,17 @@ function ConfusionMatrix:batchAdd(predictions, targets)
error("targets has invalid number of dimensions")
end
- --loop over each pair of indices
- for i = 1,preds:size(1) do
- self:_add(preds[i], targs[i])
- end
+ -- non-positive values are considered missing and therefore ignored
+ local mask = targs:ge(1)
+ targs = targs[mask]
+ preds = preds[mask]
+
+ self._mat_flat = self._mat_flat or self.mat:view(-1) -- for backward compatibility
+
+ assert(self.mat:isContiguous() and self.mat:stride(2) == 1)
+ local indices = ((targs - 1) * self.mat:stride(1) + preds):typeAs(self.mat)
+ local ones = torch.ones(1):typeAs(self.mat):expand(indices:size(1))
+ self._mat_flat:indexAdd(1, indices, ones)
end
function ConfusionMatrix:zero()