diff options
author | soumith <soumith@fb.com> | 2016-09-27 03:48:02 +0300 |
---|---|---|
committer | soumith <soumith@fb.com> | 2016-09-27 03:48:02 +0300 |
commit | 2bf9487d73b43bd198597d90c0b24ea917e77b17 (patch) | |
tree | 85b68d084c43d3e2a2e25bfff7873e5662b609ca | |
parent | 959f176d4926ca800990e203c2e34b9ec74a5ebc (diff) |
making ClassNLLCriterion targets consistent between cpu and cudaclassnllfix
-rw-r--r-- | ClassNLLCriterion.lua | 4 | ||||
-rw-r--r-- | SpatialClassNLLCriterion.lua | 4 |
2 files changed, 4 insertions, 4 deletions
diff --git a/ClassNLLCriterion.lua b/ClassNLLCriterion.lua index 8e8acbf..0a57f2e 100644 --- a/ClassNLLCriterion.lua +++ b/ClassNLLCriterion.lua @@ -33,7 +33,7 @@ function ClassNLLCriterion:updateOutput(input, target) end self.target[1] = target elseif target:type() == 'torch.CudaTensor' then - self.target = target + self.target = target:cudaLong() else self.target = target:long() end @@ -54,7 +54,7 @@ function ClassNLLCriterion:updateGradInput(input, target) if type(target) == 'number' then self.target[1] = target elseif target:type() == 'torch.CudaTensor' then - self.target = target + self.target = target:cudaLong() else self.target = target:long() end diff --git a/SpatialClassNLLCriterion.lua b/SpatialClassNLLCriterion.lua index 8652e88..c20a2f6 100644 --- a/SpatialClassNLLCriterion.lua +++ b/SpatialClassNLLCriterion.lua @@ -33,7 +33,7 @@ function SpatialClassNLLCriterion:updateOutput(input, target) end self.target[1] = target elseif target:type() == 'torch.CudaTensor' then - self.target = target + self.target = target:cudaLong() else self.target = target:long() end @@ -54,7 +54,7 @@ function SpatialClassNLLCriterion:updateGradInput(input, target) if type(target) == 'number' then self.target[1] = target elseif target:type() == 'torch.CudaTensor' then - self.target = target + self.target = target:cudaLong() else self.target = target:long() end |