Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorsoumith <soumith@fb.com>2014-12-24 20:30:03 +0300
committersoumith <soumith@fb.com>2014-12-24 20:30:03 +0300
commit28f235cbd8f872abf5f89a534ab6066fc3f3dfbe (patch)
tree331bd18fede13df383b4e6bc837dca02f442e0fa
parent289c6a2d91dfec05cc5c55105353e408b8541334 (diff)
fixing small cuda typing issue for ClassNLLCriterion
-rw-r--r--ClassNLLCriterion.lua14
1 files changed, 9 insertions, 5 deletions
diff --git a/ClassNLLCriterion.lua b/ClassNLLCriterion.lua
index 926e707..ae208c1 100644
--- a/ClassNLLCriterion.lua
+++ b/ClassNLLCriterion.lua
@@ -12,9 +12,11 @@ end
function ClassNLLCriterion:updateOutput(input, target)
if input:type() == 'torch.CudaTensor' and not self.weights then
- input.nn.ClassNLLCriterion_updateOutput(self, input, target)
- self.output = self.outputTensor[1]
- return self.output
+ self._target = self._target or input.new(1)
+ self._target[1] = target
+ input.nn.ClassNLLCriterion_updateOutput(self, input, self._target)
+ self.output = self.outputTensor[1]
+ return self.output
end
if input:dim() == 1 then
@@ -46,8 +48,10 @@ function ClassNLLCriterion:updateGradInput(input, target)
self.gradInput:zero()
if input:type() == 'torch.CudaTensor' and not self.weights then
- input.nn.ClassNLLCriterion_updateGradInput(self, input, target)
- return self.gradInput
+ self._target = self._target or input.new(1)
+ self._target[1] = target
+ input.nn.ClassNLLCriterion_updateGradInput(self, input, self._target)
+ return self.gradInput
end
if input:dim() == 1 then