Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAysegul Dundar <adundar@purdue.edu>2014-12-25 16:40:43 +0300
committerAysegul Dundar <adundar@purdue.edu>2014-12-25 16:40:43 +0300
commit3da4628ab84500ccbd76c8ddac46d18bd345464a (patch)
tree6b629a34c7c7c92daddf862ee3ad51e77ea0942c /ClassNLLCriterion.lua
parent28f235cbd8f872abf5f89a534ab6066fc3f3dfbe (diff)
correction in batch mode
Diffstat (limited to 'ClassNLLCriterion.lua')
-rw-r--r--ClassNLLCriterion.lua20
1 files changed, 14 insertions, 6 deletions
diff --git a/ClassNLLCriterion.lua b/ClassNLLCriterion.lua
index ae208c1..997ecef 100644
--- a/ClassNLLCriterion.lua
+++ b/ClassNLLCriterion.lua
@@ -12,9 +12,13 @@ end
function ClassNLLCriterion:updateOutput(input, target)
if input:type() == 'torch.CudaTensor' and not self.weights then
- self._target = self._target or input.new(1)
- self._target[1] = target
- input.nn.ClassNLLCriterion_updateOutput(self, input, self._target)
+ if input:dim() == 1 then
+ self._target = self._target or input.new(1)
+ self._target[1] = target
+ input.nn.ClassNLLCriterion_updateOutput(self, input, self._target)
+ else
+ input.nn.ClassNLLCriterion_updateOutput(self, input, target)
+ end
self.output = self.outputTensor[1]
return self.output
end
@@ -48,9 +52,13 @@ function ClassNLLCriterion:updateGradInput(input, target)
self.gradInput:zero()
if input:type() == 'torch.CudaTensor' and not self.weights then
- self._target = self._target or input.new(1)
- self._target[1] = target
- input.nn.ClassNLLCriterion_updateGradInput(self, input, self._target)
+ if input:dim() == 1 then
+ self._target = self._target or input.new(1)
+ self._target[1] = target
+ input.nn.ClassNLLCriterion_updateGradInput(self, input, self._target)
+ else
+ input.nn.ClassNLLCriterion_updateGradInput(self, input, target)
+ end
return self.gradInput
end