Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorClement Farabet <clement.farabet@gmail.com>2013-10-16 20:30:54 +0400
committerClement Farabet <clement.farabet@gmail.com>2013-10-16 20:30:54 +0400
commit3dcae303d2eaa9c11682102c2c1aaa318e31dd7f (patch)
tree6ef46b1bf38b6d2355455da8923c9994a532df04
parentb994633771dd31d2597307bf85068176662f90ac (diff)
parent4b3a22b5fc63bce5d7db6d53a587777d7c27ea97 (diff)
Merge pull request #163 from vladmnih/batch
Added minibatch support to MarginRankingCriterion and PairwiseDistance.
-rw-r--r--MarginRankingCriterion.lua66
-rw-r--r--PairwiseDistance.lua52
2 files changed, 97 insertions, 21 deletions
diff --git a/MarginRankingCriterion.lua b/MarginRankingCriterion.lua
index ec85fb9..5012c2a 100644
--- a/MarginRankingCriterion.lua
+++ b/MarginRankingCriterion.lua
@@ -8,18 +8,64 @@ function MarginRankingCriterion:__init(margin)
end
function MarginRankingCriterion:updateOutput(input,y)
- self.output=math.max(0, -y*(input[1][1]-input[2][1]) + self.margin )
+ if type(input[1]) == "number" then
+ self.output=math.max(0, -y*(input[1]-input[2]) + self.margin )
+ else
+ if type(self.output) == "number" then
+ self.output = input[1]:clone()
+ end
+ self.output = self.output or input[1]:clone()
+ self.output:resizeAs(input[1])
+ self.output:copy(input[1])
+
+ self.output:add(-1, input[2])
+ self.output:mul(-y)
+ self.output:add(self.margin)
+
+ self.mask = self.mask or self.output:clone()
+ self.mask:resizeAs(self.output)
+ self.mask:copy(self.output)
+
+ self.mask:ge(self.output, 0.0)
+ self.output:cmul(self.mask)
+ end
+
return self.output
end
function MarginRankingCriterion:updateGradInput(input, y)
- local dist = -y*(input[1][1]-input[2][1]) + self.margin
- if dist < 0 then
- self.gradInput[1][1]=0;
- self.gradInput[2][1]=0;
- else
- self.gradInput[1][1]=-y
- self.gradInput[2][1]=y
- end
- return self.gradInput
+ if type(input[1]) == "number" then
+ local dist = -y*(input[1][1]-input[2][1]) + self.margin
+ if dist < 0 then
+ self.gradInput[1][1]=0;
+ self.gradInput[2][1]=0;
+ else
+ self.gradInput[1][1]=-y
+ self.gradInput[2][1]=y
+ end
+ else
+ self.dist = self.dist or input[1].new()
+ self.dist = self.dist:resizeAs(input[1]):copy(input[1])
+ local dist = self.dist
+
+ dist:add(-1, input[2])
+ dist:mul(-y)
+ dist:add(self.margin)
+
+ self.mask = self.mask or input[1].new()
+ self.mask = self.mask:resizeAs(input[1]):copy(dist)
+ local mask = self.mask
+
+ mask:ge(dist, 0)
+
+ self.gradInput[1]:resize(dist:size())
+ self.gradInput[2]:resize(dist:size())
+
+ self.gradInput[1]:copy(mask)
+ self.gradInput[1]:mul(-y)
+ self.gradInput[2]:copy(mask)
+ self.gradInput[2]:mul(y)
+
+ end
+ return self.gradInput
end
diff --git a/PairwiseDistance.lua b/PairwiseDistance.lua
index 638c58f..f108b97 100644
--- a/PairwiseDistance.lua
+++ b/PairwiseDistance.lua
@@ -5,12 +5,29 @@ function PairwiseDistance:__init(p)
-- state
self.gradInput = {torch.Tensor(), torch.Tensor()}
- self.output = torch.Tensor(1)
+ self.output = torch.Tensor()
self.norm=p
end
function PairwiseDistance:updateOutput(input)
- self.output[1]=input[1]:dist(input[2],self.norm);
+ if input[1]:dim() == 1 then
+ self.output[1]=input[1]:dist(input[2],self.norm)
+ elseif input[1]:dim() == 2 then
+ self.diff = self.diff or input[1].new()
+ self.diff:resizeAs(input[1])
+
+ local diff = self.diff:zero()
+ --local diff = torch.add(input[1], -1, input[2])
+ diff:add(input[1], -1, input[2])
+
+ self.output:resize(input[1]:size(1))
+ self.output:zero()
+ self.output:add(diff:pow(self.norm):sum(2))
+ self.output:pow(1./self.norm)
+ else
+ error('input must be vector or matrix')
+ end
+
return self.output
end
@@ -20,14 +37,27 @@ local function mathsign(x)
end
function PairwiseDistance:updateGradInput(input, gradOutput)
- self.gradInput[1]:resizeAs(input[1])
- self.gradInput[2]:resizeAs(input[2])
- self.gradInput[1]:copy(input[1])
- self.gradInput[1]:add(-1, input[2])
- if self.norm==1 then
+ self.gradInput[1]:resize(input[1]:size())
+ self.gradInput[2]:resize(input[2]:size())
+ self.gradInput[1]:copy(input[1])
+ self.gradInput[1]:add(-1, input[2])
+ if self.norm==1 then
self.gradInput[1]:apply(mathsign)
- end
- self.gradInput[1]:mul(gradOutput[1]);
- self.gradInput[2]:zero():add(-1, self.gradInput[1])
- return self.gradInput
+ end
+ if input[1]:dim() == 1 then
+ self.gradInput[1]:mul(gradOutput[1])
+ elseif input[1]:dim() == 2 then
+ self.grad = self.grad or gradOutput.new()
+ self.ones = self.ones or gradOutput.new()
+
+ self.grad:resizeAs(input[1]):zero()
+ self.ones:resize(input[1]:size(2)):fill(1)
+
+ self.grad:addr(gradOutput, self.ones)
+ self.gradInput[1]:cmul(self.grad)
+ else
+ error('input must be vector or matrix')
+ end
+ self.gradInput[2]:zero():add(-1, self.gradInput[1])
+ return self.gradInput
end