Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorsoumith <soumith@fb.com>2016-02-20 04:40:57 +0300
committersoumith <soumith@fb.com>2016-02-20 04:44:51 +0300
commitd1f57af38674a246a2e7677def0c63e59e97e61a (patch)
tree725d74d381a4cfe58c2d072441caec6215d0368c /MultiMarginCriterion.lua
parentcc8e874a8858c7f5e13c2e1b2a26e5e49ae9f91d (diff)
adding weights to MultiMarginCriterion
Diffstat (limited to 'MultiMarginCriterion.lua')
-rw-r--r--MultiMarginCriterion.lua13
1 files changed, 10 insertions, 3 deletions
diff --git a/MultiMarginCriterion.lua b/MultiMarginCriterion.lua
index 7ea78ca..3f480cc 100644
--- a/MultiMarginCriterion.lua
+++ b/MultiMarginCriterion.lua
@@ -1,10 +1,15 @@
+local THNN = require 'nn.THNN'
local MultiMarginCriterion, parent = torch.class('nn.MultiMarginCriterion', 'nn.Criterion')
-function MultiMarginCriterion:__init(p)
+function MultiMarginCriterion:__init(p, weights)
assert(p == nil or p == 1 or p == 2, 'only p=1 and p=2 supported')
self.p = p or 1
parent.__init(self)
self.sizeAverage = true
+ if weights then
+ assert(weights:dim() == 1, "weights input should be 1-D Tensor")
+ self.weights = weights
+ end
end
function MultiMarginCriterion:updateOutput(input, target)
@@ -21,7 +26,8 @@ function MultiMarginCriterion:updateOutput(input, target)
target:cdata(),
self.output_tensor:cdata(),
self.sizeAverage,
- self.p
+ self.p,
+ THNN.optionalTensor(self.weights)
)
self.output = self.output_tensor[1]
return self.output
@@ -38,7 +44,8 @@ function MultiMarginCriterion:updateGradInput(input, target)
target:cdata(),
self.gradInput:cdata(),
self.sizeAverage,
- self.p
+ self.p,
+ THNN.optionalTensor(self.weights)
)
return self.gradInput
end