From 3ce35f0bac29d4a9645ba71ee9a6cde6d93091d2 Mon Sep 17 00:00:00 2001 From: Jonathan Tompson Date: Mon, 21 Oct 2013 14:12:18 -0400 Subject: going back on Clement's suggestion. It was a good idea, but we're needlessly loosing performance in the 1D case with the input clone. --- PairwiseDistance.lua | 78 ++++++++++++++++++++++++++-------------------------- 1 file changed, 39 insertions(+), 39 deletions(-) diff --git a/PairwiseDistance.lua b/PairwiseDistance.lua index f9f77f2..79569c9 100644 --- a/PairwiseDistance.lua +++ b/PairwiseDistance.lua @@ -11,26 +11,27 @@ function PairwiseDistance:__init(p) end function PairwiseDistance:updateOutput(input) - if input[1]:dim() > 2 then - error('input must be vector or matrix') - end if input[1]:dim() == 1 then - -- Reshape the input so it always looks like a batch (avoids multiple - -- code-paths). (Clement's good suggestion) - input[1]:resize(1,input[1]:size(1)) - input[2]:resize(1,input[2]:size(1)) - end - - self.diff:resizeAs(input[1]) - self.diff:zero() - self.diff:add(input[1], -1, input[2]) - self.diff:abs() + self.output[1]=input[1]:dist(input[2],self.norm) + elseif input[1]:dim() == 2 then + self.diff = self.diff or input[1].new() + self.diff:resizeAs(input[1]) - self.output:resize(input[1]:size(1)) - self.output:zero() - self.output:add(self.diff:pow(self.norm):sum(2)) - self.output:pow(1./self.norm) + local diff = self.diff:zero() + diff:add(input[1], -1, input[2]) + diff:abs() + self.output:resize(input[1]:size(1)) + self.output:zero() + self.output:add(diff:pow(self.norm):sum(2)) + self.output:pow(1./self.norm) + else + error('input must be vector or matrix') + end + if input[1]:dim() > 2 then + error('input must be vector or matrix') + end + return self.output end @@ -44,13 +45,6 @@ function PairwiseDistance:updateGradInput(input, gradOutput) error('input must be vector or matrix') end - if input[1]:dim() == 1 then - -- Reshape the input so it always looks like a batch (avoids multiple - -- code-paths). (Clement's good suggestion) - input[1]:resize(1,input[1]:size(1)) - input[2]:resize(1,input[2]:size(1)) - end - self.gradInput[1]:resize(input[1]:size()) self.gradInput[2]:resize(input[2]:size()) self.gradInput[1]:copy(input[1]) @@ -65,24 +59,30 @@ function PairwiseDistance:updateGradInput(input, gradOutput) self.gradInput[1]:cmul(self.gradInput[1]:clone():abs():pow(self.norm-2)) end - self.outExpand = self.outExpand or self.output.new() - self.outExpand:resize(self.output:size(1), 1) - self.outExpand:copy(self.output) - self.outExpand:add(1.0e-6) -- Prevent divide by zero errors - self.outExpand:pow(-(self.norm-1)) - self.gradInput[1]:cmul(self.outExpand:expand(self.gradInput[1]:size(1), - self.gradInput[1]:size(2))) + if (input[1]:dim() > 1) then + self.outExpand = self.outExpand or self.output.new() + self.outExpand:resize(self.output:size(1), 1) + self.outExpand:copy(self.output) + self.outExpand:add(1.0e-6) -- Prevent divide by zero errors + self.outExpand:pow(-(self.norm-1)) + self.gradInput[1]:cmul(self.outExpand:expand(self.gradInput[1]:size(1), + self.gradInput[1]:size(2))) + else + self.gradInput[1]:mul(math.pow(self.output[1] + 1e-6, -(self.norm-1))) + end end - - self.grad = self.grad or gradOutput.new() - self.ones = self.ones or gradOutput.new() + if input[1]:dim() == 1 then + self.gradInput[1]:mul(gradOutput[1]) + else + self.grad = self.grad or gradOutput.new() + self.ones = self.ones or gradOutput.new() - self.grad:resizeAs(input[1]):zero() - self.ones:resize(input[1]:size(2)):fill(1) + self.grad:resizeAs(input[1]):zero() + self.ones:resize(input[1]:size(2)):fill(1) - self.grad:addr(gradOutput, self.ones) - self.gradInput[1]:cmul(self.grad) - + self.grad:addr(gradOutput, self.ones) + self.gradInput[1]:cmul(self.grad) + end self.gradInput[2]:zero():add(-1, self.gradInput[1]) return self.gradInput end -- cgit v1.2.3