diff options
author | Aysegul Dundar <adundar@purdue.edu> | 2014-12-25 16:40:43 +0300 |
---|---|---|
committer | Aysegul Dundar <adundar@purdue.edu> | 2014-12-25 16:40:43 +0300 |
commit | 3da4628ab84500ccbd76c8ddac46d18bd345464a (patch) | |
tree | 6b629a34c7c7c92daddf862ee3ad51e77ea0942c /ClassNLLCriterion.lua | |
parent | 28f235cbd8f872abf5f89a534ab6066fc3f3dfbe (diff) |
correction in batch mode
Diffstat (limited to 'ClassNLLCriterion.lua')
-rw-r--r-- | ClassNLLCriterion.lua | 20 |
1 files changed, 14 insertions, 6 deletions
diff --git a/ClassNLLCriterion.lua b/ClassNLLCriterion.lua index ae208c1..997ecef 100644 --- a/ClassNLLCriterion.lua +++ b/ClassNLLCriterion.lua @@ -12,9 +12,13 @@ end function ClassNLLCriterion:updateOutput(input, target) if input:type() == 'torch.CudaTensor' and not self.weights then - self._target = self._target or input.new(1) - self._target[1] = target - input.nn.ClassNLLCriterion_updateOutput(self, input, self._target) + if input:dim() == 1 then + self._target = self._target or input.new(1) + self._target[1] = target + input.nn.ClassNLLCriterion_updateOutput(self, input, self._target) + else + input.nn.ClassNLLCriterion_updateOutput(self, input, target) + end self.output = self.outputTensor[1] return self.output end @@ -48,9 +52,13 @@ function ClassNLLCriterion:updateGradInput(input, target) self.gradInput:zero() if input:type() == 'torch.CudaTensor' and not self.weights then - self._target = self._target or input.new(1) - self._target[1] = target - input.nn.ClassNLLCriterion_updateGradInput(self, input, self._target) + if input:dim() == 1 then + self._target = self._target or input.new(1) + self._target[1] = target + input.nn.ClassNLLCriterion_updateGradInput(self, input, self._target) + else + input.nn.ClassNLLCriterion_updateGradInput(self, input, target) + end return self.gradInput end |