Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAdria Puigdomenech <adriap@google.com>2015-10-16 17:02:32 +0300
committerAdria Puigdomenech <adriap@google.com>2015-10-16 17:02:32 +0300
commit2c3544204c73ae9d85c508207e67f3228bc2efe7 (patch)
treefca1ea26741444eb3b1684a0f09ffa4e08b06e0e /Normalize.lua
parent2963a98a7b01c539c7438f923edf4acf666bcc74 (diff)
Fix modification of output in nn.Normalize
Diffstat (limited to 'Normalize.lua')
-rw-r--r--Normalize.lua35
1 files changed, 15 insertions, 20 deletions
diff --git a/Normalize.lua b/Normalize.lua
index 41a8ef2..d3a17aa 100644
--- a/Normalize.lua
+++ b/Normalize.lua
@@ -6,17 +6,18 @@ function Normalize:__init(p,eps)
assert(p > 0, p..'-norm not supported')
self.p = p
self.eps = eps or 1e-10
+ self._output = torch.Tensor()
+ self._gradInput = torch.Tensor()
end
function Normalize:updateOutput(input)
assert(input:dim() <= 2, 'only 1d layer supported')
- local is_batch = true
+ local input_size = input:size()
if input:dim() == 1 then
input = input:view(1,-1)
- is_batch = false
end
- self.output:resizeAs(input)
+ self._output:resizeAs(input)
self.norm = self.norm or input.new()
self.normp = self.normp or input.new()
@@ -29,11 +30,9 @@ function Normalize:updateOutput(input)
end
self.normp:sum(self.buffer,2):add(self.eps)
self.norm:pow(self.normp,1/self.p)
- self.output:cdiv(input,self.norm:view(-1,1):expandAs(self.output))
+ self._output:cdiv(input, self.norm:view(-1,1):expandAs(input))
- if not is_batch then
- self.output = self.output[1]
- end
+ self.output = self._output:view(input_size)
return self.output
end
@@ -41,19 +40,18 @@ function Normalize:updateGradInput(input, gradOutput)
assert(input:dim() <= 2, 'only 1d layer supported')
assert(gradOutput:dim() <= 2, 'only 1d layer supported')
- local is_batch = true
+ local input_size = input:size()
if input:dim() == 1 then
input = input:view(1,-1)
- is_batch = false
end
local n = input:size(1) -- batch size
local d = input:size(2) -- dimensionality of vectors
-- compute diagonal term with gradOutput
- self.gradInput:resize(n,d,1)
+ self._gradInput:resize(n,d,1)
gradOutput = gradOutput:view(n,d,1)
- self.gradInput:cmul(self.normp:view(n,1,1):expand(n,d,1),gradOutput)
+ self._gradInput:cmul(self.normp:view(n,1,1):expand(n,d,1), gradOutput)
-- compute cross term in two steps
self.cross = self.cross or input.new()
@@ -65,19 +63,16 @@ function Normalize:updateGradInput(input, gradOutput)
-- instead of having a huge temporary matrix (b1*b2),
-- do the computations as b1*(b2*gradOutput). This avoids redundant
-- computation and also a huge buffer of size n*d^2
- self.cross:bmm(b2,gradOutput)
- self.gradInput:baddbmm(-1,b1, self.cross)
+ self.cross:bmm(b2, gradOutput)
+ self._gradInput:baddbmm(-1, b1, self.cross)
-- reuse cross buffer for normalization
- self.cross:cmul(self.normp,self.norm)
- self.gradInput:cdiv(self.cross:view(n,1,1):expand(n,d,1))
+ self.cross:cmul(self.normp, self.norm)
+ self._gradInput:cdiv(self.cross:view(n,1,1):expand(n,d,1))
- self.gradInput = self.gradInput:view(n,d)
+ self._gradInput = self._gradInput:view(n,d)
- if not is_batch then
- self.gradInput = self.gradInput[1]
- end
-
+ self.gradInput = self._gradInput:view(input_size)
return self.gradInput
end