diff options
author | Soumith Chintala <soumith@gmail.com> | 2014-03-20 20:23:12 +0400 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2014-03-20 20:23:12 +0400 |
commit | 9c511ead8869c69a677ced7b2ea2ca38e2b7934d (patch) | |
tree | cc92bce5570704c8733fac52d8ff55dd07751536 /SaturatedLU.lua | |
parent | 3bdd8dd32db8f4d42a23ef12cd10ec7145c78669 (diff) |
adding a Saturated linear unit (saturate from both sides)
Diffstat (limited to 'SaturatedLU.lua')
-rw-r--r-- | SaturatedLU.lua | 27 |
1 files changed, 27 insertions, 0 deletions
diff --git a/SaturatedLU.lua b/SaturatedLU.lua new file mode 100644 index 0000000..f49d119 --- /dev/null +++ b/SaturatedLU.lua @@ -0,0 +1,27 @@ +local SaturatedLU, parent = torch.class('nn.SaturatedLU','nn.Module') + +function SaturatedLU:__init(th,v,th2,v2) + parent.__init(self) + self.threshold = th or -1.0 + self.val = v or -1.0 + self.threshold2 = th2 or 1.0 + self.val2 = v2 or 1.0 + if (th and type(th) ~= 'number') or (v and type(v) ~= 'number') + or (th2 and type(th2) ~= 'number') or (v2 and type(v2) ~= 'number') then + error('nn.SaturatedLU(lower-bound, value, upper-bound, value2)') + end +end + +function SaturatedLU:updateOutput(input) + self.output = input:clone() + self.output[self.output:lt(self.threshold)] = self.val + self.output[self.output:gt(self.threshold2)] = self.val2 + return self.output +end + +function SaturatedLU:updateGradInput(input, gradOutput) + self.gradInput = gradOutput:clone() + self.gradInput[input:lt(self.threshold)] = 0 + self.gradInput[input:gt(self.threshold2)] = 0 + return self.gradInput +end
\ No newline at end of file |