diff options
author | Soumith Chintala <soumith@gmail.com> | 2015-10-20 18:35:26 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2015-10-20 18:35:26 +0300 |
commit | f4be0fd1b0daa8680fd8215921a970c20dc11f32 (patch) | |
tree | 90e81c7789b3656b17e2d8c5cd410aff1c7f648b /Normalize.lua | |
parent | d3da8bad4dd7a28191fd2542392235276187a20e (diff) |
Revert "Fix modification of output in nn.Normalize"
Diffstat (limited to 'Normalize.lua')
-rw-r--r-- | Normalize.lua | 35 |
1 files changed, 20 insertions, 15 deletions
diff --git a/Normalize.lua b/Normalize.lua index d3a17aa..41a8ef2 100644 --- a/Normalize.lua +++ b/Normalize.lua @@ -6,18 +6,17 @@ function Normalize:__init(p,eps) assert(p > 0, p..'-norm not supported') self.p = p self.eps = eps or 1e-10 - self._output = torch.Tensor() - self._gradInput = torch.Tensor() end function Normalize:updateOutput(input) assert(input:dim() <= 2, 'only 1d layer supported') - local input_size = input:size() + local is_batch = true if input:dim() == 1 then input = input:view(1,-1) + is_batch = false end - self._output:resizeAs(input) + self.output:resizeAs(input) self.norm = self.norm or input.new() self.normp = self.normp or input.new() @@ -30,9 +29,11 @@ function Normalize:updateOutput(input) end self.normp:sum(self.buffer,2):add(self.eps) self.norm:pow(self.normp,1/self.p) - self._output:cdiv(input, self.norm:view(-1,1):expandAs(input)) + self.output:cdiv(input,self.norm:view(-1,1):expandAs(self.output)) - self.output = self._output:view(input_size) + if not is_batch then + self.output = self.output[1] + end return self.output end @@ -40,18 +41,19 @@ function Normalize:updateGradInput(input, gradOutput) assert(input:dim() <= 2, 'only 1d layer supported') assert(gradOutput:dim() <= 2, 'only 1d layer supported') - local input_size = input:size() + local is_batch = true if input:dim() == 1 then input = input:view(1,-1) + is_batch = false end local n = input:size(1) -- batch size local d = input:size(2) -- dimensionality of vectors -- compute diagonal term with gradOutput - self._gradInput:resize(n,d,1) + self.gradInput:resize(n,d,1) gradOutput = gradOutput:view(n,d,1) - self._gradInput:cmul(self.normp:view(n,1,1):expand(n,d,1), gradOutput) + self.gradInput:cmul(self.normp:view(n,1,1):expand(n,d,1),gradOutput) -- compute cross term in two steps self.cross = self.cross or input.new() @@ -63,16 +65,19 @@ function Normalize:updateGradInput(input, gradOutput) -- instead of having a huge temporary matrix (b1*b2), -- do the computations as b1*(b2*gradOutput). This avoids redundant -- computation and also a huge buffer of size n*d^2 - self.cross:bmm(b2, gradOutput) - self._gradInput:baddbmm(-1, b1, self.cross) + self.cross:bmm(b2,gradOutput) + self.gradInput:baddbmm(-1,b1, self.cross) -- reuse cross buffer for normalization - self.cross:cmul(self.normp, self.norm) - self._gradInput:cdiv(self.cross:view(n,1,1):expand(n,d,1)) + self.cross:cmul(self.normp,self.norm) + self.gradInput:cdiv(self.cross:view(n,1,1):expand(n,d,1)) - self._gradInput = self._gradInput:view(n,d) + self.gradInput = self.gradInput:view(n,d) - self.gradInput = self._gradInput:view(input_size) + if not is_batch then + self.gradInput = self.gradInput[1] + end + return self.gradInput end |