Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'SparseLinear.lua')
-rw-r--r--SparseLinear.lua17
1 files changed, 17 insertions, 0 deletions
diff --git a/SparseLinear.lua b/SparseLinear.lua
index f1a2be5..735d0ed 100644
--- a/SparseLinear.lua
+++ b/SparseLinear.lua
@@ -42,3 +42,20 @@ end
function SparseLinear:accGradParameters(input, gradOutput, scale)
return input.nn.SparseLinear_accGradParameters(self, input, gradOutput, scale)
end
+
+function SparseLinear:updateGradInput(input, gradOutput)
+ if self.gradInput then
+ self.gradInput:resize(input:size())
+ self.gradInput:copy(input)
+ local numNonzero = self.gradInput:size(1)
+ for e=1,numNonzero do
+ local g = 0
+ local i = self.gradInput[{e,1}]
+ for j=1,self.output:size(1) do
+ g = g + self.weight[{j,i}] * gradOutput[j]
+ end
+ self.gradInput[{e,2}] = g
+ end
+ return self.gradInput
+ end
+end \ No newline at end of file