Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorFrancisco Massa <fvsmassa@gmail.com>2015-10-24 23:24:52 +0300
committerFrancisco Massa <fvsmassa@gmail.com>2015-10-24 23:24:52 +0300
commit08195f320495a50b480628901fd96071fca8d18f (patch)
treeca888a86a1b3667a8482b92233d13b8b97ee557d /Normalize.lua
parentd1d20dbf94c87e12d343acbac8ae1f67997a607e (diff)
Fix possible division by zero in Normalize
When p < 2, if the input is 0 there was division by zero. Add a small dampening factor to avoid this. Also, small optimizations for different p.
Diffstat (limited to 'Normalize.lua')
-rw-r--r--Normalize.lua19
1 files changed, 18 insertions, 1 deletions
diff --git a/Normalize.lua b/Normalize.lua
index 304caa0..58dc11d 100644
--- a/Normalize.lua
+++ b/Normalize.lua
@@ -53,11 +53,28 @@ function Normalize:updateGradInput(input, gradOutput)
gradOutput = gradOutput:view(n,d,1)
self._gradInput:cmul(self.normp:view(n,1,1):expand(n,d,1), gradOutput)
+ -- small optimizations for different p
+ -- buffer = input*|input|^(p-2)
+ if self.p % 2 ~= 0 then
+ -- for non-even p, need to add absolute value
+ if self.p < 2 then
+ -- add eps to avoid possible division by 0
+ self.buffer:abs(input):add(self.eps):pow(self.p-2):cmul(input)
+ else
+ self.buffer:abs(input):pow(self.p-2):cmul(input)
+ end
+ elseif self.p == 2 then
+ -- special case for p == 2, pow(x,0) = 1
+ self.buffer:copy(input)
+ else
+ -- p is even and > 2, pow(x,p) is always positive
+ self.buffer:pow(input,self.p-2):cmul(input)
+ end
+
-- compute cross term in two steps
self.cross = self.cross or input.new()
self.cross:resize(n,1,1)
- self.buffer:abs(input):pow(self.p-2):cmul(input)
local b1 = self.buffer:view(n,d,1)
local b2 = input:view(n,1,d)
-- instead of having a huge temporary matrix (b1*b2),