Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndreas Köpf <andreas.koepf@xamla.com>2016-01-25 02:50:43 +0300
committerAndreas Köpf <andreas.koepf@xamla.com>2016-02-01 21:42:34 +0300
commit66f995fd75234fb2590109263592012a14956d09 (patch)
treee95c3b75d3fdef990b614f38419d63e119832805 /MultiMarginCriterion.lua
parentcadef1fafaa4246d96a976dc409b06c7f57a3627 (diff)
Add gradWeightBuf & gradWeightBuf2 params to PReLU_accGradParameters
Diffstat (limited to 'MultiMarginCriterion.lua')
-rw-r--r--MultiMarginCriterion.lua20
1 files changed, 10 insertions, 10 deletions
diff --git a/MultiMarginCriterion.lua b/MultiMarginCriterion.lua
index 1506c1e..7ea78ca 100644
--- a/MultiMarginCriterion.lua
+++ b/MultiMarginCriterion.lua
@@ -9,16 +9,16 @@ end
function MultiMarginCriterion:updateOutput(input, target)
-- backward compatibility
- local _target = target
- if not torch.isTensor(_target) then
- _target = input.new(1)
- _target[1] = target
+ if not torch.isTensor(target) then
+ self.target_tensor = self.target_tensor or input.new(1)
+ self.target_tensor[1] = target
+ target = self.target_tensor
end
self.p = self.p or 1
self.output_tensor = self.output_tensor or input.new(1)
input.THNN.MultiMarginCriterion_updateOutput(
input:cdata(),
- _target:cdata(),
+ target:cdata(),
self.output_tensor:cdata(),
self.sizeAverage,
self.p
@@ -28,14 +28,14 @@ function MultiMarginCriterion:updateOutput(input, target)
end
function MultiMarginCriterion:updateGradInput(input, target)
- local _target = target
- if not torch.isTensor(_target) then
- _target = input.new(1)
- _target[1] = target
+ if not torch.isTensor(target) then
+ self.target_tensor = self.target_tensor or input.new(1)
+ self.target_tensor[1] = target
+ target = self.target_tensor
end
input.THNN.MultiMarginCriterion_updateGradInput(
input:cdata(),
- _target:cdata(),
+ target:cdata(),
self.gradInput:cdata(),
self.sizeAverage,
self.p