Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--CosineEmbeddingCriterion.lua151
-rwxr-xr-xdoc/criterion.md5
-rw-r--r--test.lua81
3 files changed, 208 insertions, 29 deletions
diff --git a/CosineEmbeddingCriterion.lua b/CosineEmbeddingCriterion.lua
index d5e0e70..d81dfa2 100644
--- a/CosineEmbeddingCriterion.lua
+++ b/CosineEmbeddingCriterion.lua
@@ -5,46 +5,145 @@ function CosineEmbeddingCriterion:__init(margin)
margin = margin or 0
self.margin = margin
self.gradInput = {torch.Tensor(), torch.Tensor()}
+ self.sizeAverage = true
end
-
+
function CosineEmbeddingCriterion:updateOutput(input,y)
+
local input1, input2 = input[1], input[2]
- self.w1 = input1:dot(input2)
- self.w22 = input1:dot(input1)
- self.w2 = math.sqrt(self.w22)
- self.w32 = input2:dot(input2)
- self.w3 = math.sqrt(self.w32)
- self.output = self.w1/self.w2/self.w3
- if y == -1 then
- self.output = math.max(0, self.output - self.margin);
- else
- self.output = 1 - self.output
+
+ -- keep backward compatibility
+ if type(y) == 'number' then
+ self._y = self._y or input1.new(1)
+ self._y[1] = y
+ y = self._y
+ end
+
+ if input1:dim() == 1 then
+ input1 = input1:view(1,-1)
+ input2 = input2:view(1,-1)
+ end
+
+ if not self.buffer then
+ self.buffer = input1.new()
+ self.w1 = input1.new()
+ self.w22 = input1.new()
+ self.w = input1.new()
+ self.w32 = input1.new()
+ self._outputs = input1.new()
+ -- comparison operators behave differently from cuda/c implementations
+ if input1:type() == 'torch.CudaTensor' then
+ self._idx = input1.new()
+ else
+ self._idx = torch.ByteTensor()
+ end
+ end
+
+ self.buffer:cmul(input1,input2)
+ self.w1:sum(self.buffer,2)
+
+ local epsilon = 1e-12
+ self.buffer:cmul(input1,input1)
+ self.w22:sum(self.buffer,2):add(epsilon)
+ -- self._outputs is also used as a temporary buffer
+ self._outputs:resizeAs(self.w22):fill(1)
+ self.w22:cdiv(self._outputs, self.w22)
+ self.w:resizeAs(self.w22):copy(self.w22)
+
+ self.buffer:cmul(input2,input2)
+ self.w32:sum(self.buffer,2):add(epsilon)
+ self.w32:cdiv(self._outputs, self.w32)
+ self.w:cmul(self.w32)
+ self.w:sqrt()
+
+ self._outputs:cmul(self.w1,self.w)
+ self._outputs = self._outputs:select(2,1)
+
+ y.eq(self._idx,y,-1)
+ self._outputs[self._idx] = self._outputs[self._idx]:add(-self.margin):cmax(0)
+ y.eq(self._idx,y,1)
+ self._outputs[self._idx] = self._outputs[self._idx]:mul(-1):add(1)
+
+ self.output = self._outputs:sum()
+
+ if self.sizeAverage then
+ self.output = self.output/y:size(1)
end
+
return self.output
end
function CosineEmbeddingCriterion:updateGradInput(input, y)
+
local v1 = input[1]
local v2 = input[2]
- local gw1 = input[1].new()
- local gw2 = input[2].new()
- gw1:resizeAs(v1)
- gw2:resizeAs(v1)
+ local not_batch = false
+
+ -- keep backward compatibility
+ if type(y) == 'number' then
+ self._y = self._y or input1.new(1)
+ self._y[1] = y
+ y = self._y
+ end
+
+ if v1:dim() == 1 then
+ v1 = v1:view(1,-1)
+ v2 = v2:view(1,-1)
+ not_batch = true
+ end
+
+ local gw1 = self.gradInput[1]
+ local gw2 = self.gradInput[2]
+ gw1:resizeAs(v1):copy(v2)
+ gw2:resizeAs(v1):copy(v1)
+
+ self.w = self.w:expandAs(v1)
+ self.buffer:cmul(self.w1,self.w22)
+ self.buffer = self.buffer:expandAs(v1)
+ gw1:addcmul(-1,self.buffer,v1)
+ gw1:cmul(self.w)
- gw1:zero()
- gw2:zero()
+ self.buffer:cmul(self.w1,self.w32)
+ self.buffer = self.buffer:expandAs(v1)
+ gw2:addcmul(-1,self.buffer,v2)
+ gw2:cmul(self.w)
- if self.output > 0 then
- gw1:add(1/(self.w2*self.w3), v2)
- gw1:add(-self.w1/(self.w22*self.w2*self.w3), v1)
+ -- self._idx = self._outputs <= 0
+ y.le(self._idx,self._outputs,0)
+ self._idx = self._idx:view(-1,1):expand(gw1:size())
+ gw1[self._idx] = 0
+ gw2[self._idx] = 0
- gw2:add(1/(self.w2*self.w3), v1)
- gw2:add(-self.w1/(self.w32*self.w2*self.w3), v2)
+ y.eq(self._idx,y,1)
+ self._idx = self._idx:view(-1,1):expand(gw2:size())
+ gw1[self._idx] = gw1[self._idx]:mul(-1)
+ gw2[self._idx] = gw2[self._idx]:mul(-1)
+
+ if self.sizeAverage then
+ gw1:div(y:size(1))
+ gw2:div(y:size(1))
end
- if y == 1 then
- gw1:mul(-1)
- gw2:mul(-1)
+
+ if not_batch then
+ self.gradInput[1] = gw1:select(1,1)
+ self.gradInput[2] = gw2:select(1,1)
end
- self.gradInput = {gw1, gw2}
+
+ -- fix for torch bug
+ -- https://github.com/torch/torch7/issues/289
+ self.buffer:resize()
+
return self.gradInput
end
+
+function CosineEmbeddingCriterion:type(type)
+ self._idx = nil
+ parent.type(self,type)
+ -- comparison operators behave differently from cuda/c implementations
+ if type == 'torch.CudaTensor' then
+ self._idx = torch.CudaTensor()
+ else
+ self._idx = torch.ByteTensor()
+ end
+ return self
+end
diff --git a/doc/criterion.md b/doc/criterion.md
index 2928938..104e00a 100755
--- a/doc/criterion.md
+++ b/doc/criterion.md
@@ -493,13 +493,13 @@ The `margin` has a default value of `1`, or can be set in the constructor.
criterion = nn.CosineEmbeddingCriterion([margin])
```
-Creates a criterion that measures the loss given an input `x` = `{x1, x2}`, a table of two `Tensor`s, and a label `y` (1 or -1).
+Creates a criterion that measures the loss given an input `x` = `{x1, x2}`, a table of two `Tensor`s, and a `Tensor` label `y` with values 1 or -1.
This is used for measuring whether two inputs are similar or dissimilar, using the cosine distance, and is typically used for learning nonlinear embeddings or semi-supervised learning.
`margin` should be a number from `-1` to `1`, `0` to `0.5` is suggested.
`Forward` and `Backward` have to be used alternately. If `margin` is missing, the default value is `0`.
-The loss function is:
+The loss function for each sample is:
```lua
⎧ 1 - cos(x1, x2), if y == 1
@@ -507,6 +507,7 @@ loss(x, y) = ⎨
⎩ max(0, cos(x1, x2) - margin), if y == -1
```
+For batched inputs, if the internal variable `sizeAverage` is equal to `true`, the loss function averages the loss over the batch samples; if `sizeAverage` is `false`, then the loss function sums over the batch samples. By default, `sizeAverage` equals to `true`.
<a name="nn.MarginRankingCriterion"></a>
## MarginRankingCriterion ##
diff --git a/test.lua b/test.lua
index 5b08c37..3911118 100644
--- a/test.lua
+++ b/test.lua
@@ -800,6 +800,45 @@ local function criterionJacobianTest1D(cri, input, target)
mytester:assertlt(err, precision, 'error in difference between central difference and :backward')
end
+local function criterionJacobianTest1DTable(cri, input0, target)
+ -- supposes input is a tensor, which is splitted in the first dimension
+ local input = input0:split(1,1)
+ for i=1,#input do
+ input[i] = input[i][1]
+ end
+ local eps = 1e-6
+ local _ = cri:forward(input, target)
+ local dfdx = cri:backward(input, target)
+ -- for each input perturbation, do central difference
+ local centraldiff_dfdx = torch.Tensor():resizeAs(input0)
+ local input_s = input0:storage()
+ local centraldiff_dfdx_s = centraldiff_dfdx:storage()
+ for i=1,input0:nElement() do
+ -- f(xi + h)
+ input_s[i] = input_s[i] + eps
+ local fx1 = cri:forward(input, target)
+ -- f(xi - h)
+ input_s[i] = input_s[i] - 2*eps
+ local fx2 = cri:forward(input, target)
+ -- f'(xi) = (f(xi + h) - f(xi - h)) / 2h
+ local cdfx = (fx1 - fx2) / (2*eps)
+ -- store f' in appropriate place
+ centraldiff_dfdx_s[i] = cdfx
+ -- reset input[i]
+ input_s[i] = input_s[i] + eps
+ end
+ local centraldiff_dfdx_t = centraldiff_dfdx:split(1,1)
+ for i=1,#centraldiff_dfdx_t do
+ centraldiff_dfdx_t[i] = centraldiff_dfdx_t[i][1]
+ end
+ for i=1,#centraldiff_dfdx_t do
+ -- compare centraldiff_dfdx with :backward()
+ local err = (centraldiff_dfdx_t[i] - dfdx[i]):abs():max()
+ mytester:assertlt(err, precision, 'error in difference between central difference and :backward')
+ end
+end
+
+
function nntest.MSECriterion()
local input = torch.rand(10)
local target = input:clone():add(torch.rand(10))
@@ -3826,12 +3865,52 @@ function nntest.CosineEmbeddingCriterion()
local v2 = torch.Tensor{0.5, math.sqrt(3)*0.5}
local crit = nn.CosineEmbeddingCriterion(0.6)
- local output = crit:forward({v1, v2}, -1) -- must be called before backward
+ local output = crit:forward({v1, v2}, -1) -- must be Called before backward
local grads = crit:backward({v1, v2}, -1)
local zero = torch.Tensor(2):zero()
equal(grads[1], zero, 'gradient should be zero')
equal(grads[2], zero, 'gradient should be zero')
+
+ -- check jacobians
+ local margin = math.random()*2-1
+ local dim = 5
+ local batch_size = 1
+ local crit = nn.CosineEmbeddingCriterion(margin)
+ local v = torch.rand(2,dim)
+ criterionJacobianTest1DTable(crit,v,1)
+ criterionJacobianTest1DTable(crit,v,-1)
+
+ -- batch with hand-computed values
+ local v1 = torch.Tensor{{1, 0}, {0.5, math.sqrt(3)*0.5}}
+ local v2 = torch.Tensor{{0.5, math.sqrt(3)*0.5}, {1, 0}}
+
+ local t = torch.Tensor{-1,-1}
+ local crit = nn.CosineEmbeddingCriterion(0.6)
+ local output = crit:forward({v1, v2}, t) -- must be Called before backward
+ local grads = crit:backward({v1, v2}, t)
+
+ local zero = torch.Tensor(2,2):zero()
+ equal(grads[1], zero, 'gradient should be zero')
+ equal(grads[2], zero, 'gradient should be zero')
+
+ -- batch, sizeAverage true, jacobian
+ local margin = math.random()*2-1
+ local dim = 5
+ local batch_size = 2
+ local crit = nn.CosineEmbeddingCriterion(margin)
+ crit.sizeAverage = true
+ local v = torch.rand(2,batch_size,dim)
+ local t = torch.Tensor(batch_size):random(0,1):mul(2):add(-1)
+ criterionJacobianTest1DTable(crit,v,t)
+
+ -- batch, sizeAverage false, jacobian
+ local margin = math.random()*2-1
+ local crit = nn.CosineEmbeddingCriterion(margin)
+ crit.sizeAverage = false
+ local v = torch.rand(2,batch_size,dim)
+ local t = torch.Tensor(batch_size):random(0,1):mul(2):add(-1)
+ criterionJacobianTest1DTable(crit,v,t)
end
function nntest.HingeEmbeddingCriterion()