Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorsoumith <soumith@fb.com>2015-01-03 07:42:12 +0300
committersoumith <soumith@fb.com>2015-01-03 07:42:12 +0300
commita38407a57def785acc819066db70f1649da47f03 (patch)
tree9aaa885fb28188a7c17fca6bcfe9e527f3930904 /SparseLinear.lua
parent2340b9c068b518cdc20b0c6c1a9b68971f0e97e8 (diff)
speedup and optimizations for SparseLinear
Diffstat (limited to 'SparseLinear.lua')
-rw-r--r--SparseLinear.lua37
1 files changed, 19 insertions, 18 deletions
diff --git a/SparseLinear.lua b/SparseLinear.lua
index 735d0ed..ca15be6 100644
--- a/SparseLinear.lua
+++ b/SparseLinear.lua
@@ -4,11 +4,16 @@ function SparseLinear:__init(inputSize, outputSize)
parent.__init(self)
self.weightDecay = 0
- self.weight = torch.Tensor(outputSize, inputSize)
- self.bias = torch.Tensor(outputSize)
- self.gradWeight = torch.Tensor(outputSize, inputSize)
- self.gradBias = torch.Tensor(outputSize)
- self.lastInput = torch.Tensor()
+ self.weight = torch.Tensor(outputSize, inputSize):zero()
+ self.bias = torch.Tensor(outputSize):zero()
+ self.gradWeight = torch.Tensor(outputSize, inputSize):zero()
+ self.gradBias = torch.Tensor(outputSize):zero()
+ self.lastInput = nil
+
+ if torch.getnumthreads() > 1 and outputSize >= 128 then
+ self.shardBuffer = torch.Tensor(outputSize, torch.getnumthreads())
+ end
+
-- state
self.gradInput:resize(inputSize)
self.output:resize(outputSize)
@@ -20,7 +25,7 @@ function SparseLinear:reset(stdv)
if stdv then
stdv = stdv * math.sqrt(3)
else
- stdv = 1./math.sqrt(self.weight:size(1))
+ stdv = 1./math.sqrt(self.weight:size(2))
end
if nn.oldSeed then
for i=1,self.weight:size(1) do
@@ -40,22 +45,18 @@ function SparseLinear:updateOutput(input)
end
function SparseLinear:accGradParameters(input, gradOutput, scale)
+ if not self.lastInput then
+ self.lastInput = input:clone()
+ else
+ self.lastInput:resizeAs(input):copy(input)
+ end
+
return input.nn.SparseLinear_accGradParameters(self, input, gradOutput, scale)
end
function SparseLinear:updateGradInput(input, gradOutput)
if self.gradInput then
- self.gradInput:resize(input:size())
- self.gradInput:copy(input)
- local numNonzero = self.gradInput:size(1)
- for e=1,numNonzero do
- local g = 0
- local i = self.gradInput[{e,1}]
- for j=1,self.output:size(1) do
- g = g + self.weight[{j,i}] * gradOutput[j]
- end
- self.gradInput[{e,2}] = g
- end
+ input.nn.SparseLinear_updateGradInput(self, input, gradOutput)
return self.gradInput
end
-end \ No newline at end of file
+end