Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2017-05-21 20:48:19 +0300
committerGitHub <noreply@github.com>2017-05-21 20:48:19 +0300
commit78aac1a015ebba0655a7fdad8a4a09419b68da67 (patch)
treee534fc3f1f1192102fd4b5b25974abe6d4d7f9f2 /ClassNLLCriterion.lua
parent482537275df7fde77cc4dcc1d93de33cbfafde9f (diff)
Revert "Revert "ClassNLLCriterion supports missing targets""revert-1217-revert-1215-ClassNLLCriterion-missing-target
Diffstat (limited to 'ClassNLLCriterion.lua')
-rw-r--r--ClassNLLCriterion.lua15
1 files changed, 7 insertions, 8 deletions
diff --git a/ClassNLLCriterion.lua b/ClassNLLCriterion.lua
index d89f439..dae0e66 100644
--- a/ClassNLLCriterion.lua
+++ b/ClassNLLCriterion.lua
@@ -1,13 +1,10 @@
local THNN = require 'nn.THNN'
local ClassNLLCriterion, parent = torch.class('nn.ClassNLLCriterion', 'nn.Criterion')
-function ClassNLLCriterion:__init(weights, sizeAverage)
+function ClassNLLCriterion:__init(weights, sizeAverage, ignoreIndex)
parent.__init(self)
- if sizeAverage ~= nil then
- self.sizeAverage = sizeAverage
- else
- self.sizeAverage = true
- end
+ self.sizeAverage = (sizeAverage == nil) and true or sizeAverage
+ self.ignoreIndex = ignoreIndex or -100 -- this target index will be ignored
if weights then
assert(weights:dim() == 1, "weights input should be 1-D Tensor")
self.weights = weights
@@ -47,7 +44,8 @@ function ClassNLLCriterion:updateOutput(input, target)
self.output_tensor:cdata(),
self.sizeAverage,
THNN.optionalTensor(self.weights),
- self.total_weight_tensor:cdata()
+ self.total_weight_tensor:cdata(),
+ self.ignoreIndex
)
self.output = self.output_tensor[1]
return self.output, self.total_weight_tensor[1]
@@ -76,7 +74,8 @@ function ClassNLLCriterion:updateGradInput(input, target)
self.gradInput:cdata(),
self.sizeAverage,
THNN.optionalTensor(self.weights),
- self.total_weight_tensor:cdata()
+ self.total_weight_tensor:cdata(),
+ self.ignoreIndex
)
return self.gradInput