diff options
author | Adam Lerer <alerer@fb.com> | 2015-12-24 00:26:26 +0300 |
---|---|---|
committer | Adam Lerer <alerer@fb.com> | 2015-12-24 00:26:26 +0300 |
commit | 541cf6140409f3f63c6567ca8ed33779385413d2 (patch) | |
tree | 8597828dae0c8701512f664ff660301d660dc052 /Normalize.lua | |
parent | efb781da04eee6a6c90dc7daf4b287c4ba8382a1 (diff) |
Remove bmm and baddbmm from Normalize, because they allocate memory, causing sync on CUDA.
Diffstat (limited to 'Normalize.lua')
-rw-r--r-- | Normalize.lua | 23 |
1 files changed, 10 insertions, 13 deletions
diff --git a/Normalize.lua b/Normalize.lua index 93149a8..8e9a111 100644 --- a/Normalize.lua +++ b/Normalize.lua @@ -61,9 +61,7 @@ function Normalize:updateGradInput(input, gradOutput) self._gradInput = self._gradInput or input.new() self.cross = self.cross or input.new() -- compute diagonal term with gradOutput - self._gradInput:resize(n,d,1) - gradOutput = gradOutput:view(n,d,1) - + self._gradInput:resize(n,d) if self.p == math.huge then -- specialization for the inf case self._gradInput:cmul(self.norm:view(n,1,1):expand(n,d,1),gradOutput) @@ -73,7 +71,7 @@ function Normalize:updateGradInput(input, gradOutput) self.cross:cdiv(self.norm) self.buffer:scatter(2,self._indices,self.cross) else - self._gradInput:cmul(self.normp:view(n,1,1):expand(n,d,1), gradOutput) + self._gradInput:cmul(self.normp:view(n,1):expand(n,d), gradOutput) -- small optimizations for different p -- buffer = input*|input|^(p-2) if self.p % 2 ~= 0 then @@ -92,17 +90,18 @@ function Normalize:updateGradInput(input, gradOutput) self.buffer:pow(input,self.p-2):cmul(input) end end - -- compute cross term in two steps - self.cross:resize(n,1,1) + self.cross:resize(n,1) - local b1 = self.buffer:view(n,d,1) - local b2 = input:view(n,1,d) -- instead of having a huge temporary matrix (b1*b2), -- do the computations as b1*(b2*gradOutput). This avoids redundant -- computation and also a huge buffer of size n*d^2 - self.cross:bmm(b2, gradOutput) - self._gradInput:baddbmm(-1, b1, self.cross) + self.buffer2 = self.buffer2 or input.new() -- nxd + self.buffer2:cmul(input, gradOutput) + self.cross:sum(self.buffer2, 2) + + self.buffer:cmul(self.cross:expandAs(self.buffer)) + self._gradInput:add(-1, self.buffer) -- reuse cross buffer for normalization if self.p == math.huge then @@ -110,10 +109,8 @@ function Normalize:updateGradInput(input, gradOutput) else self.cross:cmul(self.normp,self.norm) end - self._gradInput:cdiv(self.cross:view(n,1,1):expand(n,d,1)) + self._gradInput:cdiv(self.cross:expand(n,d)) - self._gradInput = self._gradInput:view(n,d) - self.gradInput = self._gradInput:view(input_size) return self.gradInput end |