diff options
author | soumith <soumith@fb.com> | 2016-02-20 04:40:57 +0300 |
---|---|---|
committer | soumith <soumith@fb.com> | 2016-02-20 04:44:51 +0300 |
commit | d1f57af38674a246a2e7677def0c63e59e97e61a (patch) | |
tree | 725d74d381a4cfe58c2d072441caec6215d0368c /MultiMarginCriterion.lua | |
parent | cc8e874a8858c7f5e13c2e1b2a26e5e49ae9f91d (diff) |
adding weights to MultiMarginCriterion
Diffstat (limited to 'MultiMarginCriterion.lua')
-rw-r--r-- | MultiMarginCriterion.lua | 13 |
1 files changed, 10 insertions, 3 deletions
diff --git a/MultiMarginCriterion.lua b/MultiMarginCriterion.lua index 7ea78ca..3f480cc 100644 --- a/MultiMarginCriterion.lua +++ b/MultiMarginCriterion.lua @@ -1,10 +1,15 @@ +local THNN = require 'nn.THNN' local MultiMarginCriterion, parent = torch.class('nn.MultiMarginCriterion', 'nn.Criterion') -function MultiMarginCriterion:__init(p) +function MultiMarginCriterion:__init(p, weights) assert(p == nil or p == 1 or p == 2, 'only p=1 and p=2 supported') self.p = p or 1 parent.__init(self) self.sizeAverage = true + if weights then + assert(weights:dim() == 1, "weights input should be 1-D Tensor") + self.weights = weights + end end function MultiMarginCriterion:updateOutput(input, target) @@ -21,7 +26,8 @@ function MultiMarginCriterion:updateOutput(input, target) target:cdata(), self.output_tensor:cdata(), self.sizeAverage, - self.p + self.p, + THNN.optionalTensor(self.weights) ) self.output = self.output_tensor[1] return self.output @@ -38,7 +44,8 @@ function MultiMarginCriterion:updateGradInput(input, target) target:cdata(), self.gradInput:cdata(), self.sizeAverage, - self.p + self.p, + THNN.optionalTensor(self.weights) ) return self.gradInput end |