Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorsoumith <soumith@fb.com>2016-09-27 03:48:02 +0300
committersoumith <soumith@fb.com>2016-09-27 03:48:02 +0300
commit2bf9487d73b43bd198597d90c0b24ea917e77b17 (patch)
tree85b68d084c43d3e2a2e25bfff7873e5662b609ca
parent959f176d4926ca800990e203c2e34b9ec74a5ebc (diff)
making ClassNLLCriterion targets consistent between cpu and cudaclassnllfix
-rw-r--r--ClassNLLCriterion.lua4
-rw-r--r--SpatialClassNLLCriterion.lua4
2 files changed, 4 insertions, 4 deletions
diff --git a/ClassNLLCriterion.lua b/ClassNLLCriterion.lua
index 8e8acbf..0a57f2e 100644
--- a/ClassNLLCriterion.lua
+++ b/ClassNLLCriterion.lua
@@ -33,7 +33,7 @@ function ClassNLLCriterion:updateOutput(input, target)
end
self.target[1] = target
elseif target:type() == 'torch.CudaTensor' then
- self.target = target
+ self.target = target:cudaLong()
else
self.target = target:long()
end
@@ -54,7 +54,7 @@ function ClassNLLCriterion:updateGradInput(input, target)
if type(target) == 'number' then
self.target[1] = target
elseif target:type() == 'torch.CudaTensor' then
- self.target = target
+ self.target = target:cudaLong()
else
self.target = target:long()
end
diff --git a/SpatialClassNLLCriterion.lua b/SpatialClassNLLCriterion.lua
index 8652e88..c20a2f6 100644
--- a/SpatialClassNLLCriterion.lua
+++ b/SpatialClassNLLCriterion.lua
@@ -33,7 +33,7 @@ function SpatialClassNLLCriterion:updateOutput(input, target)
end
self.target[1] = target
elseif target:type() == 'torch.CudaTensor' then
- self.target = target
+ self.target = target:cudaLong()
else
self.target = target:long()
end
@@ -54,7 +54,7 @@ function SpatialClassNLLCriterion:updateGradInput(input, target)
if type(target) == 'number' then
self.target[1] = target
elseif target:type() == 'torch.CudaTensor' then
- self.target = target
+ self.target = target:cudaLong()
else
self.target = target:long()
end