Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGregory Chanan <gchanan@fb.com>2016-10-25 20:19:53 +0300
committerGregory Chanan <gchanan@fb.com>2016-10-25 20:19:53 +0300
commit766fd1517e12452fbcce9da356b70f8ee3f34f3b (patch)
tree84cd4326b88ede411002e32b0fa822ef1b616f22 /MultiMarginCriterion.lua
parent7caafece75fe1b5b2bb7a92c85db8cde3f5b8e69 (diff)
Add integer indexing for MultiMarginCriterion.
Diffstat (limited to 'MultiMarginCriterion.lua')
-rw-r--r--MultiMarginCriterion.lua14
1 files changed, 12 insertions, 2 deletions
diff --git a/MultiMarginCriterion.lua b/MultiMarginCriterion.lua
index 1a22bde..e312238 100644
--- a/MultiMarginCriterion.lua
+++ b/MultiMarginCriterion.lua
@@ -16,10 +16,15 @@ end
function MultiMarginCriterion:updateOutput(input, target)
-- backward compatibility
if not torch.isTensor(target) then
- self.target_tensor = self.target_tensor or input.new(1)
+ self.target_tensor = self.target_tensor or torch.LongTensor(1)
self.target_tensor[1] = target
target = self.target_tensor
end
+ if torch.typename(input):find('torch%.Cuda.*Tensor') then
+ target = torch.CudaLongTensor and target:cudaLong() or target
+ else
+ target = target:long()
+ end
self.p = self.p or 1
self.output_tensor = self.output_tensor or input.new(1)
input.THNN.MultiMarginCriterion_updateOutput(
@@ -37,10 +42,15 @@ end
function MultiMarginCriterion:updateGradInput(input, target)
if not torch.isTensor(target) then
- self.target_tensor = self.target_tensor or input.new(1)
+ self.target_tensor = self.target_tensor or torch.LongTensor(1)
self.target_tensor[1] = target
target = self.target_tensor
end
+ if torch.typename(input):find('torch%.Cuda.*Tensor') then
+ target = torch.CudaLongTensor and target:cudaLong() or target
+ else
+ target = target:long()
+ end
input.THNN.MultiMarginCriterion_updateGradInput(
input:cdata(),
target:cdata(),