diff options
author | Adria Puigdomenech <adriap@google.com> | 2015-10-16 17:02:32 +0300 |
---|---|---|
committer | Adria Puigdomenech <adriap@google.com> | 2015-10-16 17:02:32 +0300 |
commit | 2c3544204c73ae9d85c508207e67f3228bc2efe7 (patch) | |
tree | fca1ea26741444eb3b1684a0f09ffa4e08b06e0e /Normalize.lua | |
parent | 2963a98a7b01c539c7438f923edf4acf666bcc74 (diff) |
Fix modification of output in nn.Normalize
Diffstat (limited to 'Normalize.lua')
-rw-r--r-- | Normalize.lua | 35 |
1 files changed, 15 insertions, 20 deletions
diff --git a/Normalize.lua b/Normalize.lua index 41a8ef2..d3a17aa 100644 --- a/Normalize.lua +++ b/Normalize.lua @@ -6,17 +6,18 @@ function Normalize:__init(p,eps) assert(p > 0, p..'-norm not supported') self.p = p self.eps = eps or 1e-10 + self._output = torch.Tensor() + self._gradInput = torch.Tensor() end function Normalize:updateOutput(input) assert(input:dim() <= 2, 'only 1d layer supported') - local is_batch = true + local input_size = input:size() if input:dim() == 1 then input = input:view(1,-1) - is_batch = false end - self.output:resizeAs(input) + self._output:resizeAs(input) self.norm = self.norm or input.new() self.normp = self.normp or input.new() @@ -29,11 +30,9 @@ function Normalize:updateOutput(input) end self.normp:sum(self.buffer,2):add(self.eps) self.norm:pow(self.normp,1/self.p) - self.output:cdiv(input,self.norm:view(-1,1):expandAs(self.output)) + self._output:cdiv(input, self.norm:view(-1,1):expandAs(input)) - if not is_batch then - self.output = self.output[1] - end + self.output = self._output:view(input_size) return self.output end @@ -41,19 +40,18 @@ function Normalize:updateGradInput(input, gradOutput) assert(input:dim() <= 2, 'only 1d layer supported') assert(gradOutput:dim() <= 2, 'only 1d layer supported') - local is_batch = true + local input_size = input:size() if input:dim() == 1 then input = input:view(1,-1) - is_batch = false end local n = input:size(1) -- batch size local d = input:size(2) -- dimensionality of vectors -- compute diagonal term with gradOutput - self.gradInput:resize(n,d,1) + self._gradInput:resize(n,d,1) gradOutput = gradOutput:view(n,d,1) - self.gradInput:cmul(self.normp:view(n,1,1):expand(n,d,1),gradOutput) + self._gradInput:cmul(self.normp:view(n,1,1):expand(n,d,1), gradOutput) -- compute cross term in two steps self.cross = self.cross or input.new() @@ -65,19 +63,16 @@ function Normalize:updateGradInput(input, gradOutput) -- instead of having a huge temporary matrix (b1*b2), -- do the computations as b1*(b2*gradOutput). This avoids redundant -- computation and also a huge buffer of size n*d^2 - self.cross:bmm(b2,gradOutput) - self.gradInput:baddbmm(-1,b1, self.cross) + self.cross:bmm(b2, gradOutput) + self._gradInput:baddbmm(-1, b1, self.cross) -- reuse cross buffer for normalization - self.cross:cmul(self.normp,self.norm) - self.gradInput:cdiv(self.cross:view(n,1,1):expand(n,d,1)) + self.cross:cmul(self.normp, self.norm) + self._gradInput:cdiv(self.cross:view(n,1,1):expand(n,d,1)) - self.gradInput = self.gradInput:view(n,d) + self._gradInput = self._gradInput:view(n,d) - if not is_batch then - self.gradInput = self.gradInput[1] - end - + self.gradInput = self._gradInput:view(input_size) return self.gradInput end |