Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorsoumith <soumith@fb.com>2016-03-01 00:17:10 +0300
committersoumith <soumith@fb.com>2016-03-01 00:17:10 +0300
commit9de8f9060c119df804013e53a7bce5543a027e4c (patch)
tree4617f6d0ce8225dcaff88ed17066b48c4de3731d /SoftMarginCriterion.lua
parent296c6aa199c7299125a8412897c051a6efb73b76 (diff)
SoftMarginCriterion
Diffstat (limited to 'SoftMarginCriterion.lua')
-rw-r--r--SoftMarginCriterion.lua24
1 files changed, 24 insertions, 0 deletions
diff --git a/SoftMarginCriterion.lua b/SoftMarginCriterion.lua
new file mode 100644
index 0000000..96ccda8
--- /dev/null
+++ b/SoftMarginCriterion.lua
@@ -0,0 +1,24 @@
+local SoftMarginCriterion, parent = torch.class('nn.SoftMarginCriterion', 'nn.Criterion')
+
+function SoftMarginCriterion:__init()
+ parent.__init(self)
+ self.sizeAverage = true
+end
+
+function SoftMarginCriterion:updateOutput(input, target)
+ self.output_tensor = self.output_tensor or input.new(1)
+ input.THNN.SoftMarginCriterion_updateOutput(
+ input:cdata(), target:cdata(),
+ self.output_tensor:cdata(),
+ self.sizeAverage)
+ self.output = self.output_tensor[1]
+ return self.output
+end
+
+function SoftMarginCriterion:updateGradInput(input, target)
+ input.THNN.SoftMarginCriterion_updateGradInput(
+ input:cdata(), target:cdata(),
+ self.gradInput:cdata(),
+ self.sizeAverage)
+ return self.gradInput
+end