diff options
author | soumith <soumith@fb.com> | 2015-01-03 07:42:12 +0300 |
---|---|---|
committer | soumith <soumith@fb.com> | 2015-01-03 07:42:12 +0300 |
commit | a38407a57def785acc819066db70f1649da47f03 (patch) | |
tree | 9aaa885fb28188a7c17fca6bcfe9e527f3930904 /SparseLinear.lua | |
parent | 2340b9c068b518cdc20b0c6c1a9b68971f0e97e8 (diff) |
speedup and optimizations for SparseLinear
Diffstat (limited to 'SparseLinear.lua')
-rw-r--r-- | SparseLinear.lua | 37 |
1 files changed, 19 insertions, 18 deletions
diff --git a/SparseLinear.lua b/SparseLinear.lua index 735d0ed..ca15be6 100644 --- a/SparseLinear.lua +++ b/SparseLinear.lua @@ -4,11 +4,16 @@ function SparseLinear:__init(inputSize, outputSize) parent.__init(self) self.weightDecay = 0 - self.weight = torch.Tensor(outputSize, inputSize) - self.bias = torch.Tensor(outputSize) - self.gradWeight = torch.Tensor(outputSize, inputSize) - self.gradBias = torch.Tensor(outputSize) - self.lastInput = torch.Tensor() + self.weight = torch.Tensor(outputSize, inputSize):zero() + self.bias = torch.Tensor(outputSize):zero() + self.gradWeight = torch.Tensor(outputSize, inputSize):zero() + self.gradBias = torch.Tensor(outputSize):zero() + self.lastInput = nil + + if torch.getnumthreads() > 1 and outputSize >= 128 then + self.shardBuffer = torch.Tensor(outputSize, torch.getnumthreads()) + end + -- state self.gradInput:resize(inputSize) self.output:resize(outputSize) @@ -20,7 +25,7 @@ function SparseLinear:reset(stdv) if stdv then stdv = stdv * math.sqrt(3) else - stdv = 1./math.sqrt(self.weight:size(1)) + stdv = 1./math.sqrt(self.weight:size(2)) end if nn.oldSeed then for i=1,self.weight:size(1) do @@ -40,22 +45,18 @@ function SparseLinear:updateOutput(input) end function SparseLinear:accGradParameters(input, gradOutput, scale) + if not self.lastInput then + self.lastInput = input:clone() + else + self.lastInput:resizeAs(input):copy(input) + end + return input.nn.SparseLinear_accGradParameters(self, input, gradOutput, scale) end function SparseLinear:updateGradInput(input, gradOutput) if self.gradInput then - self.gradInput:resize(input:size()) - self.gradInput:copy(input) - local numNonzero = self.gradInput:size(1) - for e=1,numNonzero do - local g = 0 - local i = self.gradInput[{e,1}] - for j=1,self.output:size(1) do - g = g + self.weight[{j,i}] * gradOutput[j] - end - self.gradInput[{e,2}] = g - end + input.nn.SparseLinear_updateGradInput(self, input, gradOutput) return self.gradInput end -end
\ No newline at end of file +end |