diff options
-rw-r--r-- | CosineDistance.lua | 94 | ||||
-rwxr-xr-x | doc/table.md | 17 | ||||
-rw-r--r-- | test.lua | 50 |
3 files changed, 136 insertions, 25 deletions
diff --git a/CosineDistance.lua b/CosineDistance.lua index 061ff92..eb8eb91 100644 --- a/CosineDistance.lua +++ b/CosineDistance.lua @@ -3,38 +3,86 @@ local CosineDistance, parent = torch.class('nn.CosineDistance', 'nn.Module') function CosineDistance:__init() parent.__init(self) self.gradInput = {torch.Tensor(), torch.Tensor()} - self.output=torch.Tensor(1) end function CosineDistance:updateOutput(input) local input1, input2 = input[1], input[2] - self.w1 = input1:dot(input2) - self.w22 = input1:dot(input1) - self.w2 = math.sqrt(self.w22) - self.w32 = input2:dot(input2) - self.w3 = math.sqrt(self.w32) - self.output[1] = self.w1/self.w2/self.w3 + + if input1:dim() == 1 then + input1 = input1:view(1,-1) + input2 = input2:view(1,-1) + end + + if not self.buffer then + self.buffer = input1.new() + self.w1 = input1.new() + self.w22 = input1.new() + self.w = input1.new() + self.w32 = input1.new() + self.ones = input1.new() + end + + self.buffer:cmul(input1,input2) + self.w1:sum(self.buffer,2) + + local epsilon = 1e-12 + self.buffer:cmul(input1,input1) + self.w22:sum(self.buffer,2):add(epsilon) + self.ones:resizeAs(self.w22):fill(1) + self.w22:cdiv(self.ones, self.w22) + self.w:resizeAs(self.w22):copy(self.w22) + + self.buffer:cmul(input2,input2) + self.w32:sum(self.buffer,2):add(epsilon) + self.w32:cdiv(self.ones, self.w32) + self.w:cmul(self.w32) + self.w:sqrt() + + self.output:cmul(self.w1,self.w) + self.output = self.output:select(2,1) + return self.output end function CosineDistance:updateGradInput(input, gradOutput) local v1 = input[1] local v2 = input[2] - local gw1 = input[1].new() - local gw2 = input[2].new() - gw1:resizeAs(v1) - gw2:resizeAs(v1) - - gw1:zero() - gw1:add(1/(self.w2*self.w3), v2) - gw1:add(-self.w1/(self.w22*self.w2*self.w3), v1) - - gw2:zero() - gw2:add(1/(self.w2*self.w3), v1) - gw2:add(-self.w1/(self.w32*self.w2*self.w3), v2) - - gw1:mul(gradOutput[1]) - gw2:mul(gradOutput[1]) - self.gradInput = {gw1, gw2} + local not_batch = false + + if v1:dim() == 1 then + v1 = v1:view(1,-1) + v2 = v2:view(1,-1) + not_batch = true + end + + local gw1 = self.gradInput[1] + local gw2 = self.gradInput[2] + gw1:resizeAs(v1):copy(v2) + gw2:resizeAs(v1):copy(v1) + + self.w = self.w:expandAs(v1) + self.buffer:cmul(self.w1,self.w22) + self.buffer = self.buffer:expandAs(v1) + gw1:addcmul(-1,self.buffer,v1) + gw1:cmul(self.w) + + self.buffer:cmul(self.w1,self.w32) + self.buffer = self.buffer:expandAs(v1) + gw2:addcmul(-1,self.buffer,v2) + gw2:cmul(self.w) + + local go = gradOutput:view(-1,1):expandAs(v1) + gw1:cmul(go) + gw2:cmul(go) + + if not_batch then + self.gradInput[1] = gw1:select(1,1) + self.gradInput[2] = gw2:select(1,1) + end + + -- fix for torch bug + -- https://github.com/torch/torch7/issues/289 + self.buffer:resize() + return self.gradInput end diff --git a/doc/table.md b/doc/table.md index 95ac2b6..91ea209 100755 --- a/doc/table.md +++ b/doc/table.md @@ -981,9 +981,9 @@ end <a name="nn.CosineDistance"/> ## CosineDistance ## -`module` = `CosineDistance()` creates a module that takes a `table` of two vectors as input and outputs the cosine distance between them. +`module` = `CosineDistance()` creates a module that takes a `table` of two vectors (or matrices if in batch mode) as input and outputs the cosine distance between them. -Example: +Examples: ```lua mlp = nn.CosineDistance() x = torch.Tensor({1, 2, 3}) @@ -995,6 +995,19 @@ gives the output: 0.9746 [torch.Tensor of dimension 1] ``` +`CosineDistance` also accepts batches: +```lua +mlp = nn.CosineDistance() +x = torch.Tensor({{1,2,3},{1,2,-3}}) +y = torch.Tensor({{4,5,6},{-4,5,6}}) +print(mlp:forward({x,y})) +``` +gives the output: +```lua + 0.9746 +-0.3655 +[torch.DoubleTensor of size 2] +``` A more complicated example: ```lua @@ -3724,6 +3724,56 @@ function nntest.BatchMMTransposeBoth() end end +function nntest.CosineDistance() + local indim = math.random(1,10) + local input = {torch.rand(indim),torch.rand(indim)} + + -- check forward against previous implementation + local module = nn.CosineDistance() + + local w1 = input[1]:dot(input[2]) + local w2 = math.sqrt(input[1]:dot(input[1])) + local w3 = math.sqrt(input[2]:dot(input[2])) + local output_old = w1/w2/w3 + + local output = module:forward(input) + + mytester:assertlt(math.abs(output_old-output[1]),precision,'error on forward ') + + + -- check gradients + -- Note: testJacobian doesn't support table inputs, and rather than re-write + -- it so that it does, I'll just use a split table module on the input. + -- I assume both SplitTable and Sequential do not have bugs, otherwise this + -- test will break. + local input = torch.rand(2,indim) + local module = nn.Sequential() + module:add(nn.SplitTable(1)) + module:add(nn.CosineDistance()) + + local err = jac.testJacobian(module,input) + mytester:assertlt(err,precision, 'error on state ') + + -- IO + local ferr,berr = jac.testIO(module,input) + mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ') + mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ') + + -- batch + -- rebuild module to avoid correlated tests + local module = nn.Sequential() + module:add(nn.SplitTable(1)) + module:add(nn.CosineDistance()) + + local nframes = math.random(1,10) + local indim = math.random(1,10) + local input = torch.rand(2,nframes,indim) + + local err = jac.testJacobian(module,input) + mytester:assertlt(err,precision, 'batch error on state ') + +end + function nntest.CosineEmbeddingCriterion() local v1 = torch.Tensor{1, 0} local v2 = torch.Tensor{0.5, math.sqrt(3)*0.5} |