diff options
Diffstat (limited to 'ConfusionMatrix.lua')
-rw-r--r-- | ConfusionMatrix.lua | 4 |
1 files changed, 4 insertions, 0 deletions
diff --git a/ConfusionMatrix.lua b/ConfusionMatrix.lua index 3b5fa9c..1467649 100644 --- a/ConfusionMatrix.lua +++ b/ConfusionMatrix.lua @@ -34,6 +34,8 @@ end -- takes scalar prediction and target as input function ConfusionMatrix:_add(p, t) + assert(p and type(p) == 'number') + assert(t and type(t) == 'number') -- non-positive values are considered missing -- and therefore ignored if t > 0 then @@ -47,12 +49,14 @@ function ConfusionMatrix:add(prediction, target) self:_add(prediction, target) else self._prediction:resize(prediction:size()):copy(prediction) + assert(prediction:dim() == 1) if type(target) == 'number' then -- prediction is a vector, then target assumed to be an index self._max:max(self._pred_idx, self._prediction, 1) self:_add(self._pred_idx[1], target) else -- both prediction and target are vectors + assert(target:dim() == 1) self._target:resize(target:size()):copy(target) self._max:max(self._targ_idx, self._target, 1) self._max:max(self._pred_idx, self._prediction, 1) |