Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJonathan Tompson <tompson@cims.nyu.edu>2013-10-21 22:12:18 +0400
committerJonathan Tompson <tompson@cims.nyu.edu>2013-10-21 22:12:18 +0400
commit3ce35f0bac29d4a9645ba71ee9a6cde6d93091d2 (patch)
treecf4a9fee48413be62968c0bd7a20ff0c83a139e7 /PairwiseDistance.lua
parent7286784cbd0af841501bfadb98a4a2cee28199b6 (diff)
going back on Clement's suggestion. It was a good idea, but we're needlessly loosing performance in the 1D case with the input clone.
Diffstat (limited to 'PairwiseDistance.lua')
-rw-r--r--PairwiseDistance.lua78
1 files changed, 39 insertions, 39 deletions
diff --git a/PairwiseDistance.lua b/PairwiseDistance.lua
index f9f77f2..79569c9 100644
--- a/PairwiseDistance.lua
+++ b/PairwiseDistance.lua
@@ -11,26 +11,27 @@ function PairwiseDistance:__init(p)
end
function PairwiseDistance:updateOutput(input)
- if input[1]:dim() > 2 then
- error('input must be vector or matrix')
- end
if input[1]:dim() == 1 then
- -- Reshape the input so it always looks like a batch (avoids multiple
- -- code-paths). (Clement's good suggestion)
- input[1]:resize(1,input[1]:size(1))
- input[2]:resize(1,input[2]:size(1))
- end
-
- self.diff:resizeAs(input[1])
- self.diff:zero()
- self.diff:add(input[1], -1, input[2])
- self.diff:abs()
+ self.output[1]=input[1]:dist(input[2],self.norm)
+ elseif input[1]:dim() == 2 then
+ self.diff = self.diff or input[1].new()
+ self.diff:resizeAs(input[1])
- self.output:resize(input[1]:size(1))
- self.output:zero()
- self.output:add(self.diff:pow(self.norm):sum(2))
- self.output:pow(1./self.norm)
+ local diff = self.diff:zero()
+ diff:add(input[1], -1, input[2])
+ diff:abs()
+ self.output:resize(input[1]:size(1))
+ self.output:zero()
+ self.output:add(diff:pow(self.norm):sum(2))
+ self.output:pow(1./self.norm)
+ else
+ error('input must be vector or matrix')
+ end
+ if input[1]:dim() > 2 then
+ error('input must be vector or matrix')
+ end
+
return self.output
end
@@ -44,13 +45,6 @@ function PairwiseDistance:updateGradInput(input, gradOutput)
error('input must be vector or matrix')
end
- if input[1]:dim() == 1 then
- -- Reshape the input so it always looks like a batch (avoids multiple
- -- code-paths). (Clement's good suggestion)
- input[1]:resize(1,input[1]:size(1))
- input[2]:resize(1,input[2]:size(1))
- end
-
self.gradInput[1]:resize(input[1]:size())
self.gradInput[2]:resize(input[2]:size())
self.gradInput[1]:copy(input[1])
@@ -65,24 +59,30 @@ function PairwiseDistance:updateGradInput(input, gradOutput)
self.gradInput[1]:cmul(self.gradInput[1]:clone():abs():pow(self.norm-2))
end
- self.outExpand = self.outExpand or self.output.new()
- self.outExpand:resize(self.output:size(1), 1)
- self.outExpand:copy(self.output)
- self.outExpand:add(1.0e-6) -- Prevent divide by zero errors
- self.outExpand:pow(-(self.norm-1))
- self.gradInput[1]:cmul(self.outExpand:expand(self.gradInput[1]:size(1),
- self.gradInput[1]:size(2)))
+ if (input[1]:dim() > 1) then
+ self.outExpand = self.outExpand or self.output.new()
+ self.outExpand:resize(self.output:size(1), 1)
+ self.outExpand:copy(self.output)
+ self.outExpand:add(1.0e-6) -- Prevent divide by zero errors
+ self.outExpand:pow(-(self.norm-1))
+ self.gradInput[1]:cmul(self.outExpand:expand(self.gradInput[1]:size(1),
+ self.gradInput[1]:size(2)))
+ else
+ self.gradInput[1]:mul(math.pow(self.output[1] + 1e-6, -(self.norm-1)))
+ end
end
-
- self.grad = self.grad or gradOutput.new()
- self.ones = self.ones or gradOutput.new()
+ if input[1]:dim() == 1 then
+ self.gradInput[1]:mul(gradOutput[1])
+ else
+ self.grad = self.grad or gradOutput.new()
+ self.ones = self.ones or gradOutput.new()
- self.grad:resizeAs(input[1]):zero()
- self.ones:resize(input[1]:size(2)):fill(1)
+ self.grad:resizeAs(input[1]):zero()
+ self.ones:resize(input[1]:size(2)):fill(1)
- self.grad:addr(gradOutput, self.ones)
- self.gradInput[1]:cmul(self.grad)
-
+ self.grad:addr(gradOutput, self.ones)
+ self.gradInput[1]:cmul(self.grad)
+ end
self.gradInput[2]:zero():add(-1, self.gradInput[1])
return self.gradInput
end