Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorkoray kavukcuoglu <koray@kavukcuoglu.org>2013-10-17 02:13:52 +0400
committerkoray kavukcuoglu <koray@kavukcuoglu.org>2013-10-17 02:13:52 +0400
commitfd76fcc7787eac37757e9b9ccfda67da1cc9920a (patch)
tree74d31f66f82574e8f268fc9396255f7f4e0a49bf
parent3dcae303d2eaa9c11682102c2c1aaa318e31dd7f (diff)
correct the stochastic case check
-rw-r--r--MarginRankingCriterion.lua6
1 files changed, 3 insertions, 3 deletions
diff --git a/MarginRankingCriterion.lua b/MarginRankingCriterion.lua
index 5012c2a..30c6855 100644
--- a/MarginRankingCriterion.lua
+++ b/MarginRankingCriterion.lua
@@ -8,8 +8,8 @@ function MarginRankingCriterion:__init(margin)
end
function MarginRankingCriterion:updateOutput(input,y)
- if type(input[1]) == "number" then
- self.output=math.max(0, -y*(input[1]-input[2]) + self.margin )
+ if input[1]:size(1) == 1 then
+ self.output=math.max(0, -y*(input[1][1]-input[2][1]) + self.margin )
else
if type(self.output) == "number" then
self.output = input[1]:clone()
@@ -34,7 +34,7 @@ function MarginRankingCriterion:updateOutput(input,y)
end
function MarginRankingCriterion:updateGradInput(input, y)
- if type(input[1]) == "number" then
+ if input[1]:size(1) == 1 then
local dist = -y*(input[1][1]-input[2][1]) + self.margin
if dist < 0 then
self.gradInput[1][1]=0;