Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--CosineDistance.lua94
-rwxr-xr-xdoc/table.md17
-rw-r--r--test.lua50
3 files changed, 136 insertions, 25 deletions
diff --git a/CosineDistance.lua b/CosineDistance.lua
index 061ff92..eb8eb91 100644
--- a/CosineDistance.lua
+++ b/CosineDistance.lua
@@ -3,38 +3,86 @@ local CosineDistance, parent = torch.class('nn.CosineDistance', 'nn.Module')
function CosineDistance:__init()
parent.__init(self)
self.gradInput = {torch.Tensor(), torch.Tensor()}
- self.output=torch.Tensor(1)
end
function CosineDistance:updateOutput(input)
local input1, input2 = input[1], input[2]
- self.w1 = input1:dot(input2)
- self.w22 = input1:dot(input1)
- self.w2 = math.sqrt(self.w22)
- self.w32 = input2:dot(input2)
- self.w3 = math.sqrt(self.w32)
- self.output[1] = self.w1/self.w2/self.w3
+
+ if input1:dim() == 1 then
+ input1 = input1:view(1,-1)
+ input2 = input2:view(1,-1)
+ end
+
+ if not self.buffer then
+ self.buffer = input1.new()
+ self.w1 = input1.new()
+ self.w22 = input1.new()
+ self.w = input1.new()
+ self.w32 = input1.new()
+ self.ones = input1.new()
+ end
+
+ self.buffer:cmul(input1,input2)
+ self.w1:sum(self.buffer,2)
+
+ local epsilon = 1e-12
+ self.buffer:cmul(input1,input1)
+ self.w22:sum(self.buffer,2):add(epsilon)
+ self.ones:resizeAs(self.w22):fill(1)
+ self.w22:cdiv(self.ones, self.w22)
+ self.w:resizeAs(self.w22):copy(self.w22)
+
+ self.buffer:cmul(input2,input2)
+ self.w32:sum(self.buffer,2):add(epsilon)
+ self.w32:cdiv(self.ones, self.w32)
+ self.w:cmul(self.w32)
+ self.w:sqrt()
+
+ self.output:cmul(self.w1,self.w)
+ self.output = self.output:select(2,1)
+
return self.output
end
function CosineDistance:updateGradInput(input, gradOutput)
local v1 = input[1]
local v2 = input[2]
- local gw1 = input[1].new()
- local gw2 = input[2].new()
- gw1:resizeAs(v1)
- gw2:resizeAs(v1)
-
- gw1:zero()
- gw1:add(1/(self.w2*self.w3), v2)
- gw1:add(-self.w1/(self.w22*self.w2*self.w3), v1)
-
- gw2:zero()
- gw2:add(1/(self.w2*self.w3), v1)
- gw2:add(-self.w1/(self.w32*self.w2*self.w3), v2)
-
- gw1:mul(gradOutput[1])
- gw2:mul(gradOutput[1])
- self.gradInput = {gw1, gw2}
+ local not_batch = false
+
+ if v1:dim() == 1 then
+ v1 = v1:view(1,-1)
+ v2 = v2:view(1,-1)
+ not_batch = true
+ end
+
+ local gw1 = self.gradInput[1]
+ local gw2 = self.gradInput[2]
+ gw1:resizeAs(v1):copy(v2)
+ gw2:resizeAs(v1):copy(v1)
+
+ self.w = self.w:expandAs(v1)
+ self.buffer:cmul(self.w1,self.w22)
+ self.buffer = self.buffer:expandAs(v1)
+ gw1:addcmul(-1,self.buffer,v1)
+ gw1:cmul(self.w)
+
+ self.buffer:cmul(self.w1,self.w32)
+ self.buffer = self.buffer:expandAs(v1)
+ gw2:addcmul(-1,self.buffer,v2)
+ gw2:cmul(self.w)
+
+ local go = gradOutput:view(-1,1):expandAs(v1)
+ gw1:cmul(go)
+ gw2:cmul(go)
+
+ if not_batch then
+ self.gradInput[1] = gw1:select(1,1)
+ self.gradInput[2] = gw2:select(1,1)
+ end
+
+ -- fix for torch bug
+ -- https://github.com/torch/torch7/issues/289
+ self.buffer:resize()
+
return self.gradInput
end
diff --git a/doc/table.md b/doc/table.md
index 95ac2b6..91ea209 100755
--- a/doc/table.md
+++ b/doc/table.md
@@ -981,9 +981,9 @@ end
<a name="nn.CosineDistance"/>
## CosineDistance ##
-`module` = `CosineDistance()` creates a module that takes a `table` of two vectors as input and outputs the cosine distance between them.
+`module` = `CosineDistance()` creates a module that takes a `table` of two vectors (or matrices if in batch mode) as input and outputs the cosine distance between them.
-Example:
+Examples:
```lua
mlp = nn.CosineDistance()
x = torch.Tensor({1, 2, 3})
@@ -995,6 +995,19 @@ gives the output:
0.9746
[torch.Tensor of dimension 1]
```
+`CosineDistance` also accepts batches:
+```lua
+mlp = nn.CosineDistance()
+x = torch.Tensor({{1,2,3},{1,2,-3}})
+y = torch.Tensor({{4,5,6},{-4,5,6}})
+print(mlp:forward({x,y}))
+```
+gives the output:
+```lua
+ 0.9746
+-0.3655
+[torch.DoubleTensor of size 2]
+```
A more complicated example:
```lua
diff --git a/test.lua b/test.lua
index 3f37dac..cd48433 100644
--- a/test.lua
+++ b/test.lua
@@ -3724,6 +3724,56 @@ function nntest.BatchMMTransposeBoth()
end
end
+function nntest.CosineDistance()
+ local indim = math.random(1,10)
+ local input = {torch.rand(indim),torch.rand(indim)}
+
+ -- check forward against previous implementation
+ local module = nn.CosineDistance()
+
+ local w1 = input[1]:dot(input[2])
+ local w2 = math.sqrt(input[1]:dot(input[1]))
+ local w3 = math.sqrt(input[2]:dot(input[2]))
+ local output_old = w1/w2/w3
+
+ local output = module:forward(input)
+
+ mytester:assertlt(math.abs(output_old-output[1]),precision,'error on forward ')
+
+
+ -- check gradients
+ -- Note: testJacobian doesn't support table inputs, and rather than re-write
+ -- it so that it does, I'll just use a split table module on the input.
+ -- I assume both SplitTable and Sequential do not have bugs, otherwise this
+ -- test will break.
+ local input = torch.rand(2,indim)
+ local module = nn.Sequential()
+ module:add(nn.SplitTable(1))
+ module:add(nn.CosineDistance())
+
+ local err = jac.testJacobian(module,input)
+ mytester:assertlt(err,precision, 'error on state ')
+
+ -- IO
+ local ferr,berr = jac.testIO(module,input)
+ mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ')
+ mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ')
+
+ -- batch
+ -- rebuild module to avoid correlated tests
+ local module = nn.Sequential()
+ module:add(nn.SplitTable(1))
+ module:add(nn.CosineDistance())
+
+ local nframes = math.random(1,10)
+ local indim = math.random(1,10)
+ local input = torch.rand(2,nframes,indim)
+
+ local err = jac.testJacobian(module,input)
+ mytester:assertlt(err,precision, 'batch error on state ')
+
+end
+
function nntest.CosineEmbeddingCriterion()
local v1 = torch.Tensor{1, 0}
local v2 = torch.Tensor{0.5, math.sqrt(3)*0.5}