Welcome to mirror list, hosted at ThFree Co, Russian Federation.

SaturatedLU.lua - github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: f49d119a4caf32e1fb84110caa1a71c87ebddc92 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
local SaturatedLU, parent = torch.class('nn.SaturatedLU','nn.Module')

function SaturatedLU:__init(th,v,th2,v2)
   parent.__init(self)
   self.threshold = th or -1.0
   self.val = v or -1.0
   self.threshold2 = th2 or 1.0
   self.val2 = v2 or 1.0
   if (th and type(th) ~= 'number') or (v and type(v) ~= 'number')
      or (th2 and type(th2) ~= 'number') or (v2 and type(v2) ~= 'number') then
	 error('nn.SaturatedLU(lower-bound, value, upper-bound, value2)')
   end
end

function SaturatedLU:updateOutput(input)
   self.output = input:clone()
   self.output[self.output:lt(self.threshold)] = self.val
   self.output[self.output:gt(self.threshold2)] = self.val2
   return self.output
end

function SaturatedLU:updateGradInput(input, gradOutput)
   self.gradInput = gradOutput:clone()
   self.gradInput[input:lt(self.threshold)] = 0
   self.gradInput[input:gt(self.threshold2)] = 0
   return self.gradInput
end