diff options
author | Gregory Chanan <gchanan@fb.com> | 2016-10-10 20:02:57 +0300 |
---|---|---|
committer | Gregory Chanan <gchanan@fb.com> | 2016-10-20 00:52:37 +0300 |
commit | 8e86879f0ce0dfa325325db69706d9708dbc6098 (patch) | |
tree | 637c036a2813e1be2dc0a9127f814ab2cfdfb6dd /SpatialClassNLLCriterion.lua | |
parent | a8e63f2da3d3d84a7e1eed917572901a9ffba5d9 (diff) |
Generic support for cuda tensor types in SpatialClassNLLCriterion.
Diffstat (limited to 'SpatialClassNLLCriterion.lua')
-rw-r--r-- | SpatialClassNLLCriterion.lua | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/SpatialClassNLLCriterion.lua b/SpatialClassNLLCriterion.lua index 54c3b30..fbd3674 100644 --- a/SpatialClassNLLCriterion.lua +++ b/SpatialClassNLLCriterion.lua @@ -28,13 +28,13 @@ end function SpatialClassNLLCriterion:updateOutput(input, target) if type(target) == 'number' then - if input:type() == 'torch.CudaTensor' then + if torch.typename(input):find('torch%.Cuda.*Tensor') then self.target = torch.CudaLongTensor and self.target:cudaLong() or self.target:cuda() else self.target = self.target:long() end self.target[1] = target - elseif input:type() == 'torch.CudaTensor' then + elseif torch.typename(input):find('torch%.Cuda.*Tensor') then self.target = torch.CudaLongTensor and target:cudaLong() or target else self.target = target:long() @@ -54,13 +54,13 @@ end function SpatialClassNLLCriterion:updateGradInput(input, target) if type(target) == 'number' then - if input:type() == 'torch.CudaTensor' then + if torch.typename(input):find('torch%.Cuda.*Tensor') then self.target = torch.CudaLongTensor and self.target:cudaLong() or self.target:cuda() else self.target = self.target:long() end self.target[1] = target - elseif input:type() == 'torch.CudaTensor' then + elseif torch.typename(input):find('torch%.Cuda.*Tensor') then self.target = torch.CudaLongTensor and target:cudaLong() or target else self.target = target:long() |