diff options
author | soumith <soumith@fb.com> | 2014-12-24 20:30:03 +0300 |
---|---|---|
committer | soumith <soumith@fb.com> | 2014-12-24 20:30:03 +0300 |
commit | 28f235cbd8f872abf5f89a534ab6066fc3f3dfbe (patch) | |
tree | 331bd18fede13df383b4e6bc837dca02f442e0fa | |
parent | 289c6a2d91dfec05cc5c55105353e408b8541334 (diff) |
fixing small cuda typing issue for ClassNLLCriterion
-rw-r--r-- | ClassNLLCriterion.lua | 14 |
1 files changed, 9 insertions, 5 deletions
diff --git a/ClassNLLCriterion.lua b/ClassNLLCriterion.lua index 926e707..ae208c1 100644 --- a/ClassNLLCriterion.lua +++ b/ClassNLLCriterion.lua @@ -12,9 +12,11 @@ end function ClassNLLCriterion:updateOutput(input, target) if input:type() == 'torch.CudaTensor' and not self.weights then - input.nn.ClassNLLCriterion_updateOutput(self, input, target) - self.output = self.outputTensor[1] - return self.output + self._target = self._target or input.new(1) + self._target[1] = target + input.nn.ClassNLLCriterion_updateOutput(self, input, self._target) + self.output = self.outputTensor[1] + return self.output end if input:dim() == 1 then @@ -46,8 +48,10 @@ function ClassNLLCriterion:updateGradInput(input, target) self.gradInput:zero() if input:type() == 'torch.CudaTensor' and not self.weights then - input.nn.ClassNLLCriterion_updateGradInput(self, input, target) - return self.gradInput + self._target = self._target or input.new(1) + self._target[1] = target + input.nn.ClassNLLCriterion_updateGradInput(self, input, self._target) + return self.gradInput end if input:dim() == 1 then |