diff options
author | Andreas Köpf <andreas.koepf@xamla.com> | 2016-01-25 02:50:43 +0300 |
---|---|---|
committer | Andreas Köpf <andreas.koepf@xamla.com> | 2016-02-01 21:42:34 +0300 |
commit | 66f995fd75234fb2590109263592012a14956d09 (patch) | |
tree | e95c3b75d3fdef990b614f38419d63e119832805 /MultiMarginCriterion.lua | |
parent | cadef1fafaa4246d96a976dc409b06c7f57a3627 (diff) |
Add gradWeightBuf & gradWeightBuf2 params to PReLU_accGradParameters
Diffstat (limited to 'MultiMarginCriterion.lua')
-rw-r--r-- | MultiMarginCriterion.lua | 20 |
1 files changed, 10 insertions, 10 deletions
diff --git a/MultiMarginCriterion.lua b/MultiMarginCriterion.lua index 1506c1e..7ea78ca 100644 --- a/MultiMarginCriterion.lua +++ b/MultiMarginCriterion.lua @@ -9,16 +9,16 @@ end function MultiMarginCriterion:updateOutput(input, target) -- backward compatibility - local _target = target - if not torch.isTensor(_target) then - _target = input.new(1) - _target[1] = target + if not torch.isTensor(target) then + self.target_tensor = self.target_tensor or input.new(1) + self.target_tensor[1] = target + target = self.target_tensor end self.p = self.p or 1 self.output_tensor = self.output_tensor or input.new(1) input.THNN.MultiMarginCriterion_updateOutput( input:cdata(), - _target:cdata(), + target:cdata(), self.output_tensor:cdata(), self.sizeAverage, self.p @@ -28,14 +28,14 @@ function MultiMarginCriterion:updateOutput(input, target) end function MultiMarginCriterion:updateGradInput(input, target) - local _target = target - if not torch.isTensor(_target) then - _target = input.new(1) - _target[1] = target + if not torch.isTensor(target) then + self.target_tensor = self.target_tensor or input.new(1) + self.target_tensor[1] = target + target = self.target_tensor end input.THNN.MultiMarginCriterion_updateGradInput( input:cdata(), - _target:cdata(), + target:cdata(), self.gradInput:cdata(), self.sizeAverage, self.p |