diff options
Diffstat (limited to 'SparseLinear.lua')
-rw-r--r-- | SparseLinear.lua | 17 |
1 files changed, 17 insertions, 0 deletions
diff --git a/SparseLinear.lua b/SparseLinear.lua index f1a2be5..735d0ed 100644 --- a/SparseLinear.lua +++ b/SparseLinear.lua @@ -42,3 +42,20 @@ end function SparseLinear:accGradParameters(input, gradOutput, scale) return input.nn.SparseLinear_accGradParameters(self, input, gradOutput, scale) end + +function SparseLinear:updateGradInput(input, gradOutput) + if self.gradInput then + self.gradInput:resize(input:size()) + self.gradInput:copy(input) + local numNonzero = self.gradInput:size(1) + for e=1,numNonzero do + local g = 0 + local i = self.gradInput[{e,1}] + for j=1,self.output:size(1) do + g = g + self.weight[{j,i}] * gradOutput[j] + end + self.gradInput[{e,2}] = g + end + return self.gradInput + end +end
\ No newline at end of file |