Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGregory Chanan <gchanan@fb.com>2016-10-26 22:05:43 +0300
committerGregory Chanan <gchanan@fb.com>2016-10-26 22:05:43 +0300
commita680f13066a23d0736ba5edb9ef9f55992bb3150 (patch)
treef6ccdc5a09ea09ba0dbd8c4b7d4db05f2c20a005 /ClassNLLCriterion.lua
parent07d27ef9755be1d54c2638d6668e9061fe55aa28 (diff)
Convert ClassNLLCriterion targets for all cuda types.
Diffstat (limited to 'ClassNLLCriterion.lua')
-rw-r--r--ClassNLLCriterion.lua8
1 files changed, 4 insertions, 4 deletions
diff --git a/ClassNLLCriterion.lua b/ClassNLLCriterion.lua
index 6dad0ba..1d3f3b7 100644
--- a/ClassNLLCriterion.lua
+++ b/ClassNLLCriterion.lua
@@ -28,13 +28,13 @@ end
function ClassNLLCriterion:updateOutput(input, target)
if type(target) == 'number' then
- if input:type() == 'torch.CudaTensor' then
+ if torch.typename(input):find('torch%.Cuda.*Tensor') then
self.target = torch.CudaLongTensor and self.target:cudaLong() or self.target:cuda()
else
self.target = self.target:long()
end
self.target[1] = target
- elseif input:type() == 'torch.CudaTensor' then
+ elseif torch.typename(input):find('torch%.Cuda.*Tensor') then
self.target = torch.CudaLongTensor and target:cudaLong() or target
else
self.target = target:long()
@@ -54,13 +54,13 @@ end
function ClassNLLCriterion:updateGradInput(input, target)
if type(target) == 'number' then
- if input:type() == 'torch.CudaTensor' then
+ if torch.typename(input):find('torch%.Cuda.*Tensor') then
self.target = torch.CudaLongTensor and self.target:cudaLong() or self.target:cuda()
else
self.target = self.target:long()
end
self.target[1] = target
- elseif input:type() == 'torch.CudaTensor' then
+ elseif torch.typename(input):find('torch%.Cuda.*Tensor') then
self.target = torch.CudaLongTensor and target:cudaLong() or target
else
self.target = target:long()