Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGregory Chanan <gchanan@fb.com>2016-10-10 20:02:57 +0300
committerGregory Chanan <gchanan@fb.com>2016-10-20 00:52:37 +0300
commit8e86879f0ce0dfa325325db69706d9708dbc6098 (patch)
tree637c036a2813e1be2dc0a9127f814ab2cfdfb6dd /SpatialClassNLLCriterion.lua
parenta8e63f2da3d3d84a7e1eed917572901a9ffba5d9 (diff)
Generic support for cuda tensor types in SpatialClassNLLCriterion.
Diffstat (limited to 'SpatialClassNLLCriterion.lua')
-rw-r--r--SpatialClassNLLCriterion.lua8
1 files changed, 4 insertions, 4 deletions
diff --git a/SpatialClassNLLCriterion.lua b/SpatialClassNLLCriterion.lua
index 54c3b30..fbd3674 100644
--- a/SpatialClassNLLCriterion.lua
+++ b/SpatialClassNLLCriterion.lua
@@ -28,13 +28,13 @@ end
function SpatialClassNLLCriterion:updateOutput(input, target)
if type(target) == 'number' then
- if input:type() == 'torch.CudaTensor' then
+ if torch.typename(input):find('torch%.Cuda.*Tensor') then
self.target = torch.CudaLongTensor and self.target:cudaLong() or self.target:cuda()
else
self.target = self.target:long()
end
self.target[1] = target
- elseif input:type() == 'torch.CudaTensor' then
+ elseif torch.typename(input):find('torch%.Cuda.*Tensor') then
self.target = torch.CudaLongTensor and target:cudaLong() or target
else
self.target = target:long()
@@ -54,13 +54,13 @@ end
function SpatialClassNLLCriterion:updateGradInput(input, target)
if type(target) == 'number' then
- if input:type() == 'torch.CudaTensor' then
+ if torch.typename(input):find('torch%.Cuda.*Tensor') then
self.target = torch.CudaLongTensor and self.target:cudaLong() or self.target:cuda()
else
self.target = self.target:long()
end
self.target[1] = target
- elseif input:type() == 'torch.CudaTensor' then
+ elseif torch.typename(input):find('torch%.Cuda.*Tensor') then
self.target = torch.CudaLongTensor and target:cudaLong() or target
else
self.target = target:long()