diff options
author | Soumith Chintala <soumith@gmail.com> | 2016-04-14 18:57:52 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2016-04-14 18:57:52 +0300 |
commit | e72d123262c8a8fa81cb9857eb699020ab9fdc3b (patch) | |
tree | 82b79132c159b7adee5deccaa02ad1bdf8c3a4cc | |
parent | aeeacbd628c5464db577c229597986e6f59501aa (diff) | |
parent | 207bb273b7eb0362a77d67a4e2493bfaa0e425db (diff) |
Merge pull request #100 from apaszke/conf_speedup
Improve ConfusionMatrix performance
-rw-r--r-- | ConfusionMatrix.lua | 16 |
1 files changed, 12 insertions, 4 deletions
diff --git a/ConfusionMatrix.lua b/ConfusionMatrix.lua index 2809fbe..c3c89a5 100644 --- a/ConfusionMatrix.lua +++ b/ConfusionMatrix.lua @@ -25,6 +25,7 @@ function ConfusionMatrix:__init(nclasses, classes) self.averageValid = 0 self.classes = classes or {} -- buffers + self._mat_flat = self.mat:view(-1) self._target = torch.FloatTensor() self._prediction = torch.FloatTensor() self._max = torch.FloatTensor() @@ -101,10 +102,17 @@ function ConfusionMatrix:batchAdd(predictions, targets) error("targets has invalid number of dimensions") end - --loop over each pair of indices - for i = 1,preds:size(1) do - self:_add(preds[i], targs[i]) - end + -- non-positive values are considered missing and therefore ignored + local mask = targs:ge(1) + targs = targs[mask] + preds = preds[mask] + + self._mat_flat = self._mat_flat or self.mat:view(-1) -- for backward compatibility + + assert(self.mat:isContiguous() and self.mat:stride(2) == 1) + local indices = ((targs - 1) * self.mat:stride(1) + preds):typeAs(self.mat) + local ones = torch.ones(1):typeAs(self.mat):expand(indices:size(1)) + self._mat_flat:indexAdd(1, indices, ones) end function ConfusionMatrix:zero() |