diff options
author | soumith <soumith@fb.com> | 2016-12-21 18:20:34 +0300 |
---|---|---|
committer | soumith <soumith@fb.com> | 2016-12-21 18:25:21 +0300 |
commit | 3b8e5a8064b4d359ab691369ace524c9b5b87575 (patch) | |
tree | e386005f70c7013716dbf56cba04ac246970932f | |
parent | 6691772199c4df338902aafaca67c95d8b3d6a2b (diff) |
fixing bug in ClassNLLCriterion for single targets
-rw-r--r-- | ClassNLLCriterion.lua | 2 |
1 files changed, 2 insertions, 0 deletions
diff --git a/ClassNLLCriterion.lua b/ClassNLLCriterion.lua index 1d3f3b7..d89f439 100644 --- a/ClassNLLCriterion.lua +++ b/ClassNLLCriterion.lua @@ -33,6 +33,7 @@ function ClassNLLCriterion:updateOutput(input, target) else self.target = self.target:long() end + self.target:resize(1) self.target[1] = target elseif torch.typename(input):find('torch%.Cuda.*Tensor') then self.target = torch.CudaLongTensor and target:cudaLong() or target @@ -59,6 +60,7 @@ function ClassNLLCriterion:updateGradInput(input, target) else self.target = self.target:long() end + self.target:resize(1) self.target[1] = target elseif torch.typename(input):find('torch%.Cuda.*Tensor') then self.target = torch.CudaLongTensor and target:cudaLong() or target |