diff options
author | Adam Paszke <adam.paszke@gmail.com> | 2016-04-20 16:54:09 +0300 |
---|---|---|
committer | Adam Paszke <adam.paszke@gmail.com> | 2016-04-20 16:58:02 +0300 |
commit | 816bc1d169cc41cd151cc5009749dc8bafee52d1 (patch) | |
tree | 05fb5fded9511d77b49347dc72fd695aa5c7e2ca /ConfusionMatrix.lua | |
parent | e72d123262c8a8fa81cb9857eb699020ab9fdc3b (diff) |
Fix possible type mismatch in ConfusionMatrix
Diffstat (limited to 'ConfusionMatrix.lua')
-rw-r--r-- | ConfusionMatrix.lua | 2 |
1 files changed, 2 insertions, 0 deletions
diff --git a/ConfusionMatrix.lua b/ConfusionMatrix.lua index c3c89a5..8659a4e 100644 --- a/ConfusionMatrix.lua +++ b/ConfusionMatrix.lua @@ -109,6 +109,8 @@ function ConfusionMatrix:batchAdd(predictions, targets) self._mat_flat = self._mat_flat or self.mat:view(-1) -- for backward compatibility + preds = preds:typeAs(targs) + assert(self.mat:isContiguous() and self.mat:stride(2) == 1) local indices = ((targs - 1) * self.mat:stride(1) + preds):typeAs(self.mat) local ones = torch.ones(1):typeAs(self.mat):expand(indices:size(1)) |