From fd76fcc7787eac37757e9b9ccfda67da1cc9920a Mon Sep 17 00:00:00 2001 From: koray kavukcuoglu Date: Wed, 16 Oct 2013 23:13:52 +0100 Subject: correct the stochastic case check --- MarginRankingCriterion.lua | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/MarginRankingCriterion.lua b/MarginRankingCriterion.lua index 5012c2a..30c6855 100644 --- a/MarginRankingCriterion.lua +++ b/MarginRankingCriterion.lua @@ -8,8 +8,8 @@ function MarginRankingCriterion:__init(margin) end function MarginRankingCriterion:updateOutput(input,y) - if type(input[1]) == "number" then - self.output=math.max(0, -y*(input[1]-input[2]) + self.margin ) + if input[1]:size(1) == 1 then + self.output=math.max(0, -y*(input[1][1]-input[2][1]) + self.margin ) else if type(self.output) == "number" then self.output = input[1]:clone() @@ -34,7 +34,7 @@ function MarginRankingCriterion:updateOutput(input,y) end function MarginRankingCriterion:updateGradInput(input, y) - if type(input[1]) == "number" then + if input[1]:size(1) == 1 then local dist = -y*(input[1][1]-input[2][1]) + self.margin if dist < 0 then self.gradInput[1][1]=0; -- cgit v1.2.3