diff options
author | Gregory Chanan <gchanan@fb.com> | 2016-10-26 22:05:43 +0300 |
---|---|---|
committer | Gregory Chanan <gchanan@fb.com> | 2016-10-26 22:05:43 +0300 |
commit | a680f13066a23d0736ba5edb9ef9f55992bb3150 (patch) | |
tree | f6ccdc5a09ea09ba0dbd8c4b7d4db05f2c20a005 /ClassNLLCriterion.lua | |
parent | 07d27ef9755be1d54c2638d6668e9061fe55aa28 (diff) |
Convert ClassNLLCriterion targets for all cuda types.
Diffstat (limited to 'ClassNLLCriterion.lua')
-rw-r--r-- | ClassNLLCriterion.lua | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/ClassNLLCriterion.lua b/ClassNLLCriterion.lua index 6dad0ba..1d3f3b7 100644 --- a/ClassNLLCriterion.lua +++ b/ClassNLLCriterion.lua @@ -28,13 +28,13 @@ end function ClassNLLCriterion:updateOutput(input, target) if type(target) == 'number' then - if input:type() == 'torch.CudaTensor' then + if torch.typename(input):find('torch%.Cuda.*Tensor') then self.target = torch.CudaLongTensor and self.target:cudaLong() or self.target:cuda() else self.target = self.target:long() end self.target[1] = target - elseif input:type() == 'torch.CudaTensor' then + elseif torch.typename(input):find('torch%.Cuda.*Tensor') then self.target = torch.CudaLongTensor and target:cudaLong() or target else self.target = target:long() @@ -54,13 +54,13 @@ end function ClassNLLCriterion:updateGradInput(input, target) if type(target) == 'number' then - if input:type() == 'torch.CudaTensor' then + if torch.typename(input):find('torch%.Cuda.*Tensor') then self.target = torch.CudaLongTensor and self.target:cudaLong() or self.target:cuda() else self.target = self.target:long() end self.target[1] = target - elseif input:type() == 'torch.CudaTensor' then + elseif torch.typename(input):find('torch%.Cuda.*Tensor') then self.target = torch.CudaLongTensor and target:cudaLong() or target else self.target = target:long() |