Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'L1HingeEmbeddingCriterion.lua')
-rw-r--r--L1HingeEmbeddingCriterion.lua41
1 files changed, 41 insertions, 0 deletions
diff --git a/L1HingeEmbeddingCriterion.lua b/L1HingeEmbeddingCriterion.lua
new file mode 100644
index 0000000..5aa1ae7
--- /dev/null
+++ b/L1HingeEmbeddingCriterion.lua
@@ -0,0 +1,41 @@
+local L1HingeEmbeddingCriterion, parent = torch.class('nn.L1HingeEmbeddingCriterion', 'nn.Module')
+
+function L1HingeEmbeddingCriterion:__init(margin)
+ parent.__init(self)
+ margin=margin or 1
+ self.margin = margin
+ self.gradInput = {torch.Tensor(), torch.Tensor()}
+end
+
+function L1HingeEmbeddingCriterion:updateOutput(input,y)
+ self.output=input[1]:dist(input[2],1);
+ if y==-1 then
+ self.output = math.max(0,self.margin - self.output);
+ end
+ return self.output
+end
+
+
+local function mathsign(t)
+ if t>0 then return 1; end
+ if t<0 then return -1; end
+ return 2*torch.random(2)-3;
+end
+
+function L1HingeEmbeddingCriterion:updateGradInput(input, y)
+ self.gradInput[1]:resizeAs(input[1])
+ self.gradInput[2]:resizeAs(input[2])
+ self.gradInput[1]:copy(input[1])
+ self.gradInput[1]:add(-1, input[2])
+ local dist = self.gradInput[1]:norm(1);
+ self.gradInput[1]:apply(mathsign) -- L1 gradient
+ if y == -1 then -- just to avoid a mul by 1
+ if dist > self.margin then
+ self.gradInput[1]:zero()
+ else
+ self.gradInput[1]:mul(-1)
+ end
+ end
+ self.gradInput[2]:zero():add(-1, self.gradInput[1])
+ return self.gradInput
+end