blob: ddaa75cee6b9ada9f3a3cf0833422057d43c85c4 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
|
local SparseCriterion, parent = torch.class('nn.SparseCriterion', 'nn.Criterion')
function SparseCriterion:__init()
parent.__init(self)
self.sizeAverage = true
end
function SparseCriterion:forward(input)
input.nn.SparseCriterion_forward(self, input)
return self.output
end
function SparseCriterion:backward(input)
input.nn.SparseCriterion_backward(self, input)
return self.gradInput
end
function SparseCriterion:write(file)
parent.write(self, file)
file:writeBool(self.sizeAverage)
end
function SparseCriterion:read(file)
parent.read(self, file)
self.sizeAverage = file:readBool()
end
|