Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2015-10-20 18:35:26 +0300
committerSoumith Chintala <soumith@gmail.com>2015-10-20 18:35:26 +0300
commitf4be0fd1b0daa8680fd8215921a970c20dc11f32 (patch)
tree90e81c7789b3656b17e2d8c5cd410aff1c7f648b /Normalize.lua
parentd3da8bad4dd7a28191fd2542392235276187a20e (diff)
Revert "Fix modification of output in nn.Normalize"
Diffstat (limited to 'Normalize.lua')
-rw-r--r--Normalize.lua35
1 files changed, 20 insertions, 15 deletions
diff --git a/Normalize.lua b/Normalize.lua
index d3a17aa..41a8ef2 100644
--- a/Normalize.lua
+++ b/Normalize.lua
@@ -6,18 +6,17 @@ function Normalize:__init(p,eps)
assert(p > 0, p..'-norm not supported')
self.p = p
self.eps = eps or 1e-10
- self._output = torch.Tensor()
- self._gradInput = torch.Tensor()
end
function Normalize:updateOutput(input)
assert(input:dim() <= 2, 'only 1d layer supported')
- local input_size = input:size()
+ local is_batch = true
if input:dim() == 1 then
input = input:view(1,-1)
+ is_batch = false
end
- self._output:resizeAs(input)
+ self.output:resizeAs(input)
self.norm = self.norm or input.new()
self.normp = self.normp or input.new()
@@ -30,9 +29,11 @@ function Normalize:updateOutput(input)
end
self.normp:sum(self.buffer,2):add(self.eps)
self.norm:pow(self.normp,1/self.p)
- self._output:cdiv(input, self.norm:view(-1,1):expandAs(input))
+ self.output:cdiv(input,self.norm:view(-1,1):expandAs(self.output))
- self.output = self._output:view(input_size)
+ if not is_batch then
+ self.output = self.output[1]
+ end
return self.output
end
@@ -40,18 +41,19 @@ function Normalize:updateGradInput(input, gradOutput)
assert(input:dim() <= 2, 'only 1d layer supported')
assert(gradOutput:dim() <= 2, 'only 1d layer supported')
- local input_size = input:size()
+ local is_batch = true
if input:dim() == 1 then
input = input:view(1,-1)
+ is_batch = false
end
local n = input:size(1) -- batch size
local d = input:size(2) -- dimensionality of vectors
-- compute diagonal term with gradOutput
- self._gradInput:resize(n,d,1)
+ self.gradInput:resize(n,d,1)
gradOutput = gradOutput:view(n,d,1)
- self._gradInput:cmul(self.normp:view(n,1,1):expand(n,d,1), gradOutput)
+ self.gradInput:cmul(self.normp:view(n,1,1):expand(n,d,1),gradOutput)
-- compute cross term in two steps
self.cross = self.cross or input.new()
@@ -63,16 +65,19 @@ function Normalize:updateGradInput(input, gradOutput)
-- instead of having a huge temporary matrix (b1*b2),
-- do the computations as b1*(b2*gradOutput). This avoids redundant
-- computation and also a huge buffer of size n*d^2
- self.cross:bmm(b2, gradOutput)
- self._gradInput:baddbmm(-1, b1, self.cross)
+ self.cross:bmm(b2,gradOutput)
+ self.gradInput:baddbmm(-1,b1, self.cross)
-- reuse cross buffer for normalization
- self.cross:cmul(self.normp, self.norm)
- self._gradInput:cdiv(self.cross:view(n,1,1):expand(n,d,1))
+ self.cross:cmul(self.normp,self.norm)
+ self.gradInput:cdiv(self.cross:view(n,1,1):expand(n,d,1))
- self._gradInput = self._gradInput:view(n,d)
+ self.gradInput = self.gradInput:view(n,d)
- self.gradInput = self._gradInput:view(input_size)
+ if not is_batch then
+ self.gradInput = self.gradInput[1]
+ end
+
return self.gradInput
end