Welcome to mirror list, hosted at ThFree Co, Russian Federation.

PairwiseDistance.lua - github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 6bf43c5533f01cbc3ab6825bcbdf66617c57a5bf (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
local PairwiseDistance, parent = torch.class('nn.PairwiseDistance', 'nn.Module')

function PairwiseDistance:__init(p)
   parent.__init(self)

   -- state
   self.gradInput = {}
   self.diff = torch.Tensor()
   self.norm = p or 2 -- Default using Euclidean distance
end 
  
function PairwiseDistance:updateOutput(input)
   self.output:resize(1)
   if input[1]:dim() == 1 then
      self.output:resize(1)
      self.output[1]=input[1]:dist(input[2],self.norm)
   elseif input[1]:dim() == 2 then
      self.diff = self.diff or input[1].new()
      self.diff:resizeAs(input[1])

      local diff = self.diff:zero()
      diff:add(input[1], -1, input[2])
      diff:abs()

      self.output:resize(input[1]:size(1))
      self.output:zero()
      self.output:add(diff:pow(self.norm):sum(2))
      self.output:pow(1./self.norm)
   else
      error('input must be vector or matrix')
   end
 
   return self.output
end

local function mathsign(x) 
   if x==0 then return  2*torch.random(2)-3; end
   if x>0 then return 1; else return -1; end
end

function PairwiseDistance:updateGradInput(input, gradOutput)
   if input[1]:dim() > 2 then
      error('input must be vector or matrix')
   end

   self.gradInput[1] = (self.gradInput[1] or input[1].new()):resize(input[1]:size()) 
   self.gradInput[2] = (self.gradInput[2] or input[2].new()):resize(input[2]:size())
   self.gradInput[1]:copy(input[1])
   self.gradInput[1]:add(-1, input[2]) 
   
   if self.norm==1 then
     self.gradInput[1]:apply(mathsign)
   else
     -- Note: derivative of p-norm:
     -- d/dx_k(||x||_p) = (x_k * abs(x_k)^(p-2)) / (||x||_p)^(p-1)
     if (self.norm > 2) then
        self.gradInput[1]:cmul(self.gradInput[1]:clone():abs():pow(self.norm-2))
     end

     if (input[1]:dim() > 1) then
        self.outExpand = self.outExpand or self.output.new()
        self.outExpand:resize(self.output:size(1), 1)
        self.outExpand:copy(self.output)
        self.outExpand:add(1.0e-6)  -- Prevent divide by zero errors
        self.outExpand:pow(-(self.norm-1))
        self.gradInput[1]:cmul(self.outExpand:expand(self.gradInput[1]:size(1),
           self.gradInput[1]:size(2)))
     else
        self.gradInput[1]:mul(math.pow(self.output[1] + 1e-6, -(self.norm-1)))
     end
   end
   if input[1]:dim() == 1 then
      self.gradInput[1]:mul(gradOutput[1])
   else
      self.grad = self.grad or gradOutput.new()
      self.ones = self.ones or gradOutput.new()

      self.grad:resizeAs(input[1]):zero()
      self.ones:resize(input[1]:size(2)):fill(1)

      self.grad:addr(gradOutput, self.ones)
      self.gradInput[1]:cmul(self.grad)
   end
   self.gradInput[2]:zero():add(-1, self.gradInput[1])
   return self.gradInput
end

function PairwiseDistance:clearState()
   nn.utils.clear(self, 'diff', 'outExpand', 'grad', 'ones')
   return parent.clearState(self)
end