Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorsoumith <soumith@fb.com>2016-09-27 03:48:02 +0300
committersoumith <soumith@fb.com>2016-09-27 03:48:02 +0300
commit2bf9487d73b43bd198597d90c0b24ea917e77b17 (patch)
tree85b68d084c43d3e2a2e25bfff7873e5662b609ca /ClassNLLCriterion.lua
parent959f176d4926ca800990e203c2e34b9ec74a5ebc (diff)
making ClassNLLCriterion targets consistent between cpu and cudaclassnllfix
Diffstat (limited to 'ClassNLLCriterion.lua')
-rw-r--r--ClassNLLCriterion.lua4
1 files changed, 2 insertions, 2 deletions
diff --git a/ClassNLLCriterion.lua b/ClassNLLCriterion.lua
index 8e8acbf..0a57f2e 100644
--- a/ClassNLLCriterion.lua
+++ b/ClassNLLCriterion.lua
@@ -33,7 +33,7 @@ function ClassNLLCriterion:updateOutput(input, target)
end
self.target[1] = target
elseif target:type() == 'torch.CudaTensor' then
- self.target = target
+ self.target = target:cudaLong()
else
self.target = target:long()
end
@@ -54,7 +54,7 @@ function ClassNLLCriterion:updateGradInput(input, target)
if type(target) == 'number' then
self.target[1] = target
elseif target:type() == 'torch.CudaTensor' then
- self.target = target
+ self.target = target:cudaLong()
else
self.target = target:long()
end