Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/optim.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authornicholas-leonard <nick@nikopia.org>2014-10-22 18:30:37 +0400
committernicholas-leonard <nick@nikopia.org>2014-10-22 18:30:37 +0400
commitf32eb3c5922965113a26eedf371f79285f2691f0 (patch)
tree5473b42c428133f81ffe6bc08e8d5f22a6cf59bb /ConfusionMatrix.lua
parent13889aefcd8e8d37b56ee98778c5bcb1519a7f2c (diff)
Confusion:batchAdd supports cuda tensors
Diffstat (limited to 'ConfusionMatrix.lua')
-rw-r--r--ConfusionMatrix.lua51
1 files changed, 29 insertions, 22 deletions
diff --git a/ConfusionMatrix.lua b/ConfusionMatrix.lua
index 86d6d18..7d33bb0 100644
--- a/ConfusionMatrix.lua
+++ b/ConfusionMatrix.lua
@@ -24,6 +24,12 @@ function ConfusionMatrix:__init(nclasses, classes)
self.totalValid = 0
self.averageValid = 0
self.classes = classes or {}
+ -- buffers
+ self._target = torch.FloatTensor()
+ self._prediction = torch.FloatTensor()
+ self._max = torch.FloatTensor()
+ self._pred_idx = torch.LongTensor()
+ self._targ_idx = torch.LongTensor()
end
-- takes scalar prediction and target as input
@@ -39,57 +45,58 @@ function ConfusionMatrix:add(prediction, target)
if type(prediction) == 'number' then
-- comparing numbers
self:_add(prediction, target)
- elseif type(target) == 'number' then
- -- prediction is a vector, then target assumed to be an index
- self.prediction_1d = self.prediction_1d or torch.FloatTensor(self.nclasses)
- self.prediction_1d:copy(prediction)
- local _,prediction = self.prediction_1d:max(1)
- self:_add(prediction[1], target)
else
- -- both prediction and target are vectors
- self.prediction_1d = self.prediction_1d or torch.FloatTensor(self.nclasses)
- self.prediction_1d:copy(prediction)
- self.target_1d = self.target_1d or torch.FloatTensor(self.nclasses)
- self.target_1d:copy(target)
- local _,prediction = self.prediction_1d:max(1)
- local _,target = self.target_1d:max(1)
- self:_add(prediction[1], target[1])
+ self._prediction:resize(prediction:size()):copy(prediction)
+ if type(target) == 'number' then
+ -- prediction is a vector, then target assumed to be an index
+ self._max:max(self._pred_idx, self._prediction, 1)
+ self:_add(self._pred_idx[1], target)
+ else
+ -- both prediction and target are vectors
+ self._target:resize(target:size()):copy(target)
+ self._max:max(self._targ_idx, self._target, 1)
+ self._max:max(self._pred_idx, self._prediction, 1)
+ self:_add(self._pred_idx[1], self._targ_idx[1])
+ end
end
end
function ConfusionMatrix:batchAdd(predictions, targets)
local preds, targs, __
+ self._prediction:resize(predictions:size()):copy(predictions)
if predictions:dim() == 1 then
-- predictions is a vector of classes
- preds = predictions
+ preds = self._prediction
elseif predictions:dim() == 2 then
-- prediction is a matrix of class likelihoods
if predictions:size(2) == 1 then
-- or prediction just needs flattening
- preds = predictions:select(2,1)
+ preds = self._prediction:select(2,1)
else
- __,preds = predictions:max(2)
- preds:resize(preds:size(1))
+ self._max:max(self._pred_idx, self._prediction, 2)
+ preds = self._pred_idx:select(2,1)
end
else
error("predictions has invalid number of dimensions")
end
+ self._target:resize(targets:size()):copy(targets)
if targets:dim() == 1 then
-- targets is a vector of classes
- targs = targets
+ targs = self._target
elseif targets:dim() == 2 then
-- targets is a matrix of one-hot rows
if targets:size(2) == 1 then
-- or targets just needs flattening
- targs = targets:select(2,1)
+ targs = self._target:select(2,1)
else
- __,targs = targets:max(2)
- targs:resize(targs:size(1))
+ self._max:max(self._targ_idx, self._target, 2)
+ targs = self._targ_idx:select(2,1)
end
else
error("targets has invalid number of dimensions")
end
+
--loop over each pair of indices
for i = 1,preds:size(1) do
self:_add(preds[i], targs[i])