Welcome to mirror list, hosted at ThFree Co, Russian Federation.

Threshold.lua - github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: fbd5c544bd87df25bbbc4d7fdb2df5795dd79f1e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
local Threshold, parent = torch.class('nn.Threshold','nn.Module')

function Threshold:__init(th,v)
   parent.__init(self)
   self.threshold = th or 1e-6
   self.val = v or 0
   if (th and type(th) ~= 'number') or (v and type(v) ~= 'number') then
      error(xlua.usage('nn.Threshold',
                          'a threhold module, if input < threshold, then output = value',
                          nil,
                          {type='number', help='threshold'},
                          {type='number', help='value'}))
   end
end

function Threshold:forward(input)
   input.nn.Threshold_forward(self, input)
   return self.output
end

function Threshold:backward(input, gradOutput)
   input.nn.Threshold_backward(self, input, gradOutput)
   return self.gradInput
end

function Threshold:write(file)
   parent.write(self,file)
   file:writeDouble(self.threshold)
   file:writeDouble(self.val)
end

function Threshold:read(file)
   parent.read(self,file)
   self.threshold = file:readDouble()
   self.val = file:readDouble()
end