Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAdam Lerer <alerer@fb.com>2015-12-24 00:26:26 +0300
committerAdam Lerer <alerer@fb.com>2015-12-24 00:26:26 +0300
commit541cf6140409f3f63c6567ca8ed33779385413d2 (patch)
tree8597828dae0c8701512f664ff660301d660dc052 /Normalize.lua
parentefb781da04eee6a6c90dc7daf4b287c4ba8382a1 (diff)
Remove bmm and baddbmm from Normalize, because they allocate memory, causing sync on CUDA.
Diffstat (limited to 'Normalize.lua')
-rw-r--r--Normalize.lua23
1 files changed, 10 insertions, 13 deletions
diff --git a/Normalize.lua b/Normalize.lua
index 93149a8..8e9a111 100644
--- a/Normalize.lua
+++ b/Normalize.lua
@@ -61,9 +61,7 @@ function Normalize:updateGradInput(input, gradOutput)
self._gradInput = self._gradInput or input.new()
self.cross = self.cross or input.new()
-- compute diagonal term with gradOutput
- self._gradInput:resize(n,d,1)
- gradOutput = gradOutput:view(n,d,1)
-
+ self._gradInput:resize(n,d)
if self.p == math.huge then
-- specialization for the inf case
self._gradInput:cmul(self.norm:view(n,1,1):expand(n,d,1),gradOutput)
@@ -73,7 +71,7 @@ function Normalize:updateGradInput(input, gradOutput)
self.cross:cdiv(self.norm)
self.buffer:scatter(2,self._indices,self.cross)
else
- self._gradInput:cmul(self.normp:view(n,1,1):expand(n,d,1), gradOutput)
+ self._gradInput:cmul(self.normp:view(n,1):expand(n,d), gradOutput)
-- small optimizations for different p
-- buffer = input*|input|^(p-2)
if self.p % 2 ~= 0 then
@@ -92,17 +90,18 @@ function Normalize:updateGradInput(input, gradOutput)
self.buffer:pow(input,self.p-2):cmul(input)
end
end
-
-- compute cross term in two steps
- self.cross:resize(n,1,1)
+ self.cross:resize(n,1)
- local b1 = self.buffer:view(n,d,1)
- local b2 = input:view(n,1,d)
-- instead of having a huge temporary matrix (b1*b2),
-- do the computations as b1*(b2*gradOutput). This avoids redundant
-- computation and also a huge buffer of size n*d^2
- self.cross:bmm(b2, gradOutput)
- self._gradInput:baddbmm(-1, b1, self.cross)
+ self.buffer2 = self.buffer2 or input.new() -- nxd
+ self.buffer2:cmul(input, gradOutput)
+ self.cross:sum(self.buffer2, 2)
+
+ self.buffer:cmul(self.cross:expandAs(self.buffer))
+ self._gradInput:add(-1, self.buffer)
-- reuse cross buffer for normalization
if self.p == math.huge then
@@ -110,10 +109,8 @@ function Normalize:updateGradInput(input, gradOutput)
else
self.cross:cmul(self.normp,self.norm)
end
- self._gradInput:cdiv(self.cross:view(n,1,1):expand(n,d,1))
+ self._gradInput:cdiv(self.cross:expand(n,d))
- self._gradInput = self._gradInput:view(n,d)
-
self.gradInput = self._gradInput:view(input_size)
return self.gradInput
end