diff options
author | Zeming Lin <misterabc@devgpu029.prn2.facebook.com> | 2016-03-08 00:12:37 +0300 |
---|---|---|
committer | Zeming Lin <misterabc@devgpu013.ash5.facebook.com> | 2016-03-12 00:06:21 +0300 |
commit | 497e8cdc261d1ca9708722dcf2d723ed5f9f1e14 (patch) | |
tree | e2fdacf89df4977e03fa82780ef64d09de599e8c /SparseLinear.lua | |
parent | 3656a5762054897eeea52922d24b9f3c08c27819 (diff) |
Adding table input support for batched SparseLinear, implementing gradInput correctly, fixing other bugs
Diffstat (limited to 'SparseLinear.lua')
-rw-r--r-- | SparseLinear.lua | 196 |
1 files changed, 137 insertions, 59 deletions
diff --git a/SparseLinear.lua b/SparseLinear.lua index 77ef7c3..6185df9 100644 --- a/SparseLinear.lua +++ b/SparseLinear.lua @@ -1,19 +1,25 @@ local THNN = require 'nn.THNN' local SparseLinear, parent = torch.class('nn.SparseLinear', 'nn.Module') -function SparseLinear:__init(inputSize, outputSize) +local NO_LAST_INPUT = 0 +local ONE_LAST_INPUT = 1 +local ACC_MULTIPLE_TIMES = 2 + +function SparseLinear:__init(inputSize, outputSize, doGradInput) parent.__init(self) self.weightDecay = 0 + self.doGradInput = doGradInput or false self.weight = torch.Tensor(outputSize, inputSize):zero() self.bias = torch.Tensor(outputSize):zero() self.gradWeight = torch.Tensor(outputSize, inputSize):zero() self.gradBias = torch.Tensor(outputSize):zero() - self.lastInput = nil - if torch.getnumthreads() > 1 and outputSize >= 128 then - self.shardBuffer = torch.Tensor(outputSize, torch.getnumthreads()) - end + assert(type(self.doGradInput) == type(true)) + + self.lastInput = nil + self.sparseUpdate = NO_LAST_INPUT + self.formatted_input = nil -- state self.gradInput:resize(inputSize) @@ -33,78 +39,148 @@ function SparseLinear:reset(stdv) end function SparseLinear:reshapeInput(input) - if input:dim() == 2 then - return input:view(1, input:size(1), input:size(2)), false + if type(input) == 'table' then + return input, true, false else - return input, true + if input:dim() == 2 then + return {input}, false, false + else + return input, true, true + end end end function SparseLinear:updateOutput(input) - self.cudaBuffer = self.cudaBuffer or input.new() - local input, batchMode = self:reshapeInput(input) - - input.THNN.SparseLinear_updateOutput( - input:cdata(), - self.output:cdata(), - self.weight:cdata(), - self.bias:cdata(), - self.cudaBuffer:cdata(), - THNN.optionalTensor(self.shardBuffer) - ) - - -- fix output size for batchSize = 1 - if not batchMode then - self.output:set(self.output:view(self.output:size(2))) - end + local input, batchMode, legacyMode = self:reshapeInput(input) + self.legacyMode = legacyMode - return self.output -end + if legacyMode then + input.THNN.SparseLinear_legacyUpdateOutput( + input:cdata(), + self.output:cdata(), + self.weight:cdata(), + self.bias:cdata() + ) + else + local nbatches = #input + if nbatches == 0 then + self.output:copy(self.bias) + return self.output + end -function SparseLinear:accGradParameters(input, gradOutput, scale) - local input, batchMode = self:reshapeInput(input) + local size = 0 + local marker = 1 + self.formatted_input = self.formatted_input or input[1].new() + + for i,v in ipairs(input) do size = size + input[i]:size(1) end + self.formatted_input:resize(size, 3) + for i,v in ipairs(input) do + local buf = self.formatted_input:narrow(1, marker, input[i]:size(1)) + buf:narrow(2,2,2):copy(input[i]) + buf:select(2,1):fill(i) + marker = marker + input[i]:size(1) + end - self.lastInput = self.lastInput or input.new() - self.lastInput:resizeAs(input):copy(input) - if not batchMode then - gradOutput = gradOutput:view(1, gradOutput:size(1)) + self.output:resize(nbatches, self.weight:size(1)) + input[1].THNN.SparseLinear_updateOutput( + self.formatted_input:cdata(), + self.output:cdata(), + self.weight:cdata(), + self.bias:cdata() + ) + + -- fix output size for batchSize = 1 + if not batchMode then + self.output = self.output[1] + end end - input.THNN.SparseLinear_accGradParameters( - input:cdata(), - gradOutput:cdata(), - self.gradWeight:cdata(), - self.gradBias:cdata(), - self.weight:cdata(), - self.bias:cdata(), - self.weightDecay or 0, - scale or 1 - ) + return self.output end -function SparseLinear:updateGradInput(input, gradOutput) - if self.gradInput then - local input, batchMode = self:reshapeInput(input) - if not batchMode then - gradOutput = gradOutput:view(1, gradOutput:size(1)) +function SparseLinear:accGradParameters(input, gradOutput, scale) + local input, batchMode, legacyMode = self:reshapeInput(input) + self.legacyMode = legacyMode + + if legacyMode then + self.lastInput = self.lastInput or input.new() + if self.sparseUpdate == NO_LAST_INPUT then + self.lastInput:resizeAs(input):copy(input) + self.sparseUpdate = ONE_LAST_INPUT + elseif self.sparseUpdate == ONE_LAST_INPUT then + self.sparseUpdate = ACC_MULTIPLE_TIMES end - input.THNN.SparseLinear_updateGradInput( + + input.THNN.SparseLinear_legacyAccGradParameters( input:cdata(), gradOutput:cdata(), - self.gradInput:cdata(), - self.weight:cdata() + self.gradWeight:cdata(), + self.gradBias:cdata(), + self.weight:cdata(), + self.bias:cdata(), + self.weightDecay or 0, + scale or 1 ) - -- fix gradInput size for batchSize = 1 + else if not batchMode then - self.gradInput:set(self.gradInput:view(self.gradInput:size(2), self.gradInput:size(3))) + gradOutput:resize(1, gradOutput:size(1)) end - return self.gradInput + input[1].THNN.SparseLinear_accGradParameters( + self.formatted_input:cdata(), + gradOutput:cdata(), + self.gradWeight:cdata(), + self.gradBias:cdata(), + self.weight:cdata(), + self.bias:cdata(), + self.weightDecay or 0, + scale or 1 + ) end end +function SparseLinear:updateGradInput(input, gradOutput) + if self.legacyMode then + if type(self.gradInput) ~= type(gradOutput) then self.gradInput = gradOutput.new() end + self.gradInput:resizeAs(input) + else + self.gradInput = {} + end + if self.doGradInput then + -- GradInput should be dense anyway + local gi + local batchMode = true + if gradOutput:dim() == 1 then + gi = self.weight:t()*gradOutput + batchMode = false + elseif gradOutput:dim() == 2 then + gi = gradOutput*self.weight + end + local ini = self.weight:size(2) + + if self.legacyMode then + local batches = self.gradInput:size(1) + self.gradInput:resize(batches, ini, 2) + self.gradInput:select(3,1):copy(torch.repeatTensor(torch.range(1, ini), batches, 1)) + self.gradInput:select(3,2):copy(gi) + else + indicies = torch.range(1, ini) + if not batchMode then gi:resize(1, ini) end + for i = 1,gi:size(1) do + self.gradInput[i] = gradOutput.new(ini, 2) + self.gradInput[i]:select(2, 2):copy(gi[i]) + self.gradInput[i]:select(2, 1):range(1, ini) + end + end + end + return self.gradInput +end + +-- These functions do sparse updates / zeros. However, if we accumulated +-- gradients multiple times, we can't depend on the last input to do sparse +-- updates. function SparseLinear:updateParameters(learningRate) - if self.lastInput then + if self.lastInput and self.legacyMode and self.sparseUpdate == ONE_LAST_INPUT then self.lastInput.THNN.SparseLinear_updateParameters( self.weight:cdata(), self.bias:cdata(), @@ -116,22 +192,24 @@ function SparseLinear:updateParameters(learningRate) else parent.updateParameters(self, learningRate) end + self.sparseUpdate = 0 end function SparseLinear:zeroGradParameters() - if self.lastInput then + if self.lastInput and self.legacyMode and self.sparseUpdate == ONE_LAST_INPUT then self.lastInput.THNN.SparseLinear_zeroGradParameters( - self.gradWeight:cdata(), - self.gradBias:cdata(), - self.lastInput:cdata() + self.gradWeight:cdata(), + self.gradBias:cdata(), + self.lastInput:cdata() ) else parent.zeroGradParameters(self) end + self.sparseUpdate = 0 end function SparseLinear:clearState() if self.lastInput then self.lastInput:set() end - if self.cudaBuffer then self.cudaBuffer:set() end + input.THNN.SparseLinear_cudaClearState() return parent.clearState(self) end |