Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2015-07-30 12:12:30 +0300
committerSoumith Chintala <soumith@gmail.com>2015-07-30 12:12:30 +0300
commit13e2cdb09363aa25d32dcd66b3204ad16320eb7f (patch)
treea1fb67153ffaf28f8b63188876bcec26af2b50b6
parent85e9542cf2051c3ebb07984ccb439f075b36da67 (diff)
parent38d18fd48078f41f606e6f4a46773fea1f6683f4 (diff)
Merge pull request #313 from mys007/batchmode_hingeemb
HingeEmbeddingCriterion: batch mode
-rw-r--r--HingeEmbeddingCriterion.lua42
-rwxr-xr-xdoc/criterion.md8
-rw-r--r--test.lua13
3 files changed, 47 insertions, 16 deletions
diff --git a/HingeEmbeddingCriterion.lua b/HingeEmbeddingCriterion.lua
index b075847..fe8f1a6 100644
--- a/HingeEmbeddingCriterion.lua
+++ b/HingeEmbeddingCriterion.lua
@@ -2,24 +2,42 @@ local HingeEmbeddingCriterion, parent = torch.class('nn.HingeEmbeddingCriterion'
function HingeEmbeddingCriterion:__init(margin)
parent.__init(self)
- margin = margin or 1
- self.margin = margin
- self.gradInput = torch.Tensor(1)
+ self.margin = margin or 1
+ self.sizeAverage = true
end
function HingeEmbeddingCriterion:updateOutput(input,y)
- self.output = input[1]
- if y == -1 then
- self.output = math.max(0,self.margin - self.output);
+ self.buffer = self.buffer or input.new()
+ if not torch.isTensor(y) then
+ self.ty = self.ty or input.new():resize(1)
+ self.ty[1]=y
+ y=self.ty
end
+
+ self.buffer:resizeAs(input):copy(input)
+ self.buffer[torch.eq(y, -1)] = 0
+ self.output = self.buffer:sum()
+
+ self.buffer:fill(self.margin):add(-1, input)
+ self.buffer:cmax(0)
+ self.buffer[torch.eq(y, 1)] = 0
+ self.output = self.output + self.buffer:sum()
+
+ if (self.sizeAverage == nil or self.sizeAverage == true) then
+ self.output = self.output / input:nElement()
+ end
+
return self.output
end
function HingeEmbeddingCriterion:updateGradInput(input, y)
- self.gradInput[1] = y
- local dist = input[1]
- if y == -1 and dist > self.margin then
- self.gradInput[1] = 0;
- end
- return self.gradInput
+ if not torch.isTensor(y) then self.ty[1]=y; y=self.ty end
+ self.gradInput:resizeAs(input):copy(y)
+ self.gradInput[torch.cmul(torch.eq(y, -1), torch.gt(input, self.margin))] = 0
+
+ if (self.sizeAverage == nil or self.sizeAverage == true) then
+ self.gradInput:mul(1 / input:nElement())
+ end
+
+ return self.gradInput
end
diff --git a/doc/criterion.md b/doc/criterion.md
index d0c9366..64e6d63 100755
--- a/doc/criterion.md
+++ b/doc/criterion.md
@@ -401,12 +401,12 @@ Creates a criterion that measures the loss given an input `x` which is a 1-dime
This is usually used for measuring whether two inputs are similar or dissimilar, e.g. using the L1 pairwise distance, and is typically used for learning nonlinear embeddings or semi-supervised learning.
```lua
- ⎧ x, if y == 1
-loss(x, y) = ⎨
- ⎩ max(0, margin - x), if y == -1
+ ⎧ x_i, if y_i == 1
+loss(x, y) = 1/n ⎨
+ ⎩ max(0, margin - x_i), if y_i == -1
```
-The `margin` has a default value of `1`, or can be set in the constructor.
+If `x` and `y` are `n`-dimensional `Tensor`s, the sum operation still operates over all the elements, and divides by `n` (this can be avoided if one sets the internal variable `sizeAverage` to `false`). The `margin` has a default value of `1`, or can be set in the constructor.
### Example
diff --git a/test.lua b/test.lua
index f5e5295..f00c7c0 100644
--- a/test.lua
+++ b/test.lua
@@ -3830,6 +3830,19 @@ function nntest.CosineEmbeddingCriterion()
equal(grads[2], zero, 'gradient should be zero')
end
+function nntest.HingeEmbeddingCriterion()
+ local x = torch.Tensor{0.3,2.1,1.8,0}
+ local y = torch.Tensor{1,-1,-1,1}
+ local expgrads = torch.Tensor{1,0,-1,1} / 4
+
+ local crit = nn.HingeEmbeddingCriterion(2)
+ local output = crit:forward(x, y) -- must be called before backward
+ local grads = crit:backward(x, y)
+
+ mytester:assert(math.abs(output - (0.3 + 0.2) / 4) < 1e-10)
+ equal(grads, expgrads)
+end
+
function nntest.Replicate()
local vector = torch.rand(3)