Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2014-03-20 20:23:12 +0400
committerSoumith Chintala <soumith@gmail.com>2014-03-20 20:23:12 +0400
commit9c511ead8869c69a677ced7b2ea2ca38e2b7934d (patch)
treecc92bce5570704c8733fac52d8ff55dd07751536 /SaturatedLU.lua
parent3bdd8dd32db8f4d42a23ef12cd10ec7145c78669 (diff)
adding a Saturated linear unit (saturate from both sides)
Diffstat (limited to 'SaturatedLU.lua')
-rw-r--r--SaturatedLU.lua27
1 files changed, 27 insertions, 0 deletions
diff --git a/SaturatedLU.lua b/SaturatedLU.lua
new file mode 100644
index 0000000..f49d119
--- /dev/null
+++ b/SaturatedLU.lua
@@ -0,0 +1,27 @@
+local SaturatedLU, parent = torch.class('nn.SaturatedLU','nn.Module')
+
+function SaturatedLU:__init(th,v,th2,v2)
+ parent.__init(self)
+ self.threshold = th or -1.0
+ self.val = v or -1.0
+ self.threshold2 = th2 or 1.0
+ self.val2 = v2 or 1.0
+ if (th and type(th) ~= 'number') or (v and type(v) ~= 'number')
+ or (th2 and type(th2) ~= 'number') or (v2 and type(v2) ~= 'number') then
+ error('nn.SaturatedLU(lower-bound, value, upper-bound, value2)')
+ end
+end
+
+function SaturatedLU:updateOutput(input)
+ self.output = input:clone()
+ self.output[self.output:lt(self.threshold)] = self.val
+ self.output[self.output:gt(self.threshold2)] = self.val2
+ return self.output
+end
+
+function SaturatedLU:updateGradInput(input, gradOutput)
+ self.gradInput = gradOutput:clone()
+ self.gradInput[input:lt(self.threshold)] = 0
+ self.gradInput[input:gt(self.threshold2)] = 0
+ return self.gradInput
+end \ No newline at end of file