From 4b3a22b5fc63bce5d7db6d53a587777d7c27ea97 Mon Sep 17 00:00:00 2001 From: Volodymyr Mnih Date: Fri, 21 Jun 2013 11:36:31 +0100 Subject: Added minibatch support to MarginRankingCriterion and PairwiseDistance. --- MarginRankingCriterion.lua | 66 +++++++++++++++++++++++++++++++++++++++------- PairwiseDistance.lua | 52 ++++++++++++++++++++++++++++-------- 2 files changed, 97 insertions(+), 21 deletions(-) diff --git a/MarginRankingCriterion.lua b/MarginRankingCriterion.lua index ec85fb9..5012c2a 100644 --- a/MarginRankingCriterion.lua +++ b/MarginRankingCriterion.lua @@ -8,18 +8,64 @@ function MarginRankingCriterion:__init(margin) end function MarginRankingCriterion:updateOutput(input,y) - self.output=math.max(0, -y*(input[1][1]-input[2][1]) + self.margin ) + if type(input[1]) == "number" then + self.output=math.max(0, -y*(input[1]-input[2]) + self.margin ) + else + if type(self.output) == "number" then + self.output = input[1]:clone() + end + self.output = self.output or input[1]:clone() + self.output:resizeAs(input[1]) + self.output:copy(input[1]) + + self.output:add(-1, input[2]) + self.output:mul(-y) + self.output:add(self.margin) + + self.mask = self.mask or self.output:clone() + self.mask:resizeAs(self.output) + self.mask:copy(self.output) + + self.mask:ge(self.output, 0.0) + self.output:cmul(self.mask) + end + return self.output end function MarginRankingCriterion:updateGradInput(input, y) - local dist = -y*(input[1][1]-input[2][1]) + self.margin - if dist < 0 then - self.gradInput[1][1]=0; - self.gradInput[2][1]=0; - else - self.gradInput[1][1]=-y - self.gradInput[2][1]=y - end - return self.gradInput + if type(input[1]) == "number" then + local dist = -y*(input[1][1]-input[2][1]) + self.margin + if dist < 0 then + self.gradInput[1][1]=0; + self.gradInput[2][1]=0; + else + self.gradInput[1][1]=-y + self.gradInput[2][1]=y + end + else + self.dist = self.dist or input[1].new() + self.dist = self.dist:resizeAs(input[1]):copy(input[1]) + local dist = self.dist + + dist:add(-1, input[2]) + dist:mul(-y) + dist:add(self.margin) + + self.mask = self.mask or input[1].new() + self.mask = self.mask:resizeAs(input[1]):copy(dist) + local mask = self.mask + + mask:ge(dist, 0) + + self.gradInput[1]:resize(dist:size()) + self.gradInput[2]:resize(dist:size()) + + self.gradInput[1]:copy(mask) + self.gradInput[1]:mul(-y) + self.gradInput[2]:copy(mask) + self.gradInput[2]:mul(y) + + end + return self.gradInput end diff --git a/PairwiseDistance.lua b/PairwiseDistance.lua index 638c58f..f108b97 100644 --- a/PairwiseDistance.lua +++ b/PairwiseDistance.lua @@ -5,12 +5,29 @@ function PairwiseDistance:__init(p) -- state self.gradInput = {torch.Tensor(), torch.Tensor()} - self.output = torch.Tensor(1) + self.output = torch.Tensor() self.norm=p end function PairwiseDistance:updateOutput(input) - self.output[1]=input[1]:dist(input[2],self.norm); + if input[1]:dim() == 1 then + self.output[1]=input[1]:dist(input[2],self.norm) + elseif input[1]:dim() == 2 then + self.diff = self.diff or input[1].new() + self.diff:resizeAs(input[1]) + + local diff = self.diff:zero() + --local diff = torch.add(input[1], -1, input[2]) + diff:add(input[1], -1, input[2]) + + self.output:resize(input[1]:size(1)) + self.output:zero() + self.output:add(diff:pow(self.norm):sum(2)) + self.output:pow(1./self.norm) + else + error('input must be vector or matrix') + end + return self.output end @@ -20,14 +37,27 @@ local function mathsign(x) end function PairwiseDistance:updateGradInput(input, gradOutput) - self.gradInput[1]:resizeAs(input[1]) - self.gradInput[2]:resizeAs(input[2]) - self.gradInput[1]:copy(input[1]) - self.gradInput[1]:add(-1, input[2]) - if self.norm==1 then + self.gradInput[1]:resize(input[1]:size()) + self.gradInput[2]:resize(input[2]:size()) + self.gradInput[1]:copy(input[1]) + self.gradInput[1]:add(-1, input[2]) + if self.norm==1 then self.gradInput[1]:apply(mathsign) - end - self.gradInput[1]:mul(gradOutput[1]); - self.gradInput[2]:zero():add(-1, self.gradInput[1]) - return self.gradInput + end + if input[1]:dim() == 1 then + self.gradInput[1]:mul(gradOutput[1]) + elseif input[1]:dim() == 2 then + self.grad = self.grad or gradOutput.new() + self.ones = self.ones or gradOutput.new() + + self.grad:resizeAs(input[1]):zero() + self.ones:resize(input[1]:size(2)):fill(1) + + self.grad:addr(gradOutput, self.ones) + self.gradInput[1]:cmul(self.grad) + else + error('input must be vector or matrix') + end + self.gradInput[2]:zero():add(-1, self.gradInput[1]) + return self.gradInput end -- cgit v1.2.3