diff options
-rw-r--r-- | ConfusionMatrix.lua | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/ConfusionMatrix.lua b/ConfusionMatrix.lua index 9a221ff..cb286ff 100644 --- a/ConfusionMatrix.lua +++ b/ConfusionMatrix.lua @@ -33,15 +33,15 @@ function ConfusionMatrix:add(prediction, target) elseif type(target) == 'number' then -- prediction is a vector, then target assumed to be an index self.prediction_1d = self.prediction_1d or torch.FloatTensor(self.nclasses) - self.prediction_1d[{}] = prediction + self.prediction_1d:copy(prediction) local _,prediction = self.prediction_1d:max(1) self.mat[target][prediction[1]] = self.mat[target][prediction[1]] + 1 else -- both prediction and target are vectors self.prediction_1d = self.prediction_1d or torch.FloatTensor(self.nclasses) - self.prediction_1d[{}] = prediction + self.prediction_1d:copy(prediction) self.target_1d = self.target_1d or torch.FloatTensor(self.nclasses) - self.target_1d[{}] = target + self.target_1d:copy(target) local _,prediction = self.prediction_1d:max(1) local _,target = self.target_1d:max(1) self.mat[target[1]][prediction[1]] = self.mat[target[1]][prediction[1]] + 1 |