Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'PairwiseDistance.lua')
-rw-r--r--PairwiseDistance.lua33
1 files changed, 33 insertions, 0 deletions
diff --git a/PairwiseDistance.lua b/PairwiseDistance.lua
new file mode 100644
index 0000000..638c58f
--- /dev/null
+++ b/PairwiseDistance.lua
@@ -0,0 +1,33 @@
+local PairwiseDistance, parent = torch.class('nn.PairwiseDistance', 'nn.Module')
+
+function PairwiseDistance:__init(p)
+ parent.__init(self)
+
+ -- state
+ self.gradInput = {torch.Tensor(), torch.Tensor()}
+ self.output = torch.Tensor(1)
+ self.norm=p
+end
+
+function PairwiseDistance:updateOutput(input)
+ self.output[1]=input[1]:dist(input[2],self.norm);
+ return self.output
+end
+
+local function mathsign(x)
+ if x==0 then return 2*torch.random(2)-3; end
+ if x>0 then return 1; else return -1; end
+end
+
+function PairwiseDistance:updateGradInput(input, gradOutput)
+ self.gradInput[1]:resizeAs(input[1])
+ self.gradInput[2]:resizeAs(input[2])
+ self.gradInput[1]:copy(input[1])
+ self.gradInput[1]:add(-1, input[2])
+ if self.norm==1 then
+ self.gradInput[1]:apply(mathsign)
+ end
+ self.gradInput[1]:mul(gradOutput[1]);
+ self.gradInput[2]:zero():add(-1, self.gradInput[1])
+ return self.gradInput
+end