Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorsoumith <soumith@fb.com>2016-12-21 18:20:34 +0300
committersoumith <soumith@fb.com>2016-12-21 18:25:21 +0300
commit3b8e5a8064b4d359ab691369ace524c9b5b87575 (patch)
treee386005f70c7013716dbf56cba04ac246970932f
parent6691772199c4df338902aafaca67c95d8b3d6a2b (diff)
fixing bug in ClassNLLCriterion for single targets
-rw-r--r--ClassNLLCriterion.lua2
1 files changed, 2 insertions, 0 deletions
diff --git a/ClassNLLCriterion.lua b/ClassNLLCriterion.lua
index 1d3f3b7..d89f439 100644
--- a/ClassNLLCriterion.lua
+++ b/ClassNLLCriterion.lua
@@ -33,6 +33,7 @@ function ClassNLLCriterion:updateOutput(input, target)
else
self.target = self.target:long()
end
+ self.target:resize(1)
self.target[1] = target
elseif torch.typename(input):find('torch%.Cuda.*Tensor') then
self.target = torch.CudaLongTensor and target:cudaLong() or target
@@ -59,6 +60,7 @@ function ClassNLLCriterion:updateGradInput(input, target)
else
self.target = self.target:long()
end
+ self.target:resize(1)
self.target[1] = target
elseif torch.typename(input):find('torch%.Cuda.*Tensor') then
self.target = torch.CudaLongTensor and target:cudaLong() or target