diff options
author | Francisco Massa <fvsmassa@gmail.com> | 2015-10-24 23:24:52 +0300 |
---|---|---|
committer | Francisco Massa <fvsmassa@gmail.com> | 2015-10-24 23:24:52 +0300 |
commit | 08195f320495a50b480628901fd96071fca8d18f (patch) | |
tree | ca888a86a1b3667a8482b92233d13b8b97ee557d /Normalize.lua | |
parent | d1d20dbf94c87e12d343acbac8ae1f67997a607e (diff) |
Fix possible division by zero in Normalize
When p < 2, if the input is 0 there was division by zero. Add a small dampening factor to avoid this. Also, small optimizations for different p.
Diffstat (limited to 'Normalize.lua')
-rw-r--r-- | Normalize.lua | 19 |
1 files changed, 18 insertions, 1 deletions
diff --git a/Normalize.lua b/Normalize.lua index 304caa0..58dc11d 100644 --- a/Normalize.lua +++ b/Normalize.lua @@ -53,11 +53,28 @@ function Normalize:updateGradInput(input, gradOutput) gradOutput = gradOutput:view(n,d,1) self._gradInput:cmul(self.normp:view(n,1,1):expand(n,d,1), gradOutput) + -- small optimizations for different p + -- buffer = input*|input|^(p-2) + if self.p % 2 ~= 0 then + -- for non-even p, need to add absolute value + if self.p < 2 then + -- add eps to avoid possible division by 0 + self.buffer:abs(input):add(self.eps):pow(self.p-2):cmul(input) + else + self.buffer:abs(input):pow(self.p-2):cmul(input) + end + elseif self.p == 2 then + -- special case for p == 2, pow(x,0) = 1 + self.buffer:copy(input) + else + -- p is even and > 2, pow(x,p) is always positive + self.buffer:pow(input,self.p-2):cmul(input) + end + -- compute cross term in two steps self.cross = self.cross or input.new() self.cross:resize(n,1,1) - self.buffer:abs(input):pow(self.p-2):cmul(input) local b1 = self.buffer:view(n,d,1) local b2 = input:view(n,1,d) -- instead of having a huge temporary matrix (b1*b2), |