blob: d086f2836b50e6507126631e9b0d0dc965b40ca9 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
|
local SparseCriterion, parent = torch.class('nn.SparseCriterion', 'nn.Criterion')
function SparseCriterion:__init()
parent.__init(self)
self.sizeAverage = true
end
function SparseCriterion:updateOutput(input)
input.nn.SparseCriterion_updateOutput(self, input)
return self.output
end
function SparseCriterion:updateGradInput(input)
input.nn.SparseCriterion_updateGradInput(self, input)
return self.gradInput
end
|