diff options
author | Soumith Chintala <soumith@gmail.com> | 2017-05-21 20:48:19 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-05-21 20:48:19 +0300 |
commit | 78aac1a015ebba0655a7fdad8a4a09419b68da67 (patch) | |
tree | e534fc3f1f1192102fd4b5b25974abe6d4d7f9f2 /ClassNLLCriterion.lua | |
parent | 482537275df7fde77cc4dcc1d93de33cbfafde9f (diff) |
Revert "Revert "ClassNLLCriterion supports missing targets""revert-1217-revert-1215-ClassNLLCriterion-missing-target
Diffstat (limited to 'ClassNLLCriterion.lua')
-rw-r--r-- | ClassNLLCriterion.lua | 15 |
1 files changed, 7 insertions, 8 deletions
diff --git a/ClassNLLCriterion.lua b/ClassNLLCriterion.lua index d89f439..dae0e66 100644 --- a/ClassNLLCriterion.lua +++ b/ClassNLLCriterion.lua @@ -1,13 +1,10 @@ local THNN = require 'nn.THNN' local ClassNLLCriterion, parent = torch.class('nn.ClassNLLCriterion', 'nn.Criterion') -function ClassNLLCriterion:__init(weights, sizeAverage) +function ClassNLLCriterion:__init(weights, sizeAverage, ignoreIndex) parent.__init(self) - if sizeAverage ~= nil then - self.sizeAverage = sizeAverage - else - self.sizeAverage = true - end + self.sizeAverage = (sizeAverage == nil) and true or sizeAverage + self.ignoreIndex = ignoreIndex or -100 -- this target index will be ignored if weights then assert(weights:dim() == 1, "weights input should be 1-D Tensor") self.weights = weights @@ -47,7 +44,8 @@ function ClassNLLCriterion:updateOutput(input, target) self.output_tensor:cdata(), self.sizeAverage, THNN.optionalTensor(self.weights), - self.total_weight_tensor:cdata() + self.total_weight_tensor:cdata(), + self.ignoreIndex ) self.output = self.output_tensor[1] return self.output, self.total_weight_tensor[1] @@ -76,7 +74,8 @@ function ClassNLLCriterion:updateGradInput(input, target) self.gradInput:cdata(), self.sizeAverage, THNN.optionalTensor(self.weights), - self.total_weight_tensor:cdata() + self.total_weight_tensor:cdata(), + self.ignoreIndex ) return self.gradInput |