diff options
author | soumith <soumith@fb.com> | 2016-03-01 00:17:10 +0300 |
---|---|---|
committer | soumith <soumith@fb.com> | 2016-03-01 00:17:10 +0300 |
commit | 9de8f9060c119df804013e53a7bce5543a027e4c (patch) | |
tree | 4617f6d0ce8225dcaff88ed17066b48c4de3731d /SoftMarginCriterion.lua | |
parent | 296c6aa199c7299125a8412897c051a6efb73b76 (diff) |
SoftMarginCriterion
Diffstat (limited to 'SoftMarginCriterion.lua')
-rw-r--r-- | SoftMarginCriterion.lua | 24 |
1 files changed, 24 insertions, 0 deletions
diff --git a/SoftMarginCriterion.lua b/SoftMarginCriterion.lua new file mode 100644 index 0000000..96ccda8 --- /dev/null +++ b/SoftMarginCriterion.lua @@ -0,0 +1,24 @@ +local SoftMarginCriterion, parent = torch.class('nn.SoftMarginCriterion', 'nn.Criterion') + +function SoftMarginCriterion:__init() + parent.__init(self) + self.sizeAverage = true +end + +function SoftMarginCriterion:updateOutput(input, target) + self.output_tensor = self.output_tensor or input.new(1) + input.THNN.SoftMarginCriterion_updateOutput( + input:cdata(), target:cdata(), + self.output_tensor:cdata(), + self.sizeAverage) + self.output = self.output_tensor[1] + return self.output +end + +function SoftMarginCriterion:updateGradInput(input, target) + input.THNN.SoftMarginCriterion_updateGradInput( + input:cdata(), target:cdata(), + self.gradInput:cdata(), + self.sizeAverage) + return self.gradInput +end |